### Libraries

In [1]:
from dotenv import load_dotenv
import os
import neptune
import praw
from datetime import datetime, timedelta, timezone
import json
import time
import pandas as pd
from tqdm import tqdm

from pathlib import Path
Root = Path('.').absolute().parent
SCRIPTS = Root / r'scripts'
# SCRIPTS = Root / r'C:\Users\Admin\Projects\ML Projects\ManipDetect\research\scripts'
DATA = Root/ r'C:\Users\krishnadas\Projects\ML Projects\ManipDetect\data'

In [2]:
def reddit_connect():
    """Initialize Reddit connection"""
    load_dotenv()
    reddit = praw.Reddit(
        client_id=os.getenv('REDDIT_CLIENT_ID'),
        client_secret=os.getenv('REDDIT_CLIENT_SECRET'),
        user_agent=os.getenv('REDDIT_USER_AGENT'),
        username=os.getenv('REDDIT_USERNAME'),
        password=os.getenv('REDDIT_PASSWORD')
    )
    return reddit

In [4]:
def load_progress(filename="scraping_progress.json"):
    """Load previously scraped data if it exists"""
    try:
        with open(filename, 'r', encoding='utf-8') as f:
            data = json.load(f)
            posts = data.get('posts', [])
            
            # Normalize old data structure to ensure consistency
            consistent_posts = []
            for post in posts:
                # Ensure all required fields exist with default values
                consistent_post_dict = {
                    'post_id': post.get('post_id', ''),
                    'title': post.get('title', ''),
                    'text': post.get('text', post.get('selftext', '')),  # Handle old 'selftext' field
                    'post_type': post.get('post_type', 'unknown'),
                    'author_name': post.get('author_name', post.get('author', '[unknown]')),  # Handle old 'author' field
                    'author_id':post.get('author_id', post.get('author_id', '')),
                    'score': post.get('score', 0),
                    'num_comments': post.get('num_comments', 0),
                    'created_utc': post.get('created_utc', 0),
                    'url': post.get('url', '')
                }
                consistent_posts.append(consistent_post_dict)
            scraped_ids = [post.get('post_id') for post in consistent_posts]
            return consistent_posts, scraped_ids
    except FileNotFoundError:
        return [], set()

def save_progress(posts_data, scraped_ids, filename="scraping_progress.json"):
    """Save current progress to file"""
    progress_data = {
        'posts': posts_data,
        'scraped_ids': scraped_ids,
        'saved_at': datetime.now().isoformat(),
        'total_posts': len(posts_data)
    }
    # Save progress as JSON (for resume functionality)
    with open(filename, 'w', encoding='utf-8') as f:
        json.dump(progress_data, f, ensure_ascii=False, indent=2)
    
    # Also save current data as CSV
    if posts_data:
        df = pd.DataFrame(posts_data)
        df.to_csv("wallstreetbetsnew_posts.csv", index=False, encoding='utf-8')

def save_final_csv(posts_data, filepath):
    """Save final dataset as CSV with proper formatting"""
    if not posts_data:
        return
    df = pd.DataFrame(posts_data)
    
    # Convert timestamp to readable format
    df['created_datetime'] = pd.to_datetime(df['created_utc'], unit='s')
    
    # Reorder columns for better readability, but only use columns that exist
    preferred_columns_order = ['post_id', 'title', 'text', 'post_type', 'author_name', 'author_id', 'score', 'num_comments', 
                                'created_utc', 'url']
    
    # Filter to only include columns that actually exist in the DataFrame
    available_columns = [col for col in preferred_columns_order if col in df.columns]
    
    # Add any remaining columns that weren't in our preferred order
    remaining_columns = [col for col in df.columns if col not in available_columns]
    final_columns_order = available_columns + remaining_columns
    
    df = df[final_columns_order]
    
    # Save with timestamp in filename
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    # csv_filename = f"wallstreetbetsnew_posts_{timestamp}.csv"
    csv_filename = filepath / f"wallstreetbetsnew_posts_{timestamp}.csv"
    df.to_csv(csv_filename, index=False, encoding='utf-8')
    
    return csv_filename

In [5]:
def date_to_unix(date_string):
    """Convert date string to Unix timestamp"""
    return int(datetime.strptime(date_string, '%Y-%m-%d').timestamp())

def unix_to_date(unix_timestamp):
    """Convert Unix timestamp to readable date"""
    return datetime.fromtimestamp(unix_timestamp, timezone.utc).strftime('%Y-%m-%d %H:%M:%S')

def _extract_post_data(submission):
        """Extract comprehensive post data"""
        
        # Handle author
        author_name = "[deleted]"
        author_id = "[deleted]"
        if submission.author is not None:
            try:
                author_name = submission.author.name
                author_id = submission.author.id
            except:
                author_name = "[unavailable]"
                author_id = "[unavailable]"
        
        # Handle post text
        post_text = ""
        post_type = "text"
        
        if submission.is_self:
            if submission.selftext:
                post_text = submission.selftext
                post_type = "text"
            else:
                post_text = "[Empty text post]"
                post_type = "text_empty"
        else:
            post_text = "[Link Post]"
            post_type = "link"
            
            # Categorize link types
            if any(ext in submission.url.lower() for ext in ['.jpg', '.jpeg', '.png', '.gif']):
                post_type = "image"
            elif any(site in submission.url.lower() for site in ['youtube.com', 'youtu.be']):
                post_type = "video"
        
        return {
            "post_id": submission.id,
            "title": submission.title,
            "text": post_text,
            "post_type": post_type,
            "author_name": author_name,
            "author_id": author_id,
            "score": submission.score,
            # "upvote_ratio": getattr(submission, 'upvote_ratio', None),
            "num_comments": submission.num_comments,
            "created_utc": submission.created_utc,
            # "created_datetime": unix_to_date(submission.created_utc),
            "url": submission.url
            # "permalink": f"https://reddit.com{submission.permalink}",
            # "subreddit": submission.subreddit.display_name,
            # "gilded": submission.gilded,
            # "locked": submission.locked,
            # "over_18": submission.over_18,
            # "spoiler": submission.spoiler,
            # "stickied": submission.stickied
        }

def _is_same_session(self, current_config, saved_config):
    """Check if current scraping session matches saved session"""
    key_fields = ['method', 'subreddit', 'start_date', 'end_date', 
                    'max_posts', 'sort_method', 'keywords', 'days_back']
    
    for field in key_fields:
        if current_config.get(field) != saved_config.get(field):
            return False
    return True

In [None]:
def scrape_by_date_range(subreddit_name, start_date, end_date, 
                        max_posts=None, sort_method='new', resume=True, batch_size = 500):
        """
        Scrape posts from a specific date range with batch processing and resume capability
        """
        print(f"Scraping r/{subreddit_name} from {start_date} to {end_date}")
        
        # Configuration for this scraping session
        config = {
            'method': 'date_range',
            'subreddit': subreddit_name,
            'start_date': start_date,
            'end_date': end_date,
            'max_posts': max_posts,
            'sort_method': sort_method,
            'batch_size': batch_size
        }
        filepath = SCRIPTS/'temp_data'  # Define the path to save progress
        # Load previous progress if resuming
        if resume:
            posts_data, scraped_ids, saved_config = load_progress()
            
            # Check if we're resuming the same scraping session
            if saved_config and _is_same_session(config, saved_config):
                print(f"Resuming previous session with {len(posts_data)} posts")
                print(f"Will skip {len(scraped_ids)} already scraped post IDs")
                start_count = len(posts_data)
            else:
                print("Starting new scraping session")
                posts_data, scraped_ids = [], set()
                start_count = 0
        else:
            print("Starting fresh scraping session (no resume)")
            posts_data, scraped_ids = [], set()
            start_count = 0
        
        # Convert dates to Unix timestamps
        start_timestamp = date_to_unix(start_date)
        end_timestamp = date_to_unix(end_date)
        
        reddit = reddit_connect()
        subreddit = reddit.subreddit(subreddit_name)
        
        # Get posts based on sort method
        if sort_method == 'new':
            posts_generator = subreddit.new(limit=None)
        elif sort_method == 'hot':
            posts_generator = subreddit.hot(limit=None)
        elif sort_method == 'top':
            posts_generator = subreddit.top(time_filter='all', limit=None)
        else:
            raise ValueError("sort_method must be 'new', 'hot', or 'top'")
        
        print("Fetching posts from Reddit...")
        collected_posts = 0
        batch_count = 0
        errors_count = 0
        
        try:
            for submission in tqdm(posts_generator, desc="Processing posts"):
                # Skip if already scraped
                if submission.id in scraped_ids:
                    continue
                
                # Check if post is within date range
                if submission.created_utc < start_timestamp:
                    if sort_method == 'new':
                        print(f"Reached posts older than {start_date}, stopping...")
                        break
                    else:
                        continue
                
                if submission.created_utc > end_timestamp:
                    continue
                
                # Extract post data
                try:
                    post_data = _extract_post_data(submission)
                    posts_data.append(post_data)
                    scraped_ids.add(submission.id)
                    collected_posts += 1
                    
                    # Batch-wise progress saving
                    if collected_posts % batch_size == 0:
                        batch_count += 1
                        print(f"\nCompleted batch {batch_count} ({collected_posts} posts)")
                        save_progress(posts_data, scraped_ids, config)
                        
                        # Rate limiting between batches
                        print("Taking 30-second break between batches...")
                        time.sleep(30)
                    
                    # Stop if we've reached max_posts
                    if max_posts and collected_posts >= max_posts:
                        print(f"Reached target of {max_posts} posts")
                        break
                    
                    # Rate limiting within batch
                    if collected_posts % 50 == 0:
                        time.sleep(1)
                        
                except Exception as e:
                    errors_count += 1
                    print(f"Error processing post {submission.id}: {e}")
                    if errors_count > 50:  # Stop if too many errors
                        print("Too many errors, stopping...")
                        break
                    continue
        
        except KeyboardInterrupt:
            print("\nScraping interrupted by user")
            save_progress(posts_data, scraped_ids, config)
            return posts_data, config
        
        except Exception as e:
            print(f"Unexpected error: {e}")
            save_progress(posts_data, scraped_ids, config)
            return posts_data, config
        
        # Final save
        save_final_csv(posts_data, filepath)
        
        new_posts = collected_posts
        total_posts = len(posts_data)
        
        print(f"\nScraping completed!")
        print(f"New posts this session: {new_posts}")
        print(f"Total posts collected: {total_posts}")
        print(f"Errors encountered: {errors_count}")
        print(f"Date range: {start_date} to {end_date}")
        
        return posts_data, config

In [15]:
def scrape_historical_batch(subreddit_name, days_back=30, 
                            posts_per_day=100, sort_method='new'):
        """
        Scrape historical data in daily batches (Pushshift-style)
        
        Args:
            subreddit_name: Name of subreddit
            days_back: How many days back to scrape
            posts_per_day: Target posts per day
            sort_method: Sorting method
        """
        print(f"Scraping {days_back} days of historical data")
        
        all_posts = []
        end_date = datetime.now()
        
        for day in range(days_back):
            current_date = end_date - timedelta(days=day)
            start_date = current_date - timedelta(days=1)
            
            start_str = start_date.strftime('%Y-%m-%d')
            end_str = current_date.strftime('%Y-%m-%d')
            
            print(f"\nScraping day {day + 1}/{days_back}: {start_str}")
            
            day_posts = scrape_by_date_range(
                subreddit_name, start_str, end_str, 
                max_posts=posts_per_day, sort_method=sort_method
            )
            
            all_posts.extend(day_posts)
            
            # Longer break between days
            if day < days_back - 1:
                print("Waiting 30 seconds before next day...")
                time.sleep(30)
        
        return all_posts

In [16]:
def scrape_with_keywords(subreddit_name, keywords, max_posts=1000, 
                        days_back=7):
    """
    Scrape posts containing specific keywords (Pushshift-style search)
    
    Args:
        subreddit_name: Subreddit name
        keywords: List of keywords to search for
        max_posts: Maximum posts to return
        days_back: Days to look back
    """
    print(f"Searching for posts with keywords: {keywords}")
    
    # Get recent posts
    end_date = datetime.now()
    start_date = end_date - timedelta(days=days_back)
    
    start_str = start_date.strftime('%Y-%m-%d')
    end_str = end_date.strftime('%Y-%m-%d')
    
    all_posts = scrape_by_date_range(
        subreddit_name, start_str, end_str, 
        max_posts=max_posts * 5  # Get more to filter
    )
    
    # Filter by keywords
    filtered_posts = []
    for post in all_posts:
        title_lower = post['title'].lower()
        text_lower = post['text'].lower()
        
        # Check if any keyword is in title or text
        if any(keyword.lower() in title_lower or keyword.lower() in text_lower 
                for keyword in keywords):
            filtered_posts.append(post)
            
            if len(filtered_posts) >= max_posts:
                break
    
    print(f"Found {len(filtered_posts)} posts with keywords")
    return filtered_posts

def _extract_post_data(submission):
    """Extract comprehensive post data"""
    
    # Handle author
    author_name = "[deleted]"
    author_id = "[deleted]"
    if submission.author is not None:
        try:
            author_name = submission.author.name
            author_id = submission.author.id
        except:
            author_name = "[unavailable]"
            author_id = "[unavailable]"
    
    # Handle post text
    post_text = ""
    post_type = "text"
    
    if submission.is_self:
        if submission.selftext:
            post_text = submission.selftext
            post_type = "text"
        else:
            post_text = "[Empty text post]"
            post_type = "text_empty"
    else:
        post_text = "[Link Post]"
        post_type = "link"
        
        # Categorize link types
        if any(ext in submission.url.lower() for ext in ['.jpg', '.jpeg', '.png', '.gif']):
            post_type = "image"
        elif any(site in submission.url.lower() for site in ['youtube.com', 'youtu.be']):
            post_type = "video"
    
    return {
        "post_id": submission.id,
        "title": submission.title,
        "text": post_text,
        "post_type": post_type,
        "author_name": author_name,
        "author_id": author_id,
        "score": submission.score,
        "upvote_ratio": getattr(submission, 'upvote_ratio', None),
        "num_comments": submission.num_comments,
        "created_utc": submission.created_utc,
        "created_datetime": unix_to_date(submission.created_utc),
        "url": submission.url,
        "permalink": f"https://reddit.com{submission.permalink}",
        "subreddit": submission.subreddit.display_name,
        "gilded": submission.gilded,
        "locked": submission.locked,
        "over_18": submission.over_18,
        "spoiler": submission.spoiler,
        "stickied": submission.stickied
    }

In [17]:
def save_to_csv(posts_data, filename=None):
    """Save posts data to CSV"""
    if not filename:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"reddit_posts_{timestamp}.csv"
    
    df = pd.DataFrame(posts_data)
    df.to_csv(filename, index=False, encoding='utf-8')
    print(f"Saved {len(posts_data)} posts to {filename}")
    return filename

def track_with_neptune(posts_data, config):
    """Track scraping results with Neptune"""
    run = neptune.init_run(project="krishnadasm/wallstreetbets-scraper")
    
    # Log configuration
    for key, value in config.items():
        run[f"config/{key}"] = value
    
    # Log results
    run["results/total_posts"] = len(posts_data)
    run["results/date_range"] = f"{config.get('start_date', 'N/A')} to {config.get('end_date', 'N/A')}"
    
    # Analyze post types
    if posts_data:
        df = pd.DataFrame(posts_data)
        post_type_counts = df['post_type'].value_counts().to_dict()
        for post_type, count in post_type_counts.items():
            run[f"analysis/post_types/{post_type}"] = count
    
    # Upload CSV
    csv_filename = save_to_csv(posts_data)
    run["data/posts_csv"].upload(csv_filename)
    
    run.stop()
    return csv_filename

In [18]:
def example_historical_scraping():
    """Example: Scrape historical data day by day"""
    # scraper = PushshiftStyleScraper()
    
    posts = scrape_historical_batch(
        subreddit_name='wallstreetbets',
        days_back=2,
        posts_per_day=100,
        sort_method='new'
    )
    
    config = {
        'method': 'historical_batch',
        'subreddit': 'wallstreetbets',
        'days_back': 2,
        'posts_per_day': 100
    }
    
    track_with_neptune(posts, config)

def example_keyword_search():
    """Example: Search for posts with specific keywords"""
    # scraper = PushshiftStyleScraper()
    
    posts = scrape_with_keywords(
        subreddit_name='wallstreetbets',
        keywords=['GME', 'GameStop', 'TSLA', 'Tesla'],
        max_posts=500,
        days_back=7
    )
    
    config = {
        'method': 'keyword_search',
        'subreddit': 'wallstreetbetsnew',
        'keywords': ['GME', 'GameStop', 'TSLA', 'Tesla'],
        'max_posts': 500,
        'days_back': 7
    }
    
    track_with_neptune(posts, config)


In [19]:

if __name__ == "__main__":
    print("Pushshift-Style Reddit Scraper")
    print("=" * 50)
    
    # Uncomment the example you want to run:
    
    # Example 1: Date range scraping
    # example_date_range_scraping()
    
    # Example 2: Historical batch scraping  
    example_historical_scraping()
    
    # Example 3: Keyword search
    # example_keyword_search()
    
    print("Choose an example to run by uncommenting the appropriate line!")

Pushshift-Style Reddit Scraper
Scraping 2 days of historical data

Scraping day 1/2: 2025-07-08
🔍 Scraping r/wallstreetbets from 2025-07-08 to 2025-07-09
Fetching posts from Reddit...


Filtering by date: 34it [00:07,  4.83it/s]


Collected 20 posts from 2025-07-08 to 2025-07-09
Waiting 30 seconds before next day...

Scraping day 2/2: 2025-07-07
🔍 Scraping r/wallstreetbets from 2025-07-07 to 2025-07-08
Fetching posts from Reddit...


Filtering by date: 52it [00:06,  7.72it/s]


Collected 18 posts from 2025-07-07 to 2025-07-08
[neptune] [info   ] Neptune initialized. Open in the app: https://app.neptune.ai/krishnadasm/wallstreetbets-scraper/e/WAL-73
Saved 38 posts to reddit_posts_20250709_172911.csv
[neptune] [info   ] Shutting down background jobs, please wait a moment...
[neptune] [info   ] Done!
[neptune] [info   ] Waiting for the remaining 17 operations to synchronize with Neptune. Do not kill this process.
[neptune] [info   ] All 17 operations synced, thanks for waiting!
[neptune] [info   ] Explore the metadata in the Neptune app: https://app.neptune.ai/krishnadasm/wallstreetbets-scraper/e/WAL-73/metadata
Choose an example to run by uncommenting the appropriate line!
