In [1]:
from datasets import load_dataset
import numpy as np
from loguru import logger

# Load data

In [2]:
dataset = load_dataset("McAuley-Lab/Amazon-Reviews-2023", "5core_timestamp_Books", trust_remote_code=True)

In [3]:
dataset['train']

Dataset({
    features: ['user_id', 'parent_asin', 'rating', 'timestamp'],
    num_rows: 8733855
})

In [4]:
dataset['valid']

Dataset({
    features: ['user_id', 'parent_asin', 'rating', 'timestamp'],
    num_rows: 426209
})

In [5]:
def parse_dtype(df):
    return (
        df
        .assign(
            rating=lambda df: df['rating'].astype(float),
            timestamp=lambda df: df['timestamp'].astype(int)
        )
    )

train_raw = dataset['train'].to_pandas().pipe(parse_dtype)
val_raw = dataset['valid'].to_pandas().pipe(parse_dtype)

# Sample data

In [6]:
SAMPLE_VAL_USERS = 2000
if SAMPLE_VAL_USERS:
    random_seed = 42
    np.random.seed(random_seed)
    
    # Get users present in both train and val datasets
    users_in_train = train_raw['user_id'].unique()
    users_in_val = val_raw['user_id'].unique()
    common_users = np.intersect1d(users_in_val, users_in_train)
    
    # Sample users from the common users
    sample_users = np.random.choice(common_users, size=SAMPLE_VAL_USERS, replace=False)
    
    # Fetch all interactions of the sampled users in both datasets
    val_sample = val_raw[val_raw['user_id'].isin(sample_users)]
    train_sample = train_raw[train_raw['user_id'].isin(sample_users)]
    
    # Ensure all items in val_sample exist in train_sample
    train_items = train_sample['parent_asin'].unique()
    val_sample = val_sample[val_sample['parent_asin'].isin(train_items)]
    
    # Update item and user lists after filtering
    val_items = val_sample['parent_asin'].unique()
    train_users = train_sample['user_id'].unique()
    val_users = val_sample['user_id'].unique()
    
    # Logging
    logger.info(f"{len(train_items)=}, {len(train_users)=}")
    logger.info(f"{len(val_items)=}, {len(val_users)=}")
    val_users_in_train = set(val_users).intersection(set(train_users))
    val_items_in_train = set(val_items).intersection(set(train_items))
    logger.info(f"Percentage of val users in train: {len(val_users_in_train) / len(val_users):,.0%}")
    logger.info(f"Percentage of val items in train: {len(val_items_in_train) / len(val_items):,.0%}")

[32m2024-09-17 23:11:29.341[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m28[0m - [1mlen(train_items)=24723, len(train_users)=2000[0m
[32m2024-09-17 23:11:29.342[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m29[0m - [1mlen(val_items)=737, len(val_users)=593[0m
[32m2024-09-17 23:11:29.344[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m32[0m - [1mPercentage of val users in train: 100%[0m
[32m2024-09-17 23:11:29.344[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m33[0m - [1mPercentage of val items in train: 100%[0m


In [7]:
train_sample

Unnamed: 0,user_id,parent_asin,rating,timestamp
23,AHXBL3QDWZGJYH7A5CMPFNUPMF7Q,0920668372,5.0,1430056169000
24,AHXBL3QDWZGJYH7A5CMPFNUPMF7Q,1589255208,5.0,1443926150000
25,AHXBL3QDWZGJYH7A5CMPFNUPMF7Q,2764322836,5.0,1463967052000
26,AHXBL3QDWZGJYH7A5CMPFNUPMF7Q,2764330898,5.0,1489085694000
27,AHXBL3QDWZGJYH7A5CMPFNUPMF7Q,0062380761,5.0,1526591330983
...,...,...,...,...
8709429,AFM4K7CAFB2KE6BHWQSS7KEHTWLA,045141943X,5.0,1401739856000
8709430,AFM4K7CAFB2KE6BHWQSS7KEHTWLA,1137280166,5.0,1491185464000
8709431,AFM4K7CAFB2KE6BHWQSS7KEHTWLA,0451171357,5.0,1563398279401
8709432,AFM4K7CAFB2KE6BHWQSS7KEHTWLA,B07SVGNQG6,5.0,1568438557427


In [8]:
val_sample

Unnamed: 0,user_id,parent_asin,rating,timestamp
4,AHXBL3QDWZGJYH7A5CMPFNUPMF7Q,0451450523,2.0,1635710722120
293,AFG6YQ3GOY7TVFKQ3SKDVS6Q6RDQ,B08CV9SPDQ,4.0,1635609140286
294,AFG6YQ3GOY7TVFKQ3SKDVS6Q6RDQ,B07R3QYGHY,4.0,1657998389024
763,AFBXVB2GIANS2DHWDK3HXISL2WEA,1291332162,5.0,1651000430747
1205,AGSGLHB6G6QSTSIXWRD6ZZ7V5VZA,B0C8GJYMNH,5.0,1656800368338
...,...,...,...,...
422078,AF7F5V4G3SWPRIKQEATNV7WACR6A,0062915320,5.0,1638675622205
422346,AF2T4ZDAXUTFGFFRDG5GA5BWQXRA,1733090312,4.0,1630014011916
422347,AF2T4ZDAXUTFGFFRDG5GA5BWQXRA,1501128035,4.0,1630014353678
423566,AG3A7NFV7ZKBXWF6FV3VMF6CK3BA,1101930926,5.0,1637012094603


In [9]:
train_sample.describe(include='all').T

Unnamed: 0,count,unique,top,freq,mean,std,min,25%,50%,75%,max
user_id,30317.0,2000.0,AG7PUAYZCB2KR3U72ROZURNUYRBA,390.0,,,,,,,
parent_asin,30317.0,24723.0,B00L9B7IKE,35.0,,,,,,,
rating,30317.0,,,,4.343603,1.001445,1.0,4.0,5.0,5.0,5.0
timestamp,30317.0,,,,1489819111667.8467,114611940018.63252,878680832000.0,1428538472000.0,1513582772418.0,1578509546309.0,1628636788111.0


# Persist sample

In [10]:
train_sample.to_parquet("../data/train.parquet")
val_sample.to_parquet("../data/val.parquet")

# Archive