# Reddit Data Scraper

### Import Libraries

In [None]:
import os
import time
import datetime
import pandas as pd
import praw
from dotenv import load_dotenv
import requests
from bs4 import BeautifulSoup

## Reddit API Client Initialization

In [None]:
def initialize_reddit_client():
    """Initialize and return a Reddit client using PRAW."""
    load_dotenv("api.env")
    client_id = os.getenv('REDDIT_CLIENT_ID')
    client_secret = os.getenv('REDDIT_CLIENT_SECRET')
    user_agent = os.getenv('REDDIT_USER_AGENT', 'StockDataScraper v1.0')
    
    if not client_id or not client_secret:
        raise ValueError("Reddit API credentials not found in api.env file")
    
    return praw.Reddit(
        client_id=client_id,
        client_secret=client_secret,
        user_agent=user_agent
    )

## Get S&P 500 Tickers from Wikipedia
URL = https://en.wikipedia.org/wiki/List_of_S%26P_500_companies

In [None]:
def get_sp500_tickers():
    """Get S&P 500 tickers from Wikipedia."""
    try:
        url = "https://en.wikipedia.org/wiki/List_of_S%26P_500_companies"
        response = requests.get(url)
        soup = BeautifulSoup(response.text, 'html.parser')
        table = soup.find("table", {"class": "wikitable"})
        sp500 = pd.read_html(str(table))[0]
        return sp500['Symbol'].to_list()
    except Exception as e:
        print(f"Error fetching S&P 500 tickers: {e}")
        return ['AAPL', 'MSFT', 'AMZN', 'GOOGL', 'META', 'TSLA']  # Fallback tickers M7s

## Data Collection Function

In [None]:
def get_reddit_data(reddit, subreddit_name, data_type='posts', search_term=None, 
                   time_filter='year', limit=100, comment_limit=25):
    subreddit = reddit.subreddit(subreddit_name)
    posts_list = []
    comments_list = []
    
    try:
        # Determine which data to fetch
        if data_type == 'search' and search_term:
            print(f"Searching for '{search_term}' in r/{subreddit_name}...")
            posts = subreddit.search(search_term, limit=limit)
            search_keywords = [search_term.lower(), f"${search_term.lower()}"]
        else:
            print(f"Getting top posts from r/{subreddit_name} for {time_filter}...")
            posts = subreddit.top(time_filter=time_filter, limit=limit)
            search_keywords = None
            
        # Process posts
        for i, post in enumerate(posts):
            # Filter search results if needed
            if search_keywords and not any(kw in (post.title + " " + post.selftext).lower() for kw in search_keywords):
                continue
                
            # Extract post data
            post_data = {
                'post_id': post.id,
                'title': post.title,
                'selftext': post.selftext,
                'score': post.score,
                'upvote_ratio': post.upvote_ratio,
                'created_utc': datetime.datetime.fromtimestamp(post.created_utc),
                'num_comments': post.num_comments,
                'author': str(post.author),
                'permalink': post.permalink,
                'url': post.url,
                'is_self': post.is_self,
                'flair': post.link_flair_text,
                'subreddit': subreddit_name,
                'category': 'stock_specific' if data_type == 'search' else 'general'
            }
            
            # Add search term if applicable
            if search_term:
                post_data['search_term'] = search_term
                
            posts_list.append(post_data)
            
            # Get comments
            try:
                post.comments.replace_more(limit=0)
                for comment in post.comments.list()[:comment_limit]:
                    # Filter comments for search terms if needed
                    if search_keywords and not any(kw in comment.body.lower() for kw in search_keywords):
                        continue
                        
                    comment_data = {
                        'comment_id': comment.id,
                        'post_id': post.id,
                        'parent_id': comment.parent_id,
                        'body': comment.body,
                        'score': comment.score,
                        'created_utc': datetime.datetime.fromtimestamp(comment.created_utc),
                        'author': str(comment.author),
                        'subreddit': subreddit_name,
                        'category': 'stock_specific' if data_type == 'search' else 'general'
                    }
                    
                    # Add search term if applicable
                    if search_term:
                        comment_data['search_term'] = search_term
                        
                    comments_list.append(comment_data)
            except Exception as e:
                print(f"Error processing comments for post {post.id}: {e}")
                
            # Be nice to Reddit's servers
            if (i + 1) % 25 == 0:
                time.sleep(1)
                
        print(f"Found {len(posts_list)} posts and {len(comments_list)} comments")
        return pd.DataFrame(posts_list) if posts_list else pd.DataFrame(), \
               pd.DataFrame(comments_list) if comments_list else pd.DataFrame()
               
    except Exception as e:
        print(f"Error fetching data from r/{subreddit_name}: {e}")
        return pd.DataFrame(), pd.DataFrame()

## Main Data Collection Script

In [None]:
def main():
    # Create timestamp and directory
    timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
    data_dir = f'reddit_data_{timestamp}'
    os.makedirs(data_dir, exist_ok=True)
    
    try:
        # Initialize Reddit client
        reddit = initialize_reddit_client()
        
        # Get S&P 500 tickers
        sp500_tickers = get_sp500_tickers()
        
        # List of finance subreddits
        subreddits = ['wallstreetbets', 'stocks', 'investing', 'StockMarket']
        
        # Initialize aggregated DataFrames
        all_posts = pd.DataFrame()
        all_comments = pd.DataFrame()
        
        # 1. Get general top posts from each subreddit
        for subreddit in subreddits:
            posts_df, comments_df = get_reddit_data(
                reddit, subreddit, data_type='posts', time_filter='year', 
                limit=100, comment_limit=25
            )
            
            all_posts = pd.concat([all_posts, posts_df], ignore_index=True)
            all_comments = pd.concat([all_comments, comments_df], ignore_index=True)
            
            # Save individual subreddit data
            subreddit_dir = os.path.join(data_dir, subreddit)
            os.makedirs(subreddit_dir, exist_ok=True)
            
            if not posts_df.empty:
                posts_df.to_csv(f'{subreddit_dir}/{subreddit}_general_posts.csv', index=False)
            if not comments_df.empty:
                comments_df.to_csv(f'{subreddit_dir}/{subreddit}_general_comments.csv', index=False)
                
            time.sleep(2)  # Be nice to Reddit's servers
        
        # 2. Get stock-specific posts (limit to first 30 tickers for efficiency)
        for subreddit in subreddits:
            for ticker in sp500_tickers[:30]:  # Limit to prevent excessive API calls
                posts_df, comments_df = get_reddit_data(
                    reddit, subreddit, data_type='search', search_term=ticker,
                    limit=50, comment_limit=25
                )
                
                all_posts = pd.concat([all_posts, posts_df], ignore_index=True)
                all_comments = pd.concat([all_comments, comments_df], ignore_index=True)
                
                # Save ticker data to subreddit directory
                subreddit_dir = os.path.join(data_dir, subreddit)
                if not posts_df.empty:
                    ticker_file = f'{subreddit_dir}/{subreddit}_{ticker}_posts.csv'
                    posts_df.to_csv(ticker_file, index=False)
                if not comments_df.empty:
                    ticker_file = f'{subreddit_dir}/{subreddit}_{ticker}_comments.csv'
                    comments_df.to_csv(ticker_file, index=False)
                    
                time.sleep(1)  # Be nice to Reddit's servers
            
            time.sleep(2)  # Additional pause between subreddits
            
        # Save aggregated data
        all_posts.to_csv(f'{data_dir}/all_posts.csv', index=False)
        all_comments.to_csv(f'{data_dir}/all_comments.csv', index=False)
        
        # Create summary
        summary = {
            'total_posts': len(all_posts),
            'total_comments': len(all_comments),
            'subreddits': subreddits,
            'tickers_searched': sp500_tickers[:30],
            'time_filter': 'year',
            'collection_date': datetime.datetime.now().strftime('%Y-%m-%d')
        }
        
        pd.Series(summary).to_json(f'{data_dir}/collection_summary.json')
        
        print(f"\nData collection complete!")
        print(f"Total posts: {len(all_posts)}")
        print(f"Total comments: {len(all_comments)}")
        print(f"Data saved in: {data_dir}")
        
    except Exception as e:
        import traceback
        print(f"An error occurred: {e}")
        traceback.print_exc()
main()