# Setup

In [1]:
import os
import pandas as pd
import praw
import toml
from openai import OpenAI
import plotly.express as px
from collections import Counter
import itertools
from IPython.display import HTML, display
import numpy as np

secrets = toml.load("secrets.toml")
openai_key = secrets["OPEN_AI_KEY"]
REDDIT = praw.Reddit(
    client_id=secrets["reddit"]["client_id"],
    client_secret=secrets["reddit"]["client_secret"],
    user_agent=secrets["reddit"]["user_agent"]
)

client = OpenAI(api_key=openai_key)

In [2]:
ANALYSIS_USE_CASE = """
We are analyzing Reddit posts to understand how people are using AI and chatbots for mental health, coaching, or emotional support.
Specifically, we want to identify posts where users share their personal experiences using AI tools for:
- Managing mental health conditions (anxiety, depression, ADHD, OCD, PTSD, trauma, etc.)
- Emotional support and wellbeing
- Therapy supplements or alternatives
- Wellness coaching and goal setting
- Help focusing, goal setting, managing stress, overcoming obstacles, etc.
- Other similar use cases for AI in mental health

The post should include first-hand experience using AI tools, not just general discussion about AI in mental health.
This does NOT need to be the main focus of the post, but it should clearly mention using AI for the use case described.
We want to extract structured data about their experiences, including benefits, challenges, and specific use cases.
Do NOT make stuff up.  ONLY use keywords that accurately fit what the schema describes. 
A keyword that applies to the post generally but not specifically to what is asked for by the schema should not be used.
"""

FIELDS = {
    "relevant_sample": {
        "type": "boolean",
        "description": "Boolean indicating if text describes personal experience using AI for the use case described in the prompt"
    },
    "relevant_sample_explanation": {
        "type": "string",
        "description": "Explanation of why the sample was classified as relevant or not relevant"
    },
    "sentiment": {
        "type": "integer",
        "description": "Integer 1-10 indicating sentiment TOWARDS using AI for mental health (10 most positive).  This is NOT sentiment of the post overall, just sentiment towards the interaction with AI."
    },
    "benefits": {
        "type": "array",
        "items": {"type": "string"},
        "description": "List of keywords relating to perceived benefits of using the AI, e.g.: non_judgemental, on_demand, affordable, accessible, anonymous, consistent, supportive, patient"
    },
    "downsides": {
        "type": "array",
        "items": {"type": "string"},
        "description": "List of keywords relating to downsides of using the AI, e.g.: repetitive, robotic, shallow, unreliable, addictive, avoidant, limited"
    },
    "use_cases": {
        "type": "array",
        "items": {"type": "string"},
        "description": "List of keywords relating to how AI is used, e.g.: reflection, venting, self_talk, planning, CBT, journaling, motivation, reminders, emotional_support"
    },
    "conditions": {
        "type": "array",
        "items": {"type": "string"},
        "description": "List of keywords describing conditions being addressed, e.g.: ADHD, depression, anxiety, addiction, OCD, PTSD, bipolar, eating_disorder"
    },
    "seeing_provider": {
        "type": "boolean",
        "description": "Boolean indicating if subject indicates they are CURRENTLY seeing a therapist or mental health provider"
    },
    "previous_provider": {
        "type": "boolean",
        "description": "Boolean indicating if subject indicates they have EVER seen a therapist or mental health provider"
    },
    "provider_problems": {
        "type": "array",
        "items": {"type": "string"},
        "description": "List of keywords relating to perceived issues with HUMAN PROVIDERS, e.g.: expensive, unavailable, inaccessible, scheduling, inconsistent, judgmental"
    },
    "fields_explanation": {
        "type": "string",
        "description": "Concise but thorough explanation of your reasoning for each field in the schema (except for relevant_sample and relevant_sample_explanation)"
    },
}

# Query Reddit Posts

- [Reddit search docs](https://support.reddithelp.com/hc/en-us/articles/19696541895316-Available-search-features)
- [PRAW docs](https://praw.readthedocs.io/en/stable/code_overview/models/subreddit.html)

### Specific Subreddits

In [28]:
subreddits = (
    "ADHD, Advice, Adulting, Alcoholism, Anger, Anxiety, AsianParentStories, "
    "aspergirls, BipolarReddit, BlackMentalHealth, bodyacceptance, bpd, "
    "careerguidance, CPTSD, dating_advice, dbtselfhelp, "
    "DecidingToBeBetter, depression, depression_help, EDAnonymous, Enneagram, "
    "GetMotivated, HealthAnxiety, Healthygamergg, hopefulmentalhealth, "
    "lawofattraction, LucidDreaming, malementalhealth, meditation, "
    "mental, mentalhealth, mentalhealthadvice, "
    "mentalhealthph, mentalhealthsupport, mentalhealthuk, "
    "mentalillness, MensMentalHealth, microdosing, "
    "MMFB, nofap, nosurf, OCD, offmychest, pornfree, productivity, "
    "Psychiatry, psychology, ptsd, QAnonCasualties, "
    "raisedbynarcissists, relationship_advice, relationships, "
    "selfimprovement, socialanxiety, socialskills, StopSmoking, Stress, "
    "suicidewatch, TalkTherapy, teenagers, therapy, therapists, "
    "traumatoolbox, TrueOffMyChest, WellnessPT"
)

In [None]:
import pandas as pd
from tqdm import tqdm

# Split subreddits string into list
subreddit_list = [s.strip() for s in subreddits.split(',')]

# Query for AI content in each subreddit
posts = []
query = '(AI OR "artificial intelligence" OR chatbot OR gpt OR chatGPT or Claude OR characterAI OR Gemini OR Woebot OR Wysa OR Youper Or Sintelly)'

for subreddit in tqdm(subreddit_list):
    try:
        # Try to get the subreddit
        sub = REDDIT.subreddit(subreddit)
        
        # Check if subreddit exists and has reasonable activity
        try:
            subscribers = sub.subscribers
            if subscribers < 1000:
                print(f"Warning: {subreddit} has only {subscribers} subscribers")
                continue
        except:
            print(f"Warning: Could not access subscriber count for {subreddit}")
            continue
            
        # Search within this subreddit
        search_results = sub.search(
            query,
            sort='relevance',
            time_filter='year',
            limit=100
        )
        
        # Add posts from this subreddit
        for post in search_results:
            posts.append({
                'title': post.title,
                'text': post.selftext,
                'score': post.score,
                'created_utc': post.created_utc,
                'id': post.id,
                'subreddit': post.subreddit.display_name,
                'url': f"https://reddit.com{post.permalink}",
                'num_comments': post.num_comments
            })
            
    except Exception as e:
        print(f"Error accessing subreddit {subreddit}: {str(e)}")
        continue

# Convert to dataframe
subreddits_df = pd.DataFrame(posts)
print(f"\nFound {len(subreddits_df)} total posts across all subreddits")

### Search with query

In [None]:
query = """
(title:AI OR title:"artificial intelligence" OR title:chatbot OR title:gpt OR title:Claude OR title:characterAI OR title:Gemini) AND 
(title:therapy OR title:therapist OR title:"mental health" OR title:anxiety OR title:adhd OR title:depression OR title:stress OR title:ocd OR title:relationships)
"""
posts = []
search_results = REDDIT.subreddit("all").search(
    query,
    sort='relevance',
    syntax='lucene',
    time_filter='year',
    limit=10000
)

for post in search_results:
    posts.append({
        'title': post.title,
        'text': post.selftext,
        'score': post.score,
        'created_utc': post.created_utc,
        'id': post.id,
        'subreddit': post.subreddit.display_name,
        'url': f"https://reddit.com{post.permalink}",
        'num_comments': post.num_comments
    })
    
search_df = pd.DataFrame(posts)
print(f"Found {len(search_df)} posts")
print("\nSample titles:")
print(search_df[['title', 'subreddit', 'score']].head())

# Combine all posts

In [None]:
# Combine posts from both sources and deduplicate
all_posts = pd.concat([subreddits_df, search_df], ignore_index=True)

# Drop duplicates based on post ID since that's unique per Reddit post
all_posts = all_posts.drop_duplicates(subset=['id'], keep='first')

print(f"Total posts after combining and deduping: {len(all_posts)}")
print(f"Posts from subreddit search: {len(subreddits_df)}")
print(f"Posts from keyword search: {len(search_df)}")
print(f"Duplicates removed: {len(subreddits_df) + len(search_df) - len(all_posts)}")

In [None]:
def deduplicate_posts(df):
    """Remove duplicate posts with same title/text, keeping the one with most comments"""
    print(f"Posts before deduplication: {len(df)}")
    
    # Group by title and text to find duplicates
    duplicates = df.groupby(['title', 'text']).agg({
        'num_comments': 'max',  # Keep post with most comments
        'id': 'count'  # Count occurrences
    }).reset_index()

    # Filter to only groups with duplicates
    duplicates = duplicates[duplicates['id'] > 1]

    # For each duplicate group, keep only the post with most comments
    if len(duplicates) > 0:
        for _, dup in duplicates.iterrows():
            # Find all posts with this title/text
            mask = (df['title'] == dup['title']) & (df['text'] == dup['text'])
            # Keep only the one with most comments
            to_drop = df[mask & (df['num_comments'] < dup['num_comments'])].index
            df = df.drop(to_drop)

    print(f"Posts after deduplication: {len(df)}")
    return df

all_posts = deduplicate_posts(all_posts)

# Filter for relevance

In [None]:
def analyze_post_relevance(post, use_case):
    # Analyze post relevance
    response = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[{
            "role": "user", 
            "content": f"""Post Title: {post['title']}
Post Content: {post.get('text', '[No content]')}
Use Case: {use_case}
Question: Is this post relevant to our use case? Please answer with a brief 'Yes' or 'No' and short explanation."""
        }]
    )
    
    response_text = response.choices[0].message.content
    
    # Determine boolean based on start of response
    is_relevant = None
    if response_text.strip().lower().startswith('yes'):
        is_relevant = True
    elif response_text.strip().lower().startswith('no'):
        is_relevant = False
        
    return (is_relevant, response_text)

use_case = """
We are looking for posts relating to how people are using AI, chatbots, or virtual companions for mental health support or coaching.
This can include for any purpose: anxiety, adhd, depression, stress, ocd, relationships, goal setting, wellbeing, etc.
To qualify, the user must discuss their own experience using one of these tools, not just discussing in abstract or commenting on the use of AI in general.
"""

print_posts = False
from tqdm import tqdm
for idx, post in tqdm(all_posts.iterrows(), total=len(all_posts)):
    # Skip if already analyzed
    if pd.notna(post.get('is_relevant')):
        continue
        
    is_relevant, analysis = analyze_post_relevance(post, use_case)
    
    # Add relevance fields to existing post
    all_posts.at[idx, 'is_relevant'] = is_relevant
    all_posts.at[idx, 'relevant_explanation'] = analysis
    
    # Print concise results
    if print_posts:
        print(f"\n[{post['subreddit']}] {post['title']}")
        print(f"Link: {post['url']}")
        print(f"Analysis: {analysis}")
        print("-" * 50)

print("\nRelevant Post Counts:\n", all_posts['is_relevant'].value_counts())

In [245]:

samples = all_posts[all_posts['is_relevant'] == True].drop(columns=['is_relevant', 'relevant_explanation']).copy()
samples['created_utc'] = pd.to_datetime(samples['created_utc'], unit='s')
samples.to_json('samples.json', orient='records', date_format='iso')

# Extract Fields with AI

In [None]:
from typing import Dict, Any
import pprint
import json

def extract_fields(text: str, fields: Dict[str, dict], prompt: str) -> Dict[str, Any]:
    """Extract structured fields from text using OpenAI"""
    schema = {
        "name": "extract_fields",
        "strict": True,
        "schema": {
            "type": "object",
            "properties": fields,
            "required": list(fields.keys()),
            "additionalProperties": False
        }
    }
    
    response = client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {"role": "system", "content": prompt},
            {"role": "user", "content": text}
        ],
        response_format={
            "type": "json_schema",
            "json_schema": schema
        }
    )
    return json.loads(response.choices[0].message.content)


prompt=ANALYSIS_USE_CASE + "\n\nAnalyze the following Reddit post and extract the requested fields, carefully, according to the schema.  Explain your reasoning in the last field."

# Store results in a dictionary
results = {}

# Loop through all samples and extract fields
for idx, sample in tqdm(samples.iterrows(), total=len(samples), desc="Extracting fields"):
    text = f"{sample['title']}\n\n{sample['text']}"
    result = extract_fields(text, FIELDS, prompt)
    results[idx] = result

# Convert results dictionary to dataframe and merge with original
results_df = pd.DataFrame.from_dict(results, orient='index')
samples = samples.merge(results_df, left_index=True, right_index=True)
samples.to_json('samples-with-fields-4o.json', orient='records', date_format='iso')

In [8]:
# Read data back in
samples = pd.read_json('samples-with-fields-4o.json', orient='records')

# Filter 2nd pass

In [9]:
relevant_counts = samples['relevant_sample'].value_counts()
print("Breakdown of relevant_sample vs not:")
print(relevant_counts)

Breakdown of relevant_sample vs not:
relevant_sample
True     516
False    210
Name: count, dtype: int64


In [None]:

def display_irrelevant_samples(df):
    # Create combined text field
    df['combined_text'] = '[' + df['subreddit'] + '] ' + \
                         df['title'] + ' : ' + \
                         df['text']

    # Create relevance field                                    
    df['relevance_info'] = df['relevant_sample'].astype(str) + \
                          ' : ' + df['relevant_sample_explanation']

    display(df[['combined_text', 'relevance_info']].style.set_properties(**{
        'white-space': 'pre-wrap',
        'text-align': 'left'
    }).set_table_styles([
        {'selector': 'th', 'props': [('text-align', 'left')]},
        {'selector': '', 'props': [('border', '1px solid grey')]}
    ]))

display_irrelevant_samples(samples[~samples['relevant_sample']].sample(n=10))

In [19]:
samples = samples[samples['relevant_sample'] == True]

# Analysis Utilities

In [12]:
def plot_distribution(data, column, title, x_label, save=True):
    values = data[column].values
    percentiles = np.percentile(values, [0, 25, 50, 75, 90, 95, 97.5, 99, 99.9, 100])
    bucket_edges = np.unique([0, 1] + [int(p) for p in percentiles[1:]])
    
    bucket_labels = []
    for i in range(len(bucket_edges)-1):
        if bucket_edges[i+1] == bucket_edges[i]:
            continue
        if bucket_edges[i+1] == bucket_edges[i] + 1:
            bucket_labels.append(str(bucket_edges[i]))
        else:
            bucket_labels.append(f'{bucket_edges[i]}-{bucket_edges[i+1]-1}')
    
    bucketed_values = pd.cut(values, bins=bucket_edges, labels=bucket_labels, right=False)
    value_counts = bucketed_values.value_counts().sort_index()
    
    fig = px.bar(x=value_counts.index, y=value_counts.values,
                 title=title,
                 labels={'x': x_label, 'y': 'Number of Posts'})
    
    fig.update_layout(
        bargap=0.2,
        xaxis_title=x_label,
        yaxis_title='Number of Posts'
    )
    fig.update_xaxes(tickangle=45)
    
    if save:
        if not os.path.exists('plots'):
            os.makedirs('plots')
        fig.write_image(f"plots/{title.lower().replace(' ', '_')}.png")
    
    fig.show()

def format_post(subreddit, title, text, score=None, url=None, keywords=None):
    html = f"""
    <div style="border: 1px solid #ddd; border-radius: 8px; padding: 15px; margin: 10px 5px; background-color: #f9f9f9; display: inline-block; vertical-align: top; width: 30%;">
        <div style="color: #666; font-size: 0.9em; margin-bottom: 5px;">
            <a href="https://reddit.com/r/{subreddit}" target="_blank" style="text-decoration: none; color: inherit;">r/{subreddit}</a>
        </div>
        <div style="color: #333; font-size: 1.1em; font-weight: bold; margin-bottom: 10px;">
            <a href="{url if url else '#'}" target="_blank" style="text-decoration: none; color: inherit;">{title}</a>
        </div>
        {'<div style="color: #1a0dab; margin-bottom: 10px;">Score: ' + str(score) + '</div>' if score is not None else ''}
        {'<div style="color: #666; margin-bottom: 10px;">Keywords: ' + ', '.join(keywords) + '</div>' if keywords else ''}
        <div style="color: #444; line-height: 1.4; max-height: 200px; overflow-y: auto;">{text[:500]}...</div>
    </div>
    """
    return html

def display_examples_section(title):
    display(HTML(f'<div style="margin: 10px 0;"><h4 style="color: #333; margin: 0; padding: 5px 0; border-bottom: 1px solid #ccc;">{title}</h4></div>'))

def plot_binary_distribution(df, field_name, title, show_examples=False, save=True):
    counts = df[field_name].value_counts()
    percentages = (counts / len(df)) * 100
    fig = px.pie(values=percentages.values, names=percentages.index, title=title)
    fig.update_traces(texttemplate='%{value:.1f}%')
    
    if save:
        if not os.path.exists('plots'):
            os.makedirs('plots')
        fig.write_image(f"plots/{title.lower().replace(' ', '_')}.png")
    
    fig.show()

    if show_examples:
        for value in counts.index:
            display_examples_section(f"Example Posts for {field_name}")
            html_output = '<div style="display: flex; flex-direction: row; justify-content: space-between;">'
            example_posts = df[df[field_name] == value].sample(n=3)
            for _, post in example_posts.iterrows():
                html_output += format_post(post['subreddit'], post['title'], post['text'], url=post.get('url'))
            html_output += "</div>"
            display(HTML(html_output))

def plot_integer_distribution(df, field_name, title, show_examples=False, save=True):
    fig = px.histogram(df, x=field_name, title=title, nbins=10)
    fig.update_traces(histnorm='percent')
    
    mean_val = df[field_name].mean()
    fig.add_vline(x=mean_val, line_dash="dash", line_color="red",
                 annotation_text=f"Mean: {mean_val:.2f}",
                 annotation_position="top right")
    
    fig.update_layout(yaxis_title="Percent")
    
    if save:
        if not os.path.exists('plots'):
            os.makedirs('plots')
        fig.write_image(f"plots/{title.lower().replace(' ', '_')}.png")
    
    fig.show()

    if show_examples:
        display_examples_section(f"Example Posts with {field_name} Scores")
        html_output = '<div style="display: flex; flex-direction: row; justify-content: space-between;">'
        example_posts = df.sample(n=3)
        for _, post in example_posts.iterrows():
            html_output += format_post(post['subreddit'], post['title'], post['text'], post[field_name], url=post.get('url'))
        html_output += "</div>"
        display(HTML(html_output))

def plot_list_field(df, field_name, title, limit=10, show_examples=False, save=True):
    all_items = list(itertools.chain(*df[field_name].dropna()))
    item_counts = Counter(all_items).most_common(limit)
    if item_counts:
        df_counts = pd.DataFrame(item_counts, columns=[field_name, 'count'])
        total_posts = len(df)
        df_counts['percent'] = df_counts['count'].apply(lambda x: (x / total_posts) * 100)
        
        fig = px.bar(df_counts, x=field_name, y='percent', title=title)
        fig.update_layout(yaxis_title="Percent of Posts")
        
        if save:
            if not os.path.exists('plots'):
                os.makedirs('plots')
            fig.write_image(f"plots/{title.lower().replace(' ', '_')}.png")
        
        fig.show()

        if show_examples:
            display_examples_section(f"Example Posts with Keywords")
            html_output = '<div style="display: flex; flex-direction: row; justify-content: space-between;">'
            example_posts = df[df[field_name].apply(lambda x: len(x) > 0)].sample(n=3)
            for _, post in example_posts.iterrows():
                html_output += format_post(
                    post['subreddit'], 
                    post['title'], 
                    post['text'],
                    url=post.get('url'),
                    keywords=post[field_name]
                )
            html_output += "</div>"
            display(HTML(html_output))

def plot_all_fields(df, fields_dict, show_examples=True, save=False):
    for field_name, field_info in fields_dict.items():
        if field_info['type'] == 'boolean':
            plot_binary_distribution(df, field_name, f'Distribution of {field_name}', show_examples=show_examples, save=save)
        elif field_info['type'] == 'integer':
            plot_integer_distribution(df, field_name, f'{field_name} Distribution', show_examples=show_examples, save=save)
        elif field_info['type'] == 'array':
            plot_list_field(df, field_name, f'Most Common {field_name.capitalize()}', show_examples=show_examples, save=save)

In [27]:
# Display 5 random posts side by side as an example
html_output = ""
for _, post in samples.sample(n=3).iterrows():
    # Build keywords list from all array fields
    keywords = []
    
    # Add array fields with their values
    if 'benefits' in post and isinstance(post['benefits'], list):
        keywords.extend(post['benefits'])
    if 'downsides' in post and isinstance(post['downsides'], list):
        keywords.extend(post['downsides'])
    if 'use_cases' in post and isinstance(post['use_cases'], list):
        keywords.extend(post['use_cases'])
    if 'conditions' in post and isinstance(post['conditions'], list):
        keywords.extend(post['conditions'])
    if 'provider_problems' in post and isinstance(post['provider_problems'], list):
        keywords.extend(post['provider_problems'])
        
    # Add boolean fields as keywords if True
    if post.get('seeing_provider'):
        keywords.append('Currently seeing provider')
    if post.get('previous_provider'):
        keywords.append('Has previous provider')
        
    # Add sentiment if present
    if 'sentiment' in post and pd.notna(post['sentiment']):
        keywords.append(f'Sentiment: {post["sentiment"]}/10')

    html_output += format_post(
        subreddit=post['subreddit'],
        title=post['title'],
        text=post['text'],
        # score=post.get('score'),
        url=post.get('url'),
        keywords=keywords
    )

from IPython.display import HTML
display(HTML(html_output))

In [32]:
len(samples)

516

# Analysis - general

In [33]:
# Distribution of posts over time
posts_over_time = pd.to_datetime(samples['created_utc']).dt.to_period('M').value_counts().sort_index()
fig1 = px.line(x=posts_over_time.index.astype(str), y=posts_over_time.values, 
                title='Distribution of Posts Over Time', labels={'x': 'Week Start Date (MM-DD-YY)', 'y': 'Number of Posts'})
fig1.update_xaxes(tickvals=posts_over_time.index.astype(str), ticktext=posts_over_time.index.start_time.strftime('%m-%d-%y'))
fig1.show()

# Simple histogram of subreddit counts using plotly express, sorted by frequency, showing top 20
subreddit_counts = samples['subreddit'].value_counts().head(20)
fig = px.histogram(
    samples[samples['subreddit'].isin(subreddit_counts.index)], 
    x='subreddit',
    title='Distribution of Posts Across Top 20 Subreddits',
    category_orders={"subreddit": subreddit_counts.index}
)
fig.update_layout(
    xaxis_title="Subreddit",
    yaxis_title="Number of Posts", 
    xaxis_tickangle=45
)
fig.show()

# Plot score distribution
plot_distribution(samples, 'score', 'Distribution of Post Scores', 'Score Range', save=False)

# Plot comment distribution 
plot_distribution(samples, 'num_comments', 'Distribution of Post Comments', 'Comment Range', save=False)

In [40]:
# filtered_samples = samples[samples['subreddit'] == 'CPTSD']
# html_output = "<div>"
# for _, row in filtered_samples.iterrows():
#     html_output += f"""
#     <div style='margin-bottom: 20px;'>
#         <a href='{row['url']}' target='_blank'>{row['title']}</a>
#         <div style='color: #666; font-size: 0.9em;'>Posted in r/{row['subreddit']}</div>
#     </div>
#     """
# html_output += "</div>"
# display(HTML(html_output))


# Analyze Extracted Fields

In [41]:
# Visualize fields based on their types defined in FIELDS
# Filter samples to only include relevant posts
plot_all_fields(samples, FIELDS, show_examples=True)

# Investigate specific groups

In [47]:
def show_samples_with_value(samples, field, value):
    """
    Display samples that contain a specific value in a field
    
    Args:
        samples (pd.DataFrame): DataFrame containing the samples
        field (str): Name of the field to check (e.g. 'downsides', 'benefits')
        value (str): Value to look for in the field
    """
    filtered_samples = samples[samples[field].apply(lambda x: value in x)]

    # Create HTML output
    html_output = "<div>"  # Removed max-width constraint entirely
    html_output += f"<h3>{100*len(filtered_samples)/len(samples):.1f}% of samples mention '{value}' in {field}</h3>"

    for _, sample in filtered_samples[:5].iterrows():
        html_output += format_post(title=sample['title'], text=sample['text'], subreddit=sample['subreddit'], url=sample['url'])
        
    html_output += "</div>"
    display(HTML(html_output))

# Example usage:
show_samples_with_value(samples, 'downsides', 'shallow')


In [49]:
specific_title = "am I cheating on my bf with an AI?"
matching_post = samples[samples['title'].str.contains(specific_title, case=False, na=False)]

In [None]:
def show_subreddit_examples(df, subreddit, n=5):
    subreddit_posts = df[df['subreddit'] == subreddit].sample(n=n)

    html_output = '<div style="display: flex; flex-direction: row; justify-content: space-between;">'
    for _, post in subreddit_posts.iterrows():
        html_output += format_post(
            subreddit=post['subreddit'],
            title=post['title'], 
            text=post['text'],
            score=post['score'],
            url=post['url']
        )
    html_output += "</div>"
    display(HTML(html_output))

show_subreddit_examples(samples, 'Healthygamergg')

In [None]:
# Plot sentiment by use case, ordered by frequency
use_cases_df = samples.explode('use_cases')
use_case_counts = use_cases_df['use_cases'].value_counts().head(10)
use_case_sentiment = use_cases_df.groupby('use_cases')['sentiment'].mean()
# Reindex sentiment by frequency order and limit to top 10
use_case_sentiment = use_case_sentiment.reindex(use_case_counts.index)

fig = px.bar(x=use_case_sentiment.index, 
             y=use_case_sentiment.values,
             title='Average Sentiment Score by Top 10 Use Cases (Ordered by Frequency)',
             labels={'x': 'Use Case', 'y': 'Average Sentiment Score'},
             text=use_case_counts.values)
fig.update_xaxes(tickangle=45)
fig.update_traces(textposition='outside')
fig.show()

# Plot sentiment by condition, ordered by frequency  
conditions_df = samples.explode('conditions')
condition_counts = conditions_df['conditions'].value_counts().head(10)
condition_sentiment = conditions_df.groupby('conditions')['sentiment'].mean()
# Reindex sentiment by frequency order and limit to top 10
condition_sentiment = condition_sentiment.reindex(condition_counts.index)

fig = px.bar(x=condition_sentiment.index,
             y=condition_sentiment.values, 
             title='Average Sentiment Score by Top 10 Conditions (Ordered by Frequency)',
             labels={'x': 'Condition', 'y': 'Average Sentiment Score'},
             text=condition_counts.values)
fig.update_xaxes(tickangle=45)
fig.update_traces(textposition='outside')
fig.show()



# Archive

### Query top popular subreddits to easily analyze

In [None]:
# Get top 10000 subreddits from Reddit API
import praw
import pandas as pd
from tqdm import tqdm
# Initialize list to store subreddit data
subreddit_data = []

for subreddit in tqdm(REDDIT.subreddits.popular(limit=10000)):
    subreddit_data.append({
        'subreddit': subreddit.display_name,
        'count': subreddit.subscribers,
    })

# Create DataFrame and save to CSV
subreddits_df = pd.DataFrame(subreddit_data)
subreddits_df.to_csv('top_subreddits.csv', index=False)
print(f"\nSaved {len(subreddits_df)} subreddits to top_subreddits.csv")