In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import csv
import os
import numpy as np
import pandas as pd
import seaborn as sns
sns.set_theme()

In [3]:
from wilds import get_dataset

In [4]:
data_dir = "/data/ddmg/redditlanguagemodeling/data/AmazonReviews/data"

In [5]:
dataset = get_dataset(dataset='amazon', download=True, root_dir=data_dir)

## Exploring Default Data Split

539,502 total reviews
3,920 total users
Mean # of reviews per user 137
Min # of reviews per user 75

5 splits: 
* Train: 245,502 reviews from 1,252 users 
    * Mean reviews per user 196; min is 75
* Val (OOD): 100,050 reviews from 1,334 users
    * 75 reviews per user
    * Users are disjoint with users in train data
* Val (ID): 46,950 reviews from 626 users
    * 75 reviews per user
    * Users are subset of training users
* Test (OOD): 100,050 reviews from 1,334 users
    * 75 reviews per user
    * Users are disjoint with both training users and val OOD users
* Test (ID): 46,950 reviews from 626 users
    * 75 reviews per user
    * Users are subset of training users
    * Users are disjoint with Val (ID) users

## Creating My Own Split

Question: should we use data from the target user when training?
* + Should we include the target user's validation data in the training data or just their train data?
* For now, could start with user that has enough data to have 75+ examples for training, 75 validation, and 75 test examples

So my own split would be:
* Pick target user and use them for validation & test data
* Train data is the same, except maybe some of the target user's data has been moved from train to val/test

In [6]:
dataset._split_array

array([0., 1., 1., ..., 3., 1., 3.])

In [7]:
len(dataset)

539502

In [8]:
dataset._metadata_fields

['user', 'product', 'category', 'year', 'y']

In [9]:
dataset._metadata_array[:, 0]

tensor([   0, 2204, 2204,  ..., 3320, 2575, 3394])

In [12]:
split_0_users = dataset._metadata_array[:, 0][dataset._split_array == 0]

In [13]:
users, counts = np.unique(split_0_users, return_counts=True)

In [14]:
print(min(counts), max(counts), np.mean(counts), np.std(counts))

75 1265 196.08785942492014 148.01675692920605


In [15]:
split_1_users = dataset._metadata_array[:, 0][dataset._split_array == 1]

In [16]:
users, counts = np.unique(split_1_users, return_counts=True)

(array([1252, 1253, 1254, ..., 2583, 2584, 2585]),
 array([75, 75, 75, ..., 75, 75, 75]))

In [64]:
len(set(np.array(dataset._metadata_array[:, 0])))

3920

In [70]:
users, counts = np.unique(np.array(dataset._metadata_array[:, 0]), return_counts=True)

In [71]:
print(min(counts), max(counts), np.mean(counts), np.std(counts))

75 1340 137.62806122448978 123.91844927092389


In [17]:
set(split_0_users).intersection(set(split_1_users))

set()

In [18]:
split_2_users = dataset._metadata_array[:, 0][dataset._split_array == 2]
np.unique(split_2_users, return_counts=True)

(array([   0,    2,    3,    7,    9,   11,   14,   15,   16,   17,   18,
          21,   24,   26,   27,   28,   29,   30,   32,   36,   39,   43,
          44,   45,   46,   49,   50,   54,   55,   56,   58,   60,   62,
          63,   64,   65,   71,   73,   75,   78,   83,   84,   85,   87,
          89,   90,   92,   96,   97,   99,  100,  101,  102,  106,  107,
         108,  109,  111,  112,  113,  117,  118,  119,  120,  121,  122,
         123,  125,  126,  128,  129,  131,  137,  138,  139,  142,  146,
         151,  153,  158,  162,  164,  167,  168,  170,  172,  176,  181,
         182,  184,  185,  189,  190,  192,  193,  194,  196,  197,  198,
         200,  201,  202,  203,  206,  207,  210,  212,  220,  224,  226,
         227,  230,  231,  232,  233,  235,  236,  238,  239,  241,  242,
         245,  246,  248,  251,  252,  256,  257,  259,  262,  263,  264,
         267,  268,  269,  271,  272,  273,  276,  277,  281,  282,  285,
         286,  287,  288,  290,  291, 

In [47]:
split_3_users = dataset._metadata_array[:, 0][dataset._split_array == 3]
np.unique(split_3_users, return_counts=True)

(array([2586, 2587, 2588, ..., 3917, 3918, 3919]),
 array([75, 75, 75, ..., 75, 75, 75]))

In [50]:
split_4_users = dataset._metadata_array[:, 0][dataset._split_array == 4]
np.unique(split_4_users, return_counts=True)

(array([   1,    4,    5,    6,    8,   10,   12,   13,   19,   20,   22,
          23,   25,   31,   33,   34,   35,   37,   38,   40,   41,   42,
          47,   48,   51,   52,   53,   57,   59,   61,   66,   67,   68,
          69,   70,   72,   74,   76,   77,   79,   80,   81,   82,   86,
          88,   91,   93,   94,   95,   98,  103,  104,  105,  110,  114,
         115,  116,  124,  127,  130,  132,  133,  134,  135,  136,  140,
         141,  143,  144,  145,  147,  148,  149,  150,  152,  154,  155,
         156,  157,  159,  160,  161,  163,  165,  166,  169,  171,  173,
         174,  175,  177,  178,  179,  180,  183,  186,  187,  188,  191,
         195,  199,  204,  205,  208,  209,  211,  213,  214,  215,  216,
         217,  218,  219,  221,  222,  223,  225,  228,  229,  234,  237,
         240,  243,  244,  247,  249,  250,  253,  254,  255,  258,  260,
         261,  265,  266,  270,  274,  275,  278,  279,  280,  283,  284,
         289,  292,  296,  298,  299, 

In [28]:
len(dataset._split_array)

539502

In [29]:
sum(dataset._split_array == 0)

245502

In [30]:
len(set(np.array(split_0_users)))

1252

In [31]:
sum(dataset._split_array == 1)

100050

In [32]:
len(set(np.array(split_1_users)))

1334

In [34]:
set(np.array(split_0_users)).intersection(np.array(split_1_users))

set()

In [35]:
sum(dataset._split_array == 2)

46950

In [36]:
len(set(np.array(split_2_users)))

626

In [56]:
len(set(np.array(split_2_users)).intersection(set(np.array(split_0_users))))

626

In [44]:
sum(dataset._split_array == 3)

100050

In [53]:
len(set(np.array(split_3_users)))

1334

In [46]:
set(np.array(split_3_users)).intersection(np.array(split_1_users))

set()

In [48]:
set(np.array(split_3_users)).intersection(np.array(split_0_users))

set()

In [49]:
sum(dataset._split_array == 4)

46950

In [57]:
len(set(np.array(split_4_users)).intersection(np.array(split_0_users)))

626

In [58]:
len(set(np.array(split_4_users)).intersection(np.array(split_2_users)))

0

In [54]:
len(set(np.array(split_4_users)))

626

In [63]:
dataset._eval_grouper._n_groups

3920

## What does full dataset (before restricting to those in user splits look like)?

In [6]:
data_df = pd.read_csv(os.path.join(data_dir, 'amazon_v2.0/reviews.csv'),
                      dtype={'reviewerID':str, 'asin':str, 'reviewTime':str,'unixReviewTime':int,
                             'reviewText':str,'summary':str,'verified':bool,'category':str, 'reviewYear':int},
                      keep_default_na=False, na_values=[], quoting=csv.QUOTE_NONNUMERIC)

In [68]:
len(data_df)

10116947

In [72]:
len(set(data_df['reviewerID']))

155656

In [77]:
in_dataset = split_df['split'] != -1

In [78]:
len(in_dataset)

10116947

In [76]:
split_df = pd.read_csv(os.path.join(data_dir, 'amazon_v2.0', 'splits', 'user.csv'))

In [81]:
# how much data to existing, in domain test users have?
is_id_test_user = split_df['split'] == 4
id_test_users_df = data_df[is_id_test_user]

In [82]:
id_test_users = set(id_test_users_df["reviewerID"])

In [83]:
len(id_test_users)

626

In [87]:
75 * 3

225

In [101]:
clean_df = data_df[split_df["clean"] == True]

In [102]:
len(clean_df)

8707610

In [103]:
user_count_df = clean_df.groupby(["reviewerID"]).count()["reviewText"]
user_count_df.loc[id_test_users]

reviewerID
A2LYT2QET2KC4N    194
A55ER33W947XU     175
A2MF4TISBBQT5A    466
A1RV4398Y02WZH    257
A3AAWCXSF99MWG    167
                 ... 
A216SKRFFQU4ZM    224
A2NRARXIUZ2XON    195
AQIA837SENFW2     209
ABLLHG0REGFX7     164
AH2KUKJWI6XW3     158
Name: reviewText, Length: 626, dtype: int64

In [104]:
test_user_count_df = user_count_df.loc[id_test_users]

In [105]:
test_user_count_df[test_user_count_df >= 225]

reviewerID
A2MF4TISBBQT5A    466
A1RV4398Y02WZH    257
A3OQCKMMBDXYPX    253
A3IZSGFX4FOG85    268
A2KQP13HS13X0U    476
                 ... 
A221FPD86E5O87    225
A3I2P8G8JTEH23    333
A33KNBGC6UM133    587
A3KEB381QTO61R    243
A152EA7ZPCA1Q8    283
Name: reviewText, Length: 311, dtype: int64

In [106]:
test_users = test_user_count_df[test_user_count_df >= 225].index.to_numpy()

In [107]:
# can select any of these users!
np.random.choice(311)

206

In [108]:
target_user = test_users[206]

In [109]:
target_user

'A2KWQ64TRHB3YH'

In [110]:
test_user_count_df.loc[target_user]

961

In [111]:
test_user_idxs = data_df["reviewerID"] == target_user

In [112]:
sum(test_user_idxs)

1006

In [113]:
test_user_splits = split_df[test_user_idxs]

In [114]:
test_user_splits

Unnamed: 0,split,clean
108038,0.0,True
108039,0.0,True
110336,0.0,True
111796,-1.0,False
112501,0.0,True
...,...,...
9781230,0.0,True
9967131,0.0,True
9973476,0.0,True
9975059,0.0,True


In [116]:
np.unique(test_user_splits["split"], return_counts=True)

(array([-1.,  0.,  4.]), array([ 45, 886,  75]))

In [117]:
886 + 75

961

In [118]:
split_df[split_df["clean"] == False]["split"].unique()

array([-1.])

In [129]:
# make new split df where 75 of these training exmaples are for validation and there is no OOD val/test splits

new_split_df = split_df.copy()
# get rid of old val/test splits
old_split_idxs = new_split_df[new_split_df["split"].isin([1, 2, 3, 4])].index
new_split_df.loc[old_split_idxs, "split"] = -1
# make new test split with just the new target user's data
test_idxs = split_df[split_df["split"] == 4].index
user_test_idxs = set(test_user_splits.index).intersection(test_idxs)
print("num user test idxs", len(user_test_idxs))
new_split_df.loc[user_test_idxs, "split"] = 2
# make new val split with just the new target user's data
train_idxs = split_df[split_df["split"] == 0].index
user_train_idxs = set(test_user_splits.index).intersection(set(train_idxs))
print("number of target user train samples", len(user_train_idxs))
val_idxs = np.random.choice(list(user_train_idxs), 75, replace=False)
new_split_df.loc[val_idxs, "split"] = 1

num user test idxs 75
number of target user train samples 886


In [130]:
np.unique(new_split_df["split"], return_counts=True)

(array([-1.,  0.,  1.,  2.]), array([9871370,  245427,      75,      75]))

In [131]:
new_data_df = data_df[new_split_df["split"] != -1]

In [135]:
245502 - 75

245427

In [136]:
len(data_df[split_df["split"] == 0]["reviewerID"].unique())

1252

In [138]:
len(data_df[new_split_df["split"] == 0]["reviewerID"].unique())

1252

In [142]:
data_df[new_split_df["split"] == 2]["reviewerID"].unique()

array(['A2KWQ64TRHB3YH'], dtype=object)

In [144]:
np.unique(new_split_df[data_df["reviewerID"] == 'A2KWQ64TRHB3YH']['split'], return_counts=True)

(array([-1.,  0.,  1.,  2.]), array([ 45, 811,  75,  75]))

In [145]:
# save split df
new_split_df.to_csv(os.path.join(data_dir, 'amazon_v2.0', 'splits', 'user_{}.csv'.format(target_user)), index=False)

## Create Smaller Datasets

In [None]:
data_df = pd.read_csv(os.path.join(data_dir, 'amazon_v2.0/reviews.csv'),
                      dtype={'reviewerID':str, 'asin':str, 'reviewTime':str,'unixReviewTime':int,
                             'reviewText':str,'summary':str,'verified':bool,'category':str, 'reviewYear':int},
                      keep_default_na=False, na_values=[], quoting=csv.QUOTE_NONNUMERIC)

In [16]:
og_split_df = pd.read_csv(os.path.join(data_dir, 'amazon_v2.0', 'splits', 'user.csv'))

In [17]:
new_split_df = pd.read_csv(os.path.join(data_dir, 'amazon_v2.0', 'splits', 'target_user_A2KWQ64TRHB3YH.csv'))

In [18]:
is_train1 = og_split_df["split"] == 0
is_train2 = new_split_df["split"] == 0

In [19]:
sum(is_train1)

245502

In [20]:
sum(is_train2)

245427

In [21]:
sum(is_train1 * is_train2)



245427

In [22]:
is_target_user = data_df["reviewerID"] == 'A2KWQ64TRHB3YH'

In [23]:
sum(is_target_user)

1006

In [25]:
# remove some training data that isn't from target user
other_train = is_train2 & ~is_target_user

In [26]:
sum(other_train)

244616

In [27]:
other_train_idxs = np.arange(len(data_df))[other_train]

In [28]:
len(other_train_idxs)

244616

In [34]:
# figure out what test/val idxs to remove from original data
is_val = og_split_df["split"] == 1
sum(is_val)

100050

In [35]:
is_test = og_split_df["split"] == 3
sum(is_test)

100050

In [37]:
# take 1,000 of these
val_idxs = np.arange(len(data_df))[is_val]
test_idxs = np.arange(len(data_df))[is_test]
keep_val = np.random.choice(val_idxs, 1000, replace=False)
keep_test = np.random.choice(test_idxs, 1000, replace=False)
remove_val = set(val_idxs).difference(keep_val)
remove_test = set(test_idxs).difference(keep_test)

In [38]:
len(remove_val)

99050

In [40]:
len(other_train_idxs)

244616

In [43]:
len(remove_idxs)

243616

In [47]:
og_split_df.loc[remove_idxs, "split"]

1572865    0.0
2097153    0.0
6815745    0.0
524292     0.0
7864321    0.0
          ... 
3145719    0.0
8388600    0.0
2621435    0.0
4718588    0.0
4718589    0.0
Name: split, Length: 243616, dtype: float64

In [48]:
# sample random subsets
for subset in [1000, 10000]:
    keep_idxs = np.random.choice(other_train_idxs, subset, replace=False)
    remove_idxs = list(set(other_train_idxs).difference(set(keep_idxs)))
    s_og_split_df = og_split_df.copy()
    print(remove_idxs)
    s_og_split_df.loc[remove_idxs, "split"] = -1
    s_og_split_df.loc[list(remove_val), "split"] = -1
    s_og_split_df.loc[list(remove_test), "split"] = -1
    s_og_split_df.to_csv(os.path.join(data_dir, 'amazon_v2.0', 'splits', 'user_subset_{}.csv'.format(subset)), index=False)
    s_new_split_df = new_split_df.copy()
    s_new_split_df.loc[remove_idxs, "split"] = -1
    s_new_split_df.to_csv(os.path.join(data_dir, 'amazon_v2.0', 'splits', 'target_user_A2KWQ64TRHB3YH_subset_{}.csv'.format(subset)), index=False)

[1572865, 2097153, 6815745, 524292, 7864321, 4718599, 3145736, 7864329, 10, 3145741, 8388621, 1572879, 7864339, 3145748, 524309, 1572885, 6815767, 8388631, 8912922, 6815772, 1572893, 3670049, 1572898, 9961507, 524329, 2621483, 4194352, 5242933, 4718647, 5242940, 8388671, 3145792, 9961540, 4718669, 524366, 8388688, 3670100, 8388697, 1572954, 2621534, 8388704, 1572963, 7864420, 7864421, 7864422, 4194407, 2097256, 4194409, 4718699, 7864433, 4194423, 1572992, 2621568, 4194434, 4718721, 5243010, 1572999, 8388748, 5243024, 8388753, 1573013, 3145877, 2621593, 154, 1573018, 3145882, 5243037, 4718757, 4194470, 8388773, 8388775, 4194476, 8388781, 5243056, 7864498, 5243069, 3670212, 4194500, 524494, 6815950, 7864531, 4194516, 4194517, 2621654, 3670232, 6291676, 524511, 4718818, 3670244, 1048805, 8388838, 4194537, 7864553, 2621675, 3670255, 4718831, 5243120, 2621682, 4194547, 5243124, 7864566, 8388855, 4194555, 1048828, 1573117, 254, 1048832, 1048835, 3145988, 4194564, 4718853, 4194567, 4718856, 2

[1572865, 2097153, 6815745, 524292, 7864321, 4718599, 3145736, 7864329, 3145741, 8388621, 1572879, 7864339, 3145748, 524309, 1572885, 6815767, 8388631, 8912922, 6815772, 3670049, 1572898, 9961507, 524329, 2621483, 4194352, 5242933, 4718647, 5242940, 8388671, 3145792, 9961540, 4718669, 524366, 8388688, 3670100, 8388697, 1572954, 2621534, 8388704, 1572963, 7864420, 7864421, 7864422, 4194407, 2097256, 4194409, 4718699, 7864433, 4194423, 1572992, 2621568, 4194434, 4718721, 5243010, 1572999, 8388748, 5243024, 8388753, 1573013, 3145877, 2621593, 154, 1573018, 3145882, 5243037, 4718757, 4194470, 8388773, 8388775, 4194476, 8388781, 5243056, 7864498, 5243069, 4194498, 3670212, 4194500, 524494, 6815950, 7864531, 4194516, 4194517, 2621654, 3670232, 6291676, 524511, 4718818, 3670244, 1048805, 8388838, 4194537, 7864553, 2621675, 3670255, 4718831, 5243120, 2621682, 4194547, 5243124, 7864566, 8388855, 4194555, 1048828, 1573117, 254, 1048832, 1048835, 3145988, 4194564, 4718853, 4194567, 4718856, 265, 

In [57]:
split_df = pd.read_csv(os.path.join(data_dir, 'amazon_v2.0', 'splits', 'target_user_A2KWQ64TRHB3YH_subset_10000.csv'))

In [62]:
sum(split_df["split"] == 2)

75