In [2]:
import pandas as pd
import json
import random

In [5]:
df = pd.read_csv('../data/reddit_ratings.csv', header=None)
df.columns = ['username', 'subreddit', 'utc']
df = df.sort_values(['username', 'utc'])
df.head()

Unnamed: 0,username,subreddit,utc
8960539,---LFC---,soccer,1562356000.0
8960538,---LFC---,soccer,1562358000.0
8960537,---LFC---,soccer,1562359000.0
8960536,---LFC---,LiverpoolFC,1562364000.0
8960535,---LFC---,nba,1562395000.0


In [7]:
df = df.drop_duplicates(subset=['username', 'subreddit'])
df = df.groupby('username').filter(lambda x: len(x) >= 20)

In [12]:
df['username'] = df['username'].astype('category')
df['subreddit'] = df['subreddit'].astype('category')

In [13]:
username_categories = df['username'].cat.categories
user_dict = {i: cat for i, cat in enumerate(username_categories)}
inv_user_dict = {cat: i for i, cat in enumerate(username_categories)}

subreddit_categories = df['subreddit'].cat.categories
subreddit_dict = {i: cat for i, cat in enumerate(subreddit_categories)}
inv_subreddit_dict = {cat: i for i, cat in enumerate(subreddit_categories)}

In [141]:
with open('../data/user.json', 'w') as f:
    json.dump([user_dict, inv_user_dict], f)
    
with open('../data/subreddit.json', 'w') as f:
    json.dump([subreddit_dict, inv_subreddit_dict], f)

In [142]:
df['username'] = df['username'].cat.codes
df['subreddit'] = df['subreddit'].cat.codes

In [148]:
df_train = df.groupby('username', group_keys=False).apply(lambda group: group.iloc[1:])
df_test_positive = df.groupby('username', group_keys=False).first().reset_index()

In [163]:
df_test_negative = pd.DataFrame(columns=['user_item'] + ['negativeItemID' + str(i) for i in range(1,100)])
subreddits = range(len(df['subreddit'].unique()))

for i, row in df_test_positive.iterrows():
    if i%500 == 0:
        print('Processing: ' + str(i))
    username = row['username']
    subreddit = row['subreddit']
    interaction = (username, subreddit)
    user_subreddits = df[df['username'] == username]['subreddit'].unique()
    user_neg_subreddits = [sub for sub in subreddits if sub not in user_subreddits]
    sampled_neg_subreddits = random.sample(user_neg_subreddits, k=99)
    df_test_negative.loc[len(df_test_negative)] = [interaction] + sampled_neg_subreddits
df_test_negative

Processing: 0
Processing: 500
Processing: 1000
Processing: 1500
Processing: 2000
Processing: 2500
Processing: 3000
Processing: 3500
Processing: 4000
Processing: 4500
Processing: 5000
Processing: 5500
Processing: 6000
Processing: 6500
Processing: 7000
Processing: 7500
Processing: 8000
Processing: 8500
Processing: 9000
Processing: 9500
Processing: 10000
Processing: 10500
Processing: 11000
Processing: 11500
Processing: 12000
Processing: 12500
Processing: 13000
Processing: 13500


Unnamed: 0,user_item,negativeItemID1,negativeItemID2,negativeItemID3,negativeItemID4,negativeItemID5,negativeItemID6,negativeItemID7,negativeItemID8,negativeItemID9,...,negativeItemID90,negativeItemID91,negativeItemID92,negativeItemID93,negativeItemID94,negativeItemID95,negativeItemID96,negativeItemID97,negativeItemID98,negativeItemID99
0,"(0.0, 27185.0)",28313,26548,2829,9105,16807,12497,3440,29982,17608,...,19496,6542,14540,26733,6876,26799,15889,9675,17345,23864
1,"(1.0, 1345.0)",25551,7644,12320,29810,28072,32244,5666,20033,29781,...,16745,1364,27229,32780,6930,31535,10570,12499,4002,23489
2,"(2.0, 31588.0)",2916,27779,32299,7815,2940,5005,442,15980,20021,...,15068,628,4665,7363,22378,19119,19730,31598,12040,29177
3,"(3.0, 31277.0)",21950,11727,17436,23625,11272,15718,18914,92,16765,...,30378,14837,22762,19621,24540,7046,7727,5070,3591,10552
4,"(4.0, 16211.0)",11846,29032,31322,21373,3782,28524,15355,32581,26808,...,10701,14536,14589,23356,24303,10212,26956,32864,20943,13098
5,"(5.0, 19058.0)",9283,19465,14972,17416,6394,15608,5817,3693,13331,...,5409,4231,5708,23691,31883,3265,9310,31941,21294,22119
6,"(6.0, 18762.0)",18992,17222,14082,9885,277,307,16708,22971,28273,...,13586,10397,3231,32630,1295,14189,5714,32490,25810,31501
7,"(7.0, 193.0)",13837,8199,29163,29160,27289,28881,22454,33122,15138,...,28509,18135,25471,27859,7995,14647,27825,21567,27650,3712
8,"(8.0, 8501.0)",20099,1493,537,24866,525,26489,9697,8263,11583,...,14211,24962,19044,24301,2905,9377,28265,15903,408,17950
9,"(9.0, 16063.0)",18616,2553,24932,298,14362,12096,24483,32675,6761,...,22926,13629,29607,16298,31868,18798,6053,2112,31493,23394


In [171]:
df_train.to_csv('../data/reddit_train.csv', index=False, header=False)
df_test_positive.to_csv('../data/reddit_test_positive.csv', index=False, header=False)
df_test_negative.to_csv('../data/reddit_test_negative.csv', index=False, header=False)