<a href="https://colab.research.google.com/github/danielmlow/tutorials/blob/main/text/reddit_download.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install praw

Search for similar reddits here:
https://anvaka.github.io/sayit/?query=GriefSupport

In [None]:
sample_size = 100

subreddits = ['bullying',
 'abusesurvivors',
 'sexualassault',
 'relationship_advice',
 'GriefSupport',
 'lonely',
 'Anxiety',
 'depression',
 'asktransgender',
 'EatingDisorders',
 'addiction']



In [None]:
class RedditSampler:
    def __init__(self, client_id: str, client_secret: str, user_agent: str):
        self.reddit = praw.Reddit(
            client_id=client_id,
            client_secret=client_secret,
            user_agent=user_agent
        )

    def get_samples(
        self,
        subreddits: List[str],
        sample_size: int = 6000,
        sleep_amount: float = 0.1,
        submission_type: str = "all"
    ) -> pd.DataFrame:
        all_submissions = []
        sorts = ['new', 'top', 'controversial', 'hot']
        time_filters = ['all', 'year', 'month', 'week']

        for subreddit_name in subreddits:
            try:
                subreddit = self.reddit.subreddit(subreddit_name)
                # Use a dictionary with submission ID as key for deduplication
                submission_dict = {}

                for sort in sorts:
                    if len(submission_dict) >= sample_size:
                        break

                    for time_filter in time_filters:
                        if len(submission_dict) >= sample_size:
                            break

                        try:
                            if sort == 'new':
                                submissions = subreddit.new(limit=1000)
                            elif sort == 'hot':
                                submissions = subreddit.hot(limit=1000)
                            else:
                                submissions = getattr(subreddit, sort)(time_filter=time_filter, limit=1000)

                            for submission in submissions:
                                if len(submission_dict) >= sample_size:
                                    break

                                if submission_type != "all":
                                    if submission_type == "self" and not submission.is_self:
                                        continue
                                    if submission_type == "link" and submission.is_self:
                                        continue

                                # Skip if we've already processed this submission ID
                                if submission.id in submission_dict:
                                    continue

                                sub_dict = {
                                    'subreddit': subreddit_name,
                                    'id': submission.id,
                                    'title': submission.title,
                                    'author': str(submission.author),
                                    'created_utc': datetime.fromtimestamp(submission.created_utc),
                                    'score': submission.score,
                                    'upvote_ratio': submission.upvote_ratio,
                                    'num_comments': submission.num_comments,
                                    'url': submission.url,
                                    'is_self': submission.is_self,
                                    'selftext': submission.selftext if submission.is_self else None,
                                    # Store these for informational purposes
                                    'collection_sort': sort,
                                    'collection_time_filter': time_filter if sort != 'new' and sort != 'hot' else None
                                }

                                # Use the ID as the dictionary key for guaranteed uniqueness
                                submission_dict[submission.id] = sub_dict
                                time.sleep(sleep_amount)

                            print(f"Collected {len(submission_dict)} unique samples from r/{subreddit_name} ({sort}/{time_filter})")
                            time.sleep(1)

                        except Exception as e:
                            print(f"Error with {sort}/{time_filter}: {str(e)}")
                            continue

                # Convert dictionary to list of values
                all_submissions.extend(list(submission_dict.values())[:sample_size])
                print(f"Added {min(len(submission_dict), sample_size)} unique submissions from r/{subreddit_name}")

            except Exception as e:
                print(f"Error collecting from r/{subreddit_name}: {str(e)}")
                continue

        return pd.DataFrame(all_submissions)

To get a Reddit API client ID and secret:

1. Visit https://www.reddit.com/prefs/apps
2. Sign in to your Reddit account
3. Click "create app" or "create another app" button
4. Fill out the form:
   - Name: your app name
   - Select "web app" or "script" depending on your needs
   - Description: brief description of your app
   - About URL: optional website URL
   - Redirect URI: use http://localhost:8000 for testing
5. Click "create app"

After creation, you'll see:
- Client ID: under your app name
- Client secret: displayed as "secret"

Note that Reddit has significantly restricted API access with new usage limitations and pricing since April 2023, which may affect your development plans.match

In [None]:


# Example usage:
if __name__ == "__main__":
    # You'll need to get these from your Reddit API application
    CLIENT_ID = 'YOUR_CLIENT_ID'
    CLIENT_SECRET = "YOUR_CLIENT_SECRET"
    USER_AGENT = f"script:data_sampler:v1.0 (by /u/{"YOUR_USERNAME"})"

    sampler = RedditSampler(CLIENT_ID, CLIENT_SECRET, USER_AGENT)



    # Get samples
    samples_df = sampler.get_samples(
        subreddits=subreddits,
        sample_size=sample_size,
        submission_type="all"
    )

In [None]:
samples_df

In [None]:

# Any duplicates?
samples_df['id'].value_counts().sort_values(ascending=False)

In [None]:
from datetime import datetime
# now
now = datetime.now()
format = '%y-%m-%dT%H-%M-%S'
date_string = now.strftime(format)
date_string

In [None]:
samples_df['subreddit'].value_counts()

In [None]:
samples_df['title_text'] = samples_df['title']+"\n---\n"+samples_df['selftext']

# Save to CSV
samples_df.to_csv(f"data/input/reddit_10_mental_health_{date_string}_incomplete.csv", index=False)
print(f"Saved {len(samples_df)} total samples to reddit_samples.csv")