In [None]:
import os
import json

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import praw

from transformers import AutoTokenizer, pipeline

from thefuzz import fuzz,process
from word2number import w2n

In [None]:
with open('../reddit_api.json') as json_file:
    reddit_api_credentials = json.load(json_file)
    reddit_read_only = praw.Reddit(client_id=reddit_api_credentials['client_id'],
                                   client_secret=reddit_api_credentials['secret'],
                                   user_agent=reddit_api_credentials['user_agent']) 

subreddit = reddit_read_only.subreddit("AskDocs")

In [None]:
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased")
age_extractor = pipeline("ner", model="../models/age_token_classification", tokenizer=tokenizer)
gender_extractor = pipeline("text-classification", 
                            model="../models/gender_training",
                            truncation=True, 
                            padding = True, )
subject_extractor = pipeline("text-classification", 
                            model="../models/subject_training",
                            truncation=True, 
                            padding = True, )

In [None]:
def resolve_age(age_extracts):
    if len(age_extracts) == 0:
        resolved_age = None
    else:
        units = [entity for entity in age_extracts if entity['entity'] == 'B-age_unit']
        if len(units) > 0:
            collection = ['years', 'months']
            resolved_unit = process.extract(units[0]['word'], collection, scorer=fuzz.ratio)[0][0]
        else:
            resolved_unit = 'years'
        age_words = [entity for entity in age_extracts if entity['entity'] in ['B-age', 'I-age']]
        if len(age_words) == 0:
            resolved_age = None
        else:
            try:
                resolved_age = int(age_words[0]['word'])
            except:
                try:
                    resolved_age = w2n.word_to_num(age_words[0]['word'])
                except:
                    resolved_age = None
        if resolved_unit == 'months':
            resolved_age = None
    return resolved_age

In [None]:
random_posts = [
    {
        "id": post.id,
        'post_text': f"{post.title}\n{post.selftext}",
        "score": post.score,
        'total_comments': post.num_comments,
        'post_url': post.url
    }
    for post in [
        subreddit.random() 
        for i in range(100)
    ]
]
random_posts_df = pd.DataFrame.from_dict(random_posts)

In [None]:
age_extracts = age_extractor([post['post_text'] for post in random_posts])
random_posts_df['resolved_age'] = [resolve_age(age_extract) for age_extract in age_extracts]
random_posts_df['resolved_gender'] = [extract['label'] 
                                      for extract in gender_extractor([post['post_text'] 
                                                                       for post in random_posts])]
random_posts_df['resolved_subject'] = [extract['label'] 
                                       for extract in subject_extractor([post['post_text'] 
                                                                         for post in random_posts])]
random_posts_df

In [None]:
if 'resolved_random_posts.csv' in os.listdir('../data/'):
    random_posts_df = pd.concat(
        [
            pd.read_csv('../data/resolved_random_posts.csv'),
            random_posts_df
        ]
    )
    random_posts_df.drop_duplicates(subset = 'id', keep = 'last', inplace = True)

random_posts_df.to_csv('../data/resolved_random_posts.csv', index = False)
len(random_posts_df)

In [None]:
plt.hist(random_posts_df['resolved_age'])
plt.title('Who is posting on r/AskDocs?')
plt.xlabel('Age')
plt.ylabel('Frequency')

In [None]:
random_posts_df['resolved_gender'].value_counts().plot(kind = 'barh')


In [None]:
random_posts_df['resolved_subject'].value_counts().plot(kind = 'barh')


In [None]:
plt.hist(random_posts_df.loc[random_posts_df['resolved_subject'] == 'Other','resolved_age'])
plt.xlabel('Age')
plt.title('How old are people in the other category')

In [None]:
random_posts_df