In [1]:
from datasets import load_dataset
import numpy as np
from loguru import logger

# Load data

In [2]:
dataset = load_dataset("McAuley-Lab/Amazon-Reviews-2023", "5core_timestamp_Video_Games", trust_remote_code=True)

In [3]:
dataset['train']

Dataset({
    features: ['user_id', 'parent_asin', 'rating', 'timestamp'],
    num_rows: 736827
})

In [4]:
dataset['valid']

Dataset({
    features: ['user_id', 'parent_asin', 'rating', 'timestamp'],
    num_rows: 34510
})

In [5]:
def parse_dtype(df):
    return (
        df
        .assign(
            rating=lambda df: df['rating'].astype(float),
            timestamp=lambda df: df['timestamp'].astype(int)
        )
    )

train_raw = dataset['train'].to_pandas().pipe(parse_dtype)
val_raw = dataset['valid'].to_pandas().pipe(parse_dtype)

# Sample data

In [6]:
SAMPLE_VAL_ROWS = 10000
# SAMPLE_TRAIN_ROWS = 10000
if SAMPLE_VAL_ROWS:
    random_seed = 42
    np.random.seed(random_seed)
    val_sample = val_raw.sample(SAMPLE_VAL_ROWS, random_state=random_seed)
    sample_users = val_sample['user_id'].unique()
    # Insist that val and train share the same pool of users and items
    sample_items = val_sample['parent_asin'].unique()
    train_sample = train_raw.loc[lambda df: df['parent_asin'].isin(sample_items) & df['user_id'].isin(sample_users)]
    train_items = train_sample['parent_asin'].unique()
    train_users = train_sample['user_id'].unique()
    val_sample = val_sample.loc[lambda df: df['parent_asin'].isin(train_items) & df['user_id'].isin(train_users)]
    # if train_raw.shape[0] > SAMPLE_TRAIN_ROWS:
    #     logger.info(f"{train_sample.shape[0]=:,.0f} exceeding SAMPLE_TRAIN_ROWS. Downsampling...")
    #     train_sample = train_sample.sample(SAMPLE_TRAIN_ROWS, random_state=random_seed)

val_items = val_sample['parent_asin'].unique()
val_users = val_sample['user_id'].unique()

logger.info(f"{len(train_items)=}, {len(train_users)=}")
logger.info(f"{len(val_items)=}, {len(val_users)=}")
val_users_in_train = set(val_users).intersection(set(train_users))
val_items_in_train = set(val_items).intersection(set(train_items))
logger.info(f"Percentage of val users in train: {len(val_users_in_train) / len(val_users):,.0%}")
logger.info(f"Percentage of val items in train: {len(val_items_in_train) / len(val_items):,.0%}")

[32m2024-09-08 23:23:36.416[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m21[0m - [1mlen(train_items)=2653, len(train_users)=5223[0m
[32m2024-09-08 23:23:36.417[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m22[0m - [1mlen(val_items)=2056, len(val_users)=3460[0m
[32m2024-09-08 23:23:36.419[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m25[0m - [1mPercentage of val users in train: 100%[0m
[32m2024-09-08 23:23:36.420[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m26[0m - [1mPercentage of val items in train: 100%[0m


In [7]:
train_sample

Unnamed: 0,user_id,parent_asin,rating,timestamp
1,AEVPPTMG43C6GWSR7I2UGRQN7WFQ,B0863MT183,4.0,1613701986538
2,AEVPPTMG43C6GWSR7I2UGRQN7WFQ,B08P8P7686,5.0,1613702112995
3,AEVPPTMG43C6GWSR7I2UGRQN7WFQ,B0B7LV3DN2,4.0,1617641445475
4,AEVPPTMG43C6GWSR7I2UGRQN7WFQ,B09WMQ6DXG,5.0,1620231368468
70,AHV6QCNBJNSGLATP56JAWJ3C4G2A,B019WRM1IA,5.0,1451860309000
...,...,...,...,...
735704,AHS2PQ33BWQLXC5NNUZS2BFXD34Q,B07TZT67KX,5.0,1622844181866
735800,AFO5SNKILFVJMSJJ2E3BRLDGE4NA,B09T5VN7D1,4.0,1601154352542
735801,AFO5SNKILFVJMSJJ2E3BRLDGE4NA,B09918MSTF,5.0,1602615880364
736772,AEFPHMM7CLX4UJNXJFQF4ZF5GNAA,B07P27XFP7,5.0,1599585146628


In [8]:
val_sample

Unnamed: 0,user_id,parent_asin,rating,timestamp
3662,AGS4TR4K5DMBRAFNBYSB2I2RCHHQ,B0936HDGJ6,5.0,1652494657651
15518,AF7HTSEWIKYSP5D3ST4EZIUK6PJQ,B08F5T3F9Y,5.0,1644540517651
6758,AFQAPWVESEJYTNZC23LDPQOH7QBA,B09GM4283G,5.0,1630119475785
13574,AGNK22JGAD5WE2TVGQTD2BTIXUNA,B000LSJKAM,5.0,1636138124907
28977,AHEYSO45HU7ECMRIGBLPQZADLDDA,B07QQ8N7LL,2.0,1643905704106
...,...,...,...,...
9329,AFTJ4C6AVWJSOPT3NKUX5JQJDUKA,B0BYNMZ3SP,5.0,1643164686173
478,AFG6UJ2SWJJPMQXSW77MFJKHCEJQ,B08FC5L3RG,1.0,1634511131032
14740,AHAYPUV4RUPVQ2EYVVJBOFFKCFPA,B0CB9GDK9P,5.0,1638698211953
3461,AFVIZWLSRFUWN65MI4VT4JATJZIA,B07R9PBHP2,3.0,1656269943479


In [9]:
train_sample.describe(include='all').T

Unnamed: 0,count,unique,top,freq,mean,std,min,25%,50%,75%,max
user_id,18095.0,5223.0,AGMWACNMAG74AXBF7IJ22IOZSZPA,165.0,,,,,,,
parent_asin,18095.0,2653.0,B01N3ASPNV,230.0,,,,,,,
rating,18095.0,,,,4.250511,1.223807,1.0,4.0,5.0,5.0,5.0
timestamp,18095.0,,,,1545981709246.6826,81530755521.68243,975042289000.0,1508006692433.5,1571875102706.0,1607936733469.5,1628642755246.0


# Persist sample

In [10]:
train_sample.to_parquet("../data/train.parquet")
val_sample.to_parquet("../data/val.parquet")