In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import csv
import os
import numpy as np
import pandas as pd
import seaborn as sns
sns.set_theme()

In [3]:
data_dir = "/data/ddmg/redditlanguagemodeling/data/AmazonReviews/data"

In [4]:
data_df = pd.read_csv(os.path.join(data_dir, 'amazon_v2.0/reviews.csv'),
                      dtype={'reviewerID':str, 'asin':str, 'reviewTime':str,'unixReviewTime':int,
                             'reviewText':str,'summary':str,'verified':bool,'category':str, 'reviewYear':int},
                      keep_default_na=False, na_values=[], quoting=csv.QUOTE_NONNUMERIC)

In [5]:
split_df = pd.read_csv(os.path.join(data_dir, 'amazon_v2.0', 'splits', 'user.csv'))

In [6]:
data_df["split"] = split_df["split"]

In [20]:
clean_df = data_df[split_df["clean"]]

In [21]:
len(clean_df)

8707610

In [28]:
# pick a few categories and pick reviews just from them
include_categories = [
    "Movies_and_TV",
    "Books", # may want to take a subset of these because there are so many
    "Tools_and_Home_Improvement",
    "Home_and_Kitchen"
]

In [29]:
my_df = clean_df[clean_df["category"].isin(include_categories)]

In [30]:
my_df.groupby("split").count()["overall"]

split
-1.0    5921174
 0.0     198171
 1.0      79846
 2.0      37385
 3.0      79940
 4.0      37649
Name: overall, dtype: int64

In [31]:
with pd.option_context('display.max_rows', None, 'display.max_columns', None):
    label_dist = my_df.groupby(["split", "category", "overall"]).count()["summary"]
    display(label_dist)

split  category                    overall
-1.0   Books                       1.0          93030
                                   2.0         155479
                                   3.0         441083
                                   4.0        1199356
                                   5.0        2820013
       Home_and_Kitchen            1.0          20488
                                   2.0          18837
                                   3.0          38820
                                   4.0          92737
                                   5.0         352273
       Movies_and_TV               1.0          26983
                                   2.0          29547
                                   3.0          61736
                                   4.0         113558
                                   5.0         271932
       Tools_and_Home_Improvement  1.0           6859
                                   2.0           5918
                                   3.0 

In [27]:
with pd.option_context('display.max_rows', None, 'display.max_columns', None):
    label_dist = clean_df.groupby(["category", "overall"]).count()["summary"]
    display(label_dist)

category                     overall
All_Beauty                   1.0              1
                             2.0              8
                             3.0             14
                             4.0             48
                             5.0            156
Arts_Crafts_and_Sewing       1.0            504
                             2.0            469
                             3.0           1142
                             4.0           2625
                             5.0          13710
Automotive                   1.0           2199
                             2.0           1821
                             3.0           4219
                             4.0          11024
                             5.0          45485
Books                        1.0          97106
                             2.0         166401
                             3.0         478081
                             4.0        1314414
                             5.0        3049325
CDs

In [32]:
with pd.option_context('display.max_rows', None, 'display.max_columns', None):
    label_dist = my_df.groupby(["category", "overall"]).count()["summary"]
    display(label_dist)

category                    overall
Books                       1.0          97106
                            2.0         166401
                            3.0         478081
                            4.0        1314414
                            5.0        3049325
Home_and_Kitchen            1.0          20746
                            2.0          19161
                            3.0          39715
                            4.0          95460
                            5.0         359516
Movies_and_TV               1.0          27786
                            2.0          30918
                            3.0          65052
                            4.0         119884
                            5.0         281980
Tools_and_Home_Improvement  1.0           6946
                            2.0           6006
                            3.0          13030
                            4.0          34255
                            5.0         128383
Name: summary, dtype: in

In [14]:
labels = [1, 2, 3, 4, 5]

In [41]:
# sample points to keep in training data for each category
for cat in include_categories:
    cat_df = data_df[data_df["category"] == cat]
    for label in labels:
        full_idx = cat_df[cat_df["overall"] == label].index
        keep = np.random.choice(full_idx, 900, replace=False)
        keep_set = set(keep)
        exclude = [x for x in full_idx if x not in keep_set]
        data_df.loc[exclude, "split"] = -1
        data_df.loc[keep[:500], "split"] = 0
        data_df.loc[keep[500:700], "split"] = 1
        data_df.loc[keep[700:], "split"] = 2

In [42]:
# make examples in other categories -1
for cat in set(data_df["category"]).difference(set(include_categories)):
    cat_idx = data_df[data_df["category"] == cat].index
    data_df.loc[cat_idx, "split"] = -1

In [36]:
with pd.option_context('display.max_rows', None, 'display.max_columns', None):
    label_dist = data_df[data_df["category"].isin(include_categories)].groupby(["split", "category", "overall"]).count()["summary"]
    display(label_dist)

split  category                    overall
-1.0   Books                       1.0         103228
                                   2.0         174866
                                   3.0         509680
                                   4.0        1395106
                                   5.0        3326090
       Home_and_Kitchen            1.0          21527
                                   2.0          19803
                                   3.0          43087
                                   4.0         109217
                                   5.0         427899
       Movies_and_TV               1.0          30470
                                   2.0          33985
                                   3.0          76659
                                   4.0         147005
                                   5.0         397341
       Tools_and_Home_Improvement  1.0           6765
                                   2.0           5697
                                   3.0 

In [43]:
my_split_df = split_df.copy()
my_split_df["split"] = data_df["split"]
my_split_df["split"].value_counts()

-1.0    10098947
 0.0       10000
 1.0        4000
 2.0        4000
Name: split, dtype: int64

In [44]:
my_split_df.to_csv(os.path.join(data_dir, 'amazon_v2.0', 'splits', 'small_even_books_movie_kitchen_tools.csv'), index=False)