In [6]:
import pandas as pd

subreddit_file = open('archive/50_subreddits_list.csv', mode='r', encoding='utf-8-sig')
subreddits = {}
for line in subreddit_file.readlines():
  subreddit = line.rstrip("\n").lower()
  df = pd.read_csv(f'archive/{subreddit}.csv')
  subreddits[subreddit] = df

### Get Title and Text from DataFrames
Our classifier only works based on Title and Text

In [7]:
subreddit_bodies = {}
for subreddit in subreddits:
  df = subreddits[subreddit]
  df['title'] = df['title'].fillna('')
  df['body'] = df['body'].fillna('')
  subreddit_bodies[subreddit] = df['title'] + " " + df['body'] 

subreddit_bodies['anime']

0      'Dragon Ball' Creator Akira Toryiyama Has Pass...
1        Kaguya-sama: Love Is War - Season 3 announced! 
2                         Aqua in yoga pants | Konosuba 
3                     This is not a Cigarette [Gintama] 
4         The Devil is a Part-Timer Season 2 Announced! 
                             ...                        
991    Mob Psycho 100 Season 2 - Episode 5 discussion...
992    Never thought in a million years I’d come acro...
993    I seriously hate Sundays [Engaged to the Unide...
994                "Spy Classroom" New Character Visual 
995    Hayasaka Ai in Spy Suit from "Kaguya: Love is ...
Length: 996, dtype: object

### Remove Links
Our dataset includes links of the form "\[link text](link)". Since we want to analayze the text content of the subreddits, we should filter out these links and just keep "link text"

In [8]:
import re # Regex Library

no_links = {}
for subreddit in subreddit_bodies:
  data = subreddit_bodies[subreddit]
  no_links[subreddit] = data.map(lambda txt: re.sub("\[([^\]]*)\]\(([^\)]*)\)", " \g<1> ", txt))

no_links['travel']

0      I visited North Korea recently, these are some...
1      Taken with a phone out of my hotel window in V...
2      Taking a ride on the Bernina Express through t...
3      Wife and I hate big social events and love tra...
4      The exact moment I took a step too close to th...
                             ...                        
992    Sisteron- France. Beautiful place we had a cof...
993    Croatia, probably the most beautiful country i...
994    Michelangelo's David is great, but pieta is on...
995    If you don't mind a little dust and grit and y...
996    I’d never realized how beautiful Montenegro wa...
Length: 997, dtype: object

### Tokenizing

Since reddit posts contain long sections of prose that are unique to each user's writing we felt finding attributes about the text as a whole would be difficult. Rather, we tokenize each head and body in hopes that individual words will vary between subreddits. We also turned each word to lowercase to remove issues with capitalization between different posts, since capitalization likely doesn't affect the semantic meaning of each word in a post.

Additionally, we found the special unicode character ’ (as opposed to ') in several entries across datasets likely due to differences in keyboards among different languages. We replaced the former with the latter to correctly match words with apostrophes (e.g. I’d => I'd)

In [9]:
import nltk

tokenizer = nltk.RegexpTokenizer(
  pattern=r"[\w']+", # Only match words as tokens (coarsely, \w + apostrophes)
  gaps=False,
  discard_empty=True # Remove empty tokens caused by markdown content
)
tokenized_subreddits = {}
for subreddit in no_links:
  data = no_links[subreddit]
  data = data.map(lambda txt: txt.replace("’", "'").lower())
  tokenized_subreddits[subreddit] = data.map(lambda txt: tokenizer.tokenize(txt)) 

tokenized_subreddits['history']

0      [new, discovery, mode, turns, video, game, ass...
1      [we, are, not, here, to, help, you, with, your...
2      [a, 1776, excerpt, from, john, adam's, diary, ...
3      [famous, viking, warrior, burial, revealed, to...
4      [3, 000, year, old, underwater, castle, discov...
                             ...                        
987    [dna, study, has, now, provided, support, for,...
988    [stonehenge, megalith, came, from, scotland, n...
989    [french, resistance, man, breaks, silence, ove...
990    [holy, grail, of, shipwrecks', to, be, raised,...
991    [emily, wilson's, new, translation, of, the, i...
Length: 992, dtype: object

In [10]:
# bring tokens back together as data
subreddit_classes = {}
subreddit_df = pd.DataFrame(columns=['text', 'subreddit'])

next_class = 1
for subreddit in tokenized_subreddits:
  subreddit_classes[subreddit] = next_class
  data = pd.DataFrame()
  data['text'] = tokenized_subreddits[subreddit].map(lambda arr: ' '.join(arr))
  data['subreddit'] = next_class
  subreddit_df = pd.concat([subreddit_df, data], ignore_index=True)
  next_class += 1

subreddit_df

Unnamed: 0,text,subreddit
0,my cab driver tonight was so excited to share ...,1
1,guardians of the front page,1
2,gas station worker takes precautionary measure...,1
3,the conversation my son and i will have on chr...,1
4,the denver broncos have the entire town of sou...,1
...,...,...
49261,how come no one has invented a foot pedal for ...,50
49262,why are trans people talked about so much desp...,50
49263,did your penis ever fall asleep like your legs...,50
49264,someone stole my bike i tracked it to its loca...,50


In [11]:
from sklearn.model_selection import train_test_split

X = subreddit_df['text']
y = subreddit_df['subreddit']

X_train, X_test, y_train, y_test = train_test_split(X, y, shuffle=True)
X_train = X_train.astype('str').to_numpy()
X_test = X_test.astype('str')
y_train = y_train.astype('int')
y_test = y_test.astype('int')

In [12]:
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
from sklearn.linear_model import SGDClassifier
from sklearn.pipeline import Pipeline

reddit_classifer = Pipeline([
  ('cv', CountVectorizer()),
  ('tfidf', TfidfTransformer()),
  ('sgd', SGDClassifier()),
])

In [89]:
from sklearn.model_selection import GridSearchCV

grid_search = GridSearchCV(reddit_classifer, {
  'cv__ngram_range': [(1, 1), (1, 2)],
  'cv__max_df': [0.8, 0.9, 0.95, 1.0],
  'cv__min_df': [1, 3, 5],
  'cv__max_features': [1000, 2000, None],
  'sgd__loss': ['log_loss', 'modified_huber']
}, n_jobs=8, verbose=3)
grid_search.fit(X_train[:10000], y_train[:10000])
print(grid_search.best_score_)
grid_search.best_params_

Fitting 5 folds for each of 144 candidates, totalling 720 fits
[CV 1/5] END cv__max_df=0.8, cv__max_features=1000, cv__min_df=1, cv__ngram_range=(1, 1), sgd__loss=log_loss;, score=0.472 total time=   3.1s
[CV 2/5] END cv__max_df=0.8, cv__max_features=1000, cv__min_df=1, cv__ngram_range=(1, 1), sgd__loss=log_loss;, score=0.475 total time=   3.1s
[CV 3/5] END cv__max_df=0.8, cv__max_features=1000, cv__min_df=1, cv__ngram_range=(1, 1), sgd__loss=log_loss;, score=0.472 total time=   3.3s
[CV 4/5] END cv__max_df=0.8, cv__max_features=1000, cv__min_df=1, cv__ngram_range=(1, 1), sgd__loss=log_loss;, score=0.469 total time=   3.5s
[CV 5/5] END cv__max_df=0.8, cv__max_features=1000, cv__min_df=1, cv__ngram_range=(1, 1), sgd__loss=log_loss;, score=0.469 total time=   3.7s
[CV 1/5] END cv__max_df=0.8, cv__max_features=1000, cv__min_df=1, cv__ngram_range=(1, 1), sgd__loss=modified_huber;, score=0.425 total time=   3.5s
[CV 4/5] END cv__max_df=0.8, cv__max_features=1000, cv__min_df=1, cv__ngram_ran

{'cv__max_df': 1.0,
 'cv__max_features': None,
 'cv__min_df': 1,
 'cv__ngram_range': (1, 2),
 'sgd__loss': 'modified_huber'}

In [13]:
reddit_classifer = Pipeline([
  ('cv', CountVectorizer(ngram_range=(1,2), lowercase=False)),
  ('tfidf', TfidfTransformer()),
  ('sgd', SGDClassifier(loss='modified_huber')),
])
reddit_classifer.fit(X_train, y_train)

In [14]:
import numpy as np

y_pred = reddit_classifer.predict(X_test)
np.mean(y_pred == y_test)

0.6611187789234392

In [16]:
def predictPrompt(prompt):
  predicted_class = reddit_classifer.predict([prompt])[0].item()
  for subreddit in subreddit_classes:
    if subreddit_classes[subreddit] == predicted_class:
      return subreddit
  print("No applicable subreddit (should be unreachable)")

def predictPromptProbabilities(prompt, threshold=0.0):
  prediction = {}
  predictions = reddit_classifer.predict_proba([prompt])[0]
  print(predictions)
  print(reddit_classifer.classes_)
  for subreddit in subreddit_classes:
    print(subreddit_classes[subreddit], subreddit)
    if predictions[subreddit_classes[subreddit]-1] > threshold:
      prediction[subreddit] = predictions[subreddit_classes[subreddit]-1]
  return prediction

print(predictPrompt("gf"))
predictPromptProbabilities("gf", 0.05)

wholesomememes
[0.10031242 0.         0.02741464 0.         0.         0.05624285
 0.         0.01050269 0.00427124 0.         0.04077281 0.0272257
 0.00304892 0.         0.0351608  0.         0.         0.
 0.         0.         0.128286   0.03414995 0.02982647 0.01735402
 0.         0.         0.         0.0172122  0.01773006 0.
 0.         0.         0.0691712  0.01284538 0.         0.
 0.         0.         0.06692333 0.02176229 0.02721318 0.
 0.01262867 0.15147497 0.         0.05149836 0.         0.01668823
 0.02028362 0.        ]
[ 1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
 49 50]
1 funny
2 askreddit
3 gaming
4 aww
5 music
6 pics
7 science
8 worldnews
9 movies
10 todayilearned
11 videos
12 news
13 showerthoughts
14 earthporn
15 gifs
16 jokes
17 mildlyinteresting
18 iama
19 books
20 lifeprotips
21 diy
22 sports
23 nottheonion
24 food
25 explainlikeimfive
26 space
27 history
28 art

{'funny': 0.10031242101537506,
 'pics': 0.056242851713737406,
 'diy': 0.1282859976008176,
 'dataisbeautiful': 0.06917119811685594,
 'creepy': 0.06692333483015514,
 'wholesomememes': 0.151474966542808,
 'memes': 0.051498360483754946}

In [None]:
from flask import Flask, request, jsonify
from flask_cors import CORS

app = Flask(__name__)
CORS(app)

@app.route("/predict")
def predictBackend():
  print(request.args.get('prompt'))
  prediction = predictPromptProbabilities(request.args.get('prompt'), float(request.args.get('threshold')))
  return jsonify({'prediction': prediction})

app.run(host='0.0.0.0', port=3000)

 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on all addresses (0.0.0.0)
 * Running on http://127.0.0.1:3000
 * Running on http://192.168.1.86:3000
Press CTRL+C to quit
[2024-11-28 22:04:27,275] ERROR in app: Exception on /predict [GET]
Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/flask/app.py", line 2528, in wsgi_app
    response = self.full_dispatch_request()
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/flask/app.py", line 1825, in full_dispatch_request
    rv = self.handle_user_exception(e)
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/flask_cors/extension.py", line 165, in wrapped_function
    return cors_after_request(app.make_response(f(*args, **kwargs)))
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/flask/app.py", line 1823, in full_dispatch_request
    rv = self.dispatch_request()
  File "/Library/Frameworks/Py

tennis 


[2024-11-28 22:04:30,832] ERROR in app: Exception on /predict [GET]
Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/flask/app.py", line 2528, in wsgi_app
    response = self.full_dispatch_request()
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/flask/app.py", line 1825, in full_dispatch_request
    rv = self.handle_user_exception(e)
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/flask_cors/extension.py", line 165, in wrapped_function
    return cors_after_request(app.make_response(f(*args, **kwargs)))
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/flask/app.py", line 1823, in full_dispatch_request
    rv = self.dispatch_request()
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/flask/app.py", line 1799, in dispatch_request
    return self.ensure_sync(se

tennis 
