In [1]:
from datasets import load_dataset
import numpy as np
from loguru import logger

# Load data

In [2]:
dataset = load_dataset("McAuley-Lab/Amazon-Reviews-2023", "5core_timestamp_Video_Games", trust_remote_code=True)

In [3]:
dataset['train']

Dataset({
    features: ['user_id', 'parent_asin', 'rating', 'timestamp'],
    num_rows: 736827
})

In [4]:
dataset['valid']

Dataset({
    features: ['user_id', 'parent_asin', 'rating', 'timestamp'],
    num_rows: 34510
})

In [5]:
def parse_dtype(df):
    return (
        df
        .assign(
            rating=lambda df: df['rating'].astype(float),
            timestamp=lambda df: df['timestamp'].astype(int)
        )
    )

train_raw = dataset['train'].to_pandas().pipe(parse_dtype)
val_raw = dataset['valid'].to_pandas().pipe(parse_dtype)

# Sample data

In [6]:
SAMPLE_VAL_ROWS = 30000
if SAMPLE_VAL_ROWS:
    random_seed = 42
    np.random.seed(random_seed)
    val_sample = val_raw.sample(SAMPLE_VAL_ROWS, random_state=random_seed)
    sample_users = val_sample['user_id'].unique()
    # Insist that val and train share the same pool of users and items
    sample_items = val_sample['parent_asin'].unique()
    train_sample = train_raw.loc[lambda df: df['parent_asin'].isin(sample_items) & df['user_id'].isin(sample_users)]
    train_items = train_sample['parent_asin'].unique()
    train_users = train_sample['user_id'].unique()
    val_sample = val_sample.loc[lambda df: df['parent_asin'].isin(train_items) & df['user_id'].isin(train_users)]
    # if train_raw.shape[0] > SAMPLE_TRAIN_ROWS:
    #     logger.info(f"{train_sample.shape[0]=:,.0f} exceeding SAMPLE_TRAIN_ROWS. Downsampling...")
    #     train_sample = train_sample.sample(SAMPLE_TRAIN_ROWS, random_state=random_seed)

val_items = val_sample['parent_asin'].unique()
val_users = val_sample['user_id'].unique()

logger.info(f"{len(train_items)=}, {len(train_users)=}")
logger.info(f"{len(val_items)=}, {len(val_users)=}")
val_users_in_train = set(val_users).intersection(set(train_users))
val_items_in_train = set(val_items).intersection(set(train_items))
logger.info(f"Percentage of val users in train: {len(val_users_in_train) / len(val_users):,.0%}")
logger.info(f"Percentage of val items in train: {len(val_items_in_train) / len(val_items):,.0%}")

[32m2024-09-14 10:15:41.912[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m20[0m - [1mlen(train_items)=5429, len(train_users)=12397[0m
[32m2024-09-14 10:15:41.913[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m21[0m - [1mlen(val_items)=4524, len(val_users)=9388[0m
[32m2024-09-14 10:15:41.917[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m24[0m - [1mPercentage of val users in train: 100%[0m
[32m2024-09-14 10:15:41.917[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m25[0m - [1mPercentage of val items in train: 100%[0m


In [7]:
train_sample

Unnamed: 0,user_id,parent_asin,rating,timestamp
1,AEVPPTMG43C6GWSR7I2UGRQN7WFQ,B0863MT183,4.0,1613701986538
2,AEVPPTMG43C6GWSR7I2UGRQN7WFQ,B08P8P7686,5.0,1613702112995
3,AEVPPTMG43C6GWSR7I2UGRQN7WFQ,B0B7LV3DN2,4.0,1617641445475
4,AEVPPTMG43C6GWSR7I2UGRQN7WFQ,B09WMQ6DXG,5.0,1620231368468
10,AFFZVSTUS3U2ZD22A2NPZSKOCPGQ,B01GW3LRD2,5.0,1491589434000
...,...,...,...,...
736532,AF4VJ4NQ7LO256VSOVNQ6Q5PGNBA,B003LJSJXW,4.0,1293327789000
736602,AGRXRIPAZTGAQHFKXZLFJDUOJSJA,B08FRMGWXQ,5.0,1609218633016
736700,AERQISDPMPFJZKZ7P6A5FGX6RP5Q,B0000AHOOK,5.0,1415342350000
736772,AEFPHMM7CLX4UJNXJFQF4ZF5GNAA,B07P27XFP7,5.0,1599585146628


In [8]:
val_sample

Unnamed: 0,user_id,parent_asin,rating,timestamp
27793,AGCL7QDBZ24RZHTSPHSQ4ZXSG3RQ,B07D29PHFY,3.0,1640487364737
3662,AGS4TR4K5DMBRAFNBYSB2I2RCHHQ,B0936HDGJ6,5.0,1652494657651
15518,AF7HTSEWIKYSP5D3ST4EZIUK6PJQ,B08F5T3F9Y,5.0,1644540517651
6758,AFQAPWVESEJYTNZC23LDPQOH7QBA,B09GM4283G,5.0,1630119475785
13574,AGNK22JGAD5WE2TVGQTD2BTIXUNA,B000LSJKAM,5.0,1636138124907
...,...,...,...,...
33256,AEUDRWR7PNA6JPXZ5RTN6KCXKOBA,B0BK673BF4,5.0,1636340286500
25398,AFEYPRZTCVN4WURYQOEYNTD2JFYQ,B000N5Z2L4,5.0,1648670003849
8003,AHPWCVRF23GACCNGHT6VRIPFTFFA,B0B5SV7L99,5.0,1631554680324
26711,AGRQ2ELNB47RERPPMRDMXK7EOGZA,B0771371PM,5.0,1630392097798


In [9]:
train_sample.describe(include='all').T

Unnamed: 0,count,unique,top,freq,mean,std,min,25%,50%,75%,max
user_id,54970.0,12397.0,AGMWACNMAG74AXBF7IJ22IOZSZPA,236.0,,,,,,,
parent_asin,54970.0,5429.0,B01N3ASPNV,527.0,,,,,,,
rating,54970.0,,,,4.212279,1.248914,1.0,4.0,5.0,5.0,5.0
timestamp,54970.0,,,,1528120880047.892,91724441151.59406,975042289000.0,1480705888000.0,1554444818254.5,1599823304243.75,1628643144373.0


# Persist sample

In [10]:
train_sample.to_parquet("../data/train.parquet")
val_sample.to_parquet("../data/val.parquet")