In [None]:
from typing import List, Dict, Optional
from datetime import datetime
import praw
from prawcore.exceptions import PrawcoreException, ResponseException
import logging
import time

logging.basicConfig(level=logging.WARNING)
logger = logging.getLogger(__name__)


class RedditERNIEFetcher:
    """用于获取 Reddit 板块的所有讨论和评论"""
    
    def __init__(
        self,
        client_id: str,
        client_secret: str,
        user_agent: str = "ERNIE_Discussion_Fetcher/1.0",
        verbose: bool = False
    ):
        try:
            self.reddit = praw.Reddit(
                client_id=client_id,
                client_secret=client_secret,
                user_agent=user_agent,
                check_for_async=False
            )
            self.verbose = verbose
        except Exception as e:
            logger.error(f"Reddit 连接失败: {e}")
            raise
    
    def _convert_timestamp(self, timestamp: float) -> str:
        """将 Unix 时间戳转换为可读的日期格式"""
        return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
    
    def fetch_subreddit_posts(
        self,
        subreddit_name: str,
        time_filter: Optional[str] = None,
        sort_by: str = "new",
        limit: Optional[int] = None
    ) -> List[Dict]:
        """
        获取某个板块的所有帖子
        
        Args:
            subreddit_name: 板块名称（如 "LanguageModels"）
            time_filter: 时间过滤 ("day", "week", "month", "year", "all")
            sort_by: 排序方式 ("new", "hot", "top", "rising")
            limit: 获取数量限制（None = 无限制，实际受 Reddit API 限制）
            
        Returns:
            帖子数据列表
        """
        posts = []
        
        try:
            subreddit = self.reddit.subreddit(subreddit_name)
            
            # 根据排序方式获取帖子
            if sort_by == "new":
                submissions = subreddit.new(limit=limit)
            elif sort_by == "hot":
                submissions = subreddit.hot(limit=limit)
            elif sort_by == "top":
                submissions = subreddit.top(time_filter=time_filter or "all", limit=limit)
            elif sort_by == "rising":
                submissions = subreddit.rising(limit=limit)
            else:
                submissions = subreddit.new(limit=limit)
            
            for submission in submissions:
                posts.append({
                    "type": "post",
                    "post_id": submission.id,
                    "title": submission.title,
                    "subreddit": submission.subreddit.display_name,
                    "author": submission.author.name if submission.author else "[deleted]",
                    "score": submission.score,
                    "upvote_ratio": submission.upvote_ratio,
                    "num_comments": submission.num_comments,
                    "created_at": self._convert_timestamp(submission.created_utc),
                    "url": submission.url,
                    "permalink": f"https://reddit.com{submission.permalink}",
                    "content": submission.selftext,
                    "is_self": submission.is_self,
                    "link_flair_text": submission.link_flair_text,
                    "parent_post_id": None,
                    "parent_title": None
                })
                
                time.sleep(0.1)  # 避免速率限制
            
            if self.verbose:
                print(f"✓ r/{subreddit_name}: 获取到 {len(posts)} 个帖子")
                
        except ResponseException as e:
            logger.error(f"获取 r/{subreddit_name} 失败 (401): {e}")
        except PrawcoreException as e:
            logger.warning(f"获取 r/{subreddit_name} 失败: {e}")
        except Exception as e:
            logger.error(f"处理 r/{subreddit_name} 时出错: {e}")
        
        return posts
    
    def fetch_comments_for_post(self, post_url: str) -> List[Dict]:
        """
        获取某个帖子的所有评论
        
        Args:
            post_url: 帖子的 URL 或 permalink
            
        Returns:
            评论数据列表
        """
        comments = []
        
        try:
            submission = self.reddit.submission(url=post_url)
            
            # 展开所有评论（limit=None 表示全部展开）
            submission.comments.replace_more(limit=None)
            
            for comment in submission.comments.list():
                if comment.author:  # 排除已删除的评论
                    comments.append({
                        "type": "comment",
                        "post_id": submission.id,
                        "comment_id": comment.id,
                        "title": None,
                        "subreddit": submission.subreddit.display_name,
                        "author": comment.author.name,
                        "score": comment.score,
                        "upvote_ratio": None,
                        "num_comments": None,
                        "created_at": self._convert_timestamp(comment.created_utc),
                        "url": f"https://reddit.com{comment.permalink}",
                        "permalink": comment.permalink,
                        "content": comment.body,
                        "is_self": True,
                        "link_flair_text": None,
                        "parent_post_id": submission.id,
                        "parent_title": submission.title
                    })
        
        except Exception as e:
            logger.error(f"获取评论失败: {e}")
        
        return comments
    
    def fetch_subreddit_full(
        self,
        subreddit_name: str,
        time_filter: Optional[str] = None,
        sort_by: str = "new",
        post_limit: Optional[int] = None,
        fetch_comments: bool = True
    ) -> List[Dict]:
        """
        获取板块的完整数据（帖子 + 评论）
        
        Args:
            subreddit_name: 板块名称
            time_filter: 时间过滤
            sort_by: 排序方式
            post_limit: 帖子数量限制
            fetch_comments: 是否获取评论
            
        Returns:
            包含帖子和评论的完整数据列表
        """
        all_data = []
        
        if self.verbose:
            print(f"\n{'='*50}")
            print(f"正在获取 r/{subreddit_name} 的数据...")
            print(f"{'='*50}")
        
        # 获取帖子
        posts = self.fetch_subreddit_posts(
            subreddit_name,
            time_filter=time_filter,
            sort_by=sort_by,
            limit=post_limit
        )
        
        all_data.extend(posts)
        
        # 获取评论
        if fetch_comments:
            for idx, post in enumerate(posts, 1):
                if self.verbose and idx % 10 == 0:
                    print(f"  处理评论进度: {idx}/{len(posts)}")
                
                comments = self.fetch_comments_for_post(post['permalink'])
                all_data.extend(comments)
                
                time.sleep(0.5)  # 避免速率限制
        
        if self.verbose:
            posts_count = len([d for d in all_data if d['type'] == 'post'])
            comments_count = len([d for d in all_data if d['type'] == 'comment'])
            print(f"✓ 完成！共 {posts_count} 个帖子，{comments_count} 条评论\n")
        
        return all_data
    
    def fetch_multiple_subreddits(
        self,
        subreddit_names: List[str],
        time_filter: Optional[str] = None,
        sort_by: str = "new",
        post_limit: Optional[int] = None,
        fetch_comments: bool = True
    ) -> List[Dict]:
        """
        批量获取多个板块的完整数据
        
        Args:
            subreddit_names: 板块名称列表
            time_filter: 时间过滤
            sort_by: 排序方式
            post_limit: 每个板块的帖子数量限制
            fetch_comments: 是否获取评论
            
        Returns:
            所有板块的数据汇总
        """
        all_data = []
        
        for subreddit_name in subreddit_names:
            data = self.fetch_subreddit_full(
                subreddit_name,
                time_filter=time_filter,
                sort_by=sort_by,
                post_limit=post_limit,
                fetch_comments=fetch_comments
            )
            all_data.extend(data)
        
        return all_data


if __name__ == "__main__":
    import pandas as pd
    
    client_id = "1KiqUsgcGQDRiXvTgU32Ow"
    client_secret = "56PQoBZJ43HEj2sBWYG9pN_UggWnYw"
    user_agent = "ERNIE_Discussion_Fetcher/1.0"
    
    fetcher = RedditERNIEFetcher(
        client_id=client_id,
        client_secret=client_secret,
        user_agent=user_agent,
        verbose=True
    )
    
    # 定义关键板块
    key_subreddits = [
        "LocalLLM",
        "LocalLlaMa",
        "ChatGPT",
        "ArtificialIntelligence",
        "OpenSourceeAI",
        "singularity",
        "machinelearningnews",
        "SillyTavernAI",
        "StableDiffusion"
    ]
    
    # 方案 1：首次全量获取（不限时间）
    print("首次全量获取...")
    data = fetcher.fetch_multiple_subreddits(
        key_subreddits,
        time_filter=None,  # 全量
        sort_by="new",
        post_limit=1000,  # 每个板块最多1000个帖子
        fetch_comments=True
    )
    
    # 方案 2：每周增量更新（只获取最近一周）
    # print("每周增量更新...")
    # data = fetcher.fetch_multiple_subreddits(
    #     key_subreddits,
    #     time_filter="week",
    #     sort_by="new",
    #     post_limit=100,
    #     fetch_comments=True
    # )
    
    # 保存数据
    df = pd.DataFrame(data)
    
    if not df.empty:
        # 在本地过滤出包含 ERNIE 的数据
        mask = df.apply(
            lambda row: 'ernie' in str(row.get('title', '')).lower() or 
                       'ernie' in str(row.get('content', '')).lower(),
            axis=1
        )
        df_filtered = df[mask]
        
        print(f"\n总数据: {len(df)} 条")
        print(f"包含 ERNIE 的数据: {len(df_filtered)} 条")
        
        # 保存所有数据
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        df.to_csv(f"reddit_all_{timestamp}.csv", index=False, encoding='utf-8-sig')
        
        # 保存过滤后的数据
        if not df_filtered.empty:
            df_filtered.to_csv(f"reddit_ernie_{timestamp}.csv", index=False, encoding='utf-8-sig')
            print(f"\n已保存到: reddit_ernie_{timestamp}.csv")
    else:
        print("未获取到任何数据")

In [14]:
df.to_excel("Reddit.xlsx", index = False)

In [15]:
# 测试连接
import praw

client_id = "1KiqUsgcGQDRiXvTgU32Ow"
client_secret = "56PQoBZJ43HEj2sBWYG9pN_UggWnYw"
user_agent = "ERNIE_fetcher/1.0 by YOUR_REDDIT_USERNAME"  # 这里需要改

try:
    reddit = praw.Reddit(
        client_id=client_id,
        client_secret=client_secret,
        user_agent=user_agent
    )
    
    # 测试连接
    print(f"认证成功！只读模式: {reddit.read_only}")
    print(f"测试获取 r/test: {reddit.subreddit('test').display_name}")
    
except Exception as e:
    print(f"认证失败: {e}")


认证成功！只读模式: True
测试获取 r/test: test
