In [1]:
from datasets import load_dataset
import numpy as np
from loguru import logger

# Load data

In [2]:
dataset = load_dataset("McAuley-Lab/Amazon-Reviews-2023", "5core_timestamp_Video_Games", trust_remote_code=True)

In [3]:
dataset['train']

Dataset({
    features: ['user_id', 'parent_asin', 'rating', 'timestamp'],
    num_rows: 736827
})

In [4]:
dataset['valid']

Dataset({
    features: ['user_id', 'parent_asin', 'rating', 'timestamp'],
    num_rows: 34510
})

In [5]:
def parse_dtype(df):
    return (
        df
        .assign(
            rating=lambda df: df['rating'].astype(float),
            timestamp=lambda df: df['timestamp'].astype(int)
        )
    )

train_raw = dataset['train'].to_pandas().pipe(parse_dtype)
val_raw = dataset['valid'].to_pandas().pipe(parse_dtype)

# Sample data

In [6]:
SAMPLE_VAL_USERS = 5000
if SAMPLE_VAL_USERS:
    random_seed = 42
    np.random.seed(random_seed)
    
    # Get users present in both train and val datasets
    users_in_train = train_raw['user_id'].unique()
    users_in_val = val_raw['user_id'].unique()
    common_users = np.intersect1d(users_in_val, users_in_train)
    
    # Sample users from the common users
    sample_users = np.random.choice(common_users, size=SAMPLE_VAL_USERS, replace=False)
    
    # Fetch all interactions of the sampled users in both datasets
    val_sample = val_raw[val_raw['user_id'].isin(sample_users)]
    train_sample = train_raw[train_raw['user_id'].isin(sample_users)]
    
    # Ensure all items in val_sample exist in train_sample
    train_items = train_sample['parent_asin'].unique()
    val_sample = val_sample[val_sample['parent_asin'].isin(train_items)]
    
    # Update item and user lists after filtering
    val_items = val_sample['parent_asin'].unique()
    train_users = train_sample['user_id'].unique()
    val_users = val_sample['user_id'].unique()
    
    # Logging
    logger.info(f"{len(train_items)=}, {len(train_users)=}")
    logger.info(f"{len(val_items)=}, {len(val_users)=}")
    val_users_in_train = set(val_users).intersection(set(train_users))
    val_items_in_train = set(val_items).intersection(set(train_items))
    logger.info(f"Percentage of val users in train: {len(val_users_in_train) / len(val_users):,.0%}")
    logger.info(f"Percentage of val items in train: {len(val_items_in_train) / len(val_items):,.0%}")

[32m2024-09-18 23:38:36.495[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m28[0m - [1mlen(train_items)=11547, len(train_users)=5000[0m
[32m2024-09-18 23:38:36.495[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m29[0m - [1mlen(val_items)=2351, len(val_users)=3433[0m
[32m2024-09-18 23:38:36.496[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m32[0m - [1mPercentage of val users in train: 100%[0m
[32m2024-09-18 23:38:36.497[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m33[0m - [1mPercentage of val items in train: 100%[0m


In [7]:
train_sample

Unnamed: 0,user_id,parent_asin,rating,timestamp
10,AFFZVSTUS3U2ZD22A2NPZSKOCPGQ,B01GW3LRD2,5.0,1491589434000
11,AFFZVSTUS3U2ZD22A2NPZSKOCPGQ,B0848LKV51,4.0,1574659954094
323,AF4WLLHTQLRPEZ33OJDYG23MFLKQ,B09BQ4ZDQZ,5.0,1553570826020
324,AF4WLLHTQLRPEZ33OJDYG23MFLKQ,B01MDQP1ZU,5.0,1561162491649
325,AF4WLLHTQLRPEZ33OJDYG23MFLKQ,B0171AOQG0,5.0,1561266333491
...,...,...,...,...
736602,AGRXRIPAZTGAQHFKXZLFJDUOJSJA,B08FRMGWXQ,5.0,1609218633016
736770,AEFPHMM7CLX4UJNXJFQF4ZF5GNAA,B01BO2012O,3.0,1554983583540
736771,AEFPHMM7CLX4UJNXJFQF4ZF5GNAA,B001NJMMHG,5.0,1599584891516
736772,AEFPHMM7CLX4UJNXJFQF4ZF5GNAA,B07P27XFP7,5.0,1599585146628


In [8]:
val_sample

Unnamed: 0,user_id,parent_asin,rating,timestamp
31,AF4WLLHTQLRPEZ33OJDYG23MFLKQ,B0BHTCQXVL,5.0,1637986006587
32,AF4WLLHTQLRPEZ33OJDYG23MFLKQ,B0BWXXWVV6,5.0,1649469448750
33,AF4WLLHTQLRPEZ33OJDYG23MFLKQ,B0929CLLPW,5.0,1653191585299
34,AF4WLLHTQLRPEZ33OJDYG23MFLKQ,B09KRQY1ZF,4.0,1653192092369
35,AF4WLLHTQLRPEZ33OJDYG23MFLKQ,B094YHB1QK,5.0,1653192172671
...,...,...,...,...
34454,AF5XYBAXC5VJO4MY4JZWIKP5SETA,B09Y2WKZRZ,5.0,1630677089102
34498,AGRXRIPAZTGAQHFKXZLFJDUOJSJA,B08LZGPPBH,5.0,1635799909309
34499,AGRXRIPAZTGAQHFKXZLFJDUOJSJA,B09ZPGLK57,5.0,1638213680178
34500,AGRXRIPAZTGAQHFKXZLFJDUOJSJA,B08DKWPSWN,5.0,1638552455754


In [9]:
train_sample.describe(include='all').T

Unnamed: 0,count,unique,top,freq,mean,std,min,25%,50%,75%,max
user_id,36943.0,5000.0,AEWLQYBQDYWWUWK6UHHTNWO5AHYA,389.0,,,,,,,
parent_asin,36943.0,11547.0,B01N3ASPNV,204.0,,,,,,,
rating,36943.0,,,,4.171805,1.260896,1.0,4.0,5.0,5.0,5.0
timestamp,36943.0,,,,1499512178055.2026,106200179834.181,969227758000.0,1436612658500.0,1521425116538.0,1582732185292.5,1628643144373.0


# Persist sample

In [10]:
train_sample.to_parquet("../data/train.parquet")
val_sample.to_parquet("../data/val.parquet")

# Archive