# Reddit Flair Classification
## Scraping from r/india
The reddit posts are scraped using the pushshift api and saved to a csv. Later, posts pertinent to the categories mentioned below are extraced by randomly sampling from the a pool of posts from the csv.

Flair categories considered(the Non-Political category is not considered, due to its vague attribution, as "Non-Political" can belong to any flair category which is not political.)
1. 'Business/Finance'
2. 'Policy/Economy'
3. 'Photography'
4. 'Politics'
5. 'Sports'
6. '[R]eddiquette'
7. 'Food'
8. 'Science/Technology'
9. 'AskIndia'
10. 'CAA-NRC'
11. 'Coronavirus'

To scrape the reddit posts, the pushshift api is used, the PRAW api could have also been used but, PRAW doesn't allow to crawl more than 1000 posts, therefore, I resort to the pushshift api. Though, with PRAW the scraping is  much simpler

pushshift api gives access to a json from which the requred fields can be extracted.



### Importing Libraries

In [0]:
import numpy as np
import pandas as pd
import requests
import json
import csv
import time
import datetime


### Functions
The functions below perform elementary tasks.
getPushshiftData() generates the URL and accesses the JSON, and returns the dictionary for further extractions
collectSubData() extracts the information from JSON, by accessing elements with keywords.

In [0]:
def getPushshiftData(sub, after, before):
    url = 'https://api.pushshift.io/reddit/search/submission/?after='+str(after)+'&before='+str(before)+'&subreddit='+str(sub)
    print(url)
    r = requests.get(url)
    data = json.loads(r.text)
    return data['data']

def collectSubData(subm):
    subData = list() #list to store data points
    title = subm['title']
    url = subm['url']

    try:
        flair = subm['link_flair_text']
    except KeyError:
        flair = "NaN" 
    try:
        selftext = subm['selftext']
    except KeyError:
        selftext = ""
    
    author = subm['author']
    sub_id = subm['id']
    score = subm['score']
    created = datetime.datetime.fromtimestamp(subm['created_utc']) 
    numComms = subm['num_comments']
    permalink = subm['permalink']
    
    subData.append((sub_id,title,url,author,score,created,numComms,permalink,flair, selftext))
    subStats[sub_id] = subData

In [0]:
#Variable initializations 
#Subreddit to query
sub='india'
#before and after dates

after = "1546300800"  #January 1st 2019
before = "1586217355" # April 6th 2020

subCount = 0
subStats = {}

### Main scraping code

Each api call gives access to 25 reddit post starting from the the time provided to the 'after' argument as a unix
timestamp, therefore, we make the api call until the while condition is false i.e. all posts before the 'before' timestamp have been accessed

In [0]:
data = getPushshiftData(sub,after, before)# Will run until all posts have been gathered 
# from the 'after' date up until before date
while len(data) > 0:
    for submission in data:
        collectSubData(submission)
        subCount+=1
    # Calls getPushshiftData() with the created date of the last submission
    print(len(data))
    print(str(datetime.datetime.fromtimestamp(data[-1]['created_utc'])))
    after = data[-1]['created_utc']
    data = getPushshiftData(sub,after, before)
    
print(len(data))

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
25
2020-01-11 10:01:51
https://api.pushshift.io/reddit/search/submission/?after=1578736911&before=1586217355&subreddit=india
25
2020-01-11 10:49:26
https://api.pushshift.io/reddit/search/submission/?after=1578739766&before=1586217355&subreddit=india
25
2020-01-11 11:38:58
https://api.pushshift.io/reddit/search/submission/?after=1578742738&before=1586217355&subreddit=india
25
2020-01-11 12:14:26
https://api.pushshift.io/reddit/search/submission/?after=1578744866&before=1586217355&subreddit=india
25
2020-01-11 13:07:41
https://api.pushshift.io/reddit/search/submission/?after=1578748061&before=1586217355&subreddit=india
25
2020-01-11 14:18:32
https://api.pushshift.io/reddit/search/submission/?after=1578752312&before=1586217355&subreddit=india
25
2020-01-11 15:39:28
https://api.pushshift.io/reddit/search/submission/?after=1578757168&before=1586217355&subreddit=india
25
2020-01-11 16:26:03
https://api.pushshift.io/reddit/searc

In [0]:
print(str(len(subStats)) + " submissions have added to list")
print("1st entry is:")
print(list(subStats.values())[0][0][1] + " created: " + str(list(subStats.values())[0][0][5]))
print("Last entry is:")
print(list(subStats.values())[-1][0][1] + " created: " + str(list(subStats.values())[-1][0][5]))

232981 submissions have added to list
1st entry is:
[P] New Year smiles: What do you say to get Mukesh Ambani’s attention? ‘Sir, Jio chal nahi raha hai’ created: 2019-01-01 00:08:08
Last entry is:
Coronavirus Cases Doubling Every 4 Days, Set To Touch 17,000 In A Week created: 2020-04-06 23:55:15


In [0]:
def updateSubs_file():
    upload_count = 0
    location = "./"
    print("input filename of submission file, please add .csv")
    filename = input()
    file = location + filename
    with open(file, 'a', newline='', encoding='utf-8') as file: 
        a = csv.writer(file, delimiter=',')
        headers = ["id","title","url","author","score","publish_date","num_comment","permalink","flair", "selftext"]
        a.writerow(headers)
        for sub in subStats:
            a.writerow(subStats[sub][0])
            upload_count+=1
            
        print(str(upload_count) + " submissions have been uploaded")

In [0]:
updateSubs_file()

input filename of submission file, please add .csv
scrapped.csv
232981 submissions have been uploaded


### Extracting relevant posts
The saved csv is imported as a dataframe for extracting the datapoints with belong to the flairs listed in the beginning.

In [0]:
data = pd.read_csv("./scrapped.csv")

In [0]:
def CountFrequency(my_list): 
    freq = {} 
    for item in my_list: 
        if (item in freq): 
            freq[item] += 1
        else: 
            freq[item] = 1
  
    return freq

unique = set(list(data['flair']))
print(unique)

In [0]:
freq = CountFrequency(list(data['flair']))
sorted_freq = {k: v for k, v in sorted(freq.items(), key=lambda item: item[1], reverse=True)}
print(sorted_freq)

In [0]:
data = data.replace('CAA-NRC-NPR', 'CAA-NRC') #Clubbing together these similar flair classes

In [0]:
freq = CountFrequency(list(data['flair']))
sorted_freq = {k: v for k, v in sorted(freq.items(), key=lambda item: item[1], reverse=True)}
print(sorted_freq)

In [0]:
flairs = ['Business/Finance', 'Policy/Economy', 'Photography', 'Politics', 'Sports', '[R]eddiquette', 'Food', 'Science/Technology', 'AskIndia','CAA-NRC', 'Coronavirus']

Builing a balanced dataset by randomly sampling a maximum of 'n'(defined below) post per flair, from the scrapped csv

In [0]:
n = 2000

np.random.seed(42)
keep = []
flairs = [flair for flair in flairs if not str(flair) == 'nan']
for flair in flairs:
    l = len(data[data['flair'] == flair])
    if l > n:
        l = n
    idx = list(data[data['flair'] == flair]['id'])
    c = np.random.choice(idx, l, replace=False)
    for i in c:
        keep.append(i)

print (len(keep))

In [0]:
data = data[data['id'].isin(keep)]
data.to_csv("for_preprocessing.csv",index=False)