In [1]:
from datasets import load_dataset
import numpy as np
from loguru import logger

# Load data

In [2]:
dataset = load_dataset("McAuley-Lab/Amazon-Reviews-2023", "5core_timestamp_Books", trust_remote_code=True)

In [3]:
dataset['train']

Dataset({
    features: ['user_id', 'parent_asin', 'rating', 'timestamp'],
    num_rows: 8733855
})

In [4]:
dataset['valid']

Dataset({
    features: ['user_id', 'parent_asin', 'rating', 'timestamp'],
    num_rows: 426209
})

In [5]:
def parse_dtype(df):
    return (
        df
        .assign(
            rating=lambda df: df['rating'].astype(float),
            timestamp=lambda df: df['timestamp'].astype(int)
        )
    )

train_raw = dataset['train'].to_pandas().pipe(parse_dtype)
val_raw = dataset['valid'].to_pandas().pipe(parse_dtype)

# Sample data

In [6]:
SAMPLE_TRAIN_USERS = 10000
if SAMPLE_TRAIN_USERS:
    random_seed = 42
    np.random.seed(random_seed)
    
    # Get users present in both train and val datasets
    users_in_train = train_raw['user_id'].unique()
    users_in_val = val_raw['user_id'].unique()
    common_users = np.intersect1d(users_in_val, users_in_train)

    # Sample users from the common users
    sample_users = np.random.choice(common_users, size=SAMPLE_TRAIN_USERS, replace=False)
    
    # Fetch all interactions of the sampled users in both datasets
    val_sample = val_raw[val_raw['user_id'].isin(sample_users)]
    train_sample = train_raw[train_raw['user_id'].isin(sample_users)]

    # Ensure all items in val_sample exist in train_sample
    train_items = train_sample['parent_asin'].unique()
    val_sample = val_sample[val_sample['parent_asin'].isin(train_items)]

    # Update item and user lists after filtering
    val_items = val_sample['parent_asin'].unique()
    train_users = train_sample['user_id'].unique()
    val_users = val_sample['user_id'].unique()

    # Logging
    logger.info(f"{len(train_items)=}, {len(train_users)=}")
    logger.info(f"{len(val_items)=}, {len(val_users)=}")
    val_users_in_train = set(val_users).intersection(set(train_users))
    val_items_in_train = set(val_items).intersection(set(train_items))
    logger.info(f"Percentage of val users in train: {len(val_users_in_train) / len(val_users):,.0%}")
    logger.info(f"Percentage of val items in train: {len(val_items_in_train) / len(val_items):,.0%}")
    logger.info(f"Sparsity: {1 - len(train_sample) / (len(train_items) * len(train_users)):,.4%}")

[32m2024-09-19 20:13:24.354[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m28[0m - [1mlen(train_items)=93655, len(train_users)=10000[0m
[32m2024-09-19 20:13:24.354[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m29[0m - [1mlen(val_items)=6056, len(val_users)=4932[0m
[32m2024-09-19 20:13:24.361[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m32[0m - [1mPercentage of val users in train: 100%[0m
[32m2024-09-19 20:13:24.362[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m33[0m - [1mPercentage of val items in train: 100%[0m
[32m2024-09-19 20:13:24.362[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m34[0m - [1mSparsity: 99.9837%[0m


In [7]:
train_sample

Unnamed: 0,user_id,parent_asin,rating,timestamp
23,AHXBL3QDWZGJYH7A5CMPFNUPMF7Q,0920668372,5.0,1430056169000
24,AHXBL3QDWZGJYH7A5CMPFNUPMF7Q,1589255208,5.0,1443926150000
25,AHXBL3QDWZGJYH7A5CMPFNUPMF7Q,2764322836,5.0,1463967052000
26,AHXBL3QDWZGJYH7A5CMPFNUPMF7Q,2764330898,5.0,1489085694000
27,AHXBL3QDWZGJYH7A5CMPFNUPMF7Q,0062380761,5.0,1526591330983
...,...,...,...,...
8727081,AFHRNLTISIMZANNK3FV7Y2GHIU4A,1620148021,5.0,1575652999784
8727082,AFHRNLTISIMZANNK3FV7Y2GHIU4A,1546034587,5.0,1620237294827
8729931,AFSYIYI3FQPLRS3P7PTGN6H6FHGA,1940189004,5.0,1412824706000
8729932,AFSYIYI3FQPLRS3P7PTGN6H6FHGA,1939650380,5.0,1475293391000


In [8]:
val_sample

Unnamed: 0,user_id,parent_asin,rating,timestamp
4,AHXBL3QDWZGJYH7A5CMPFNUPMF7Q,0451450523,2.0,1635710722120
113,AGUWL2R2JFLC3K65HLD6AHJV3KBA,1439153663,5.0,1641321916399
247,AETGCWXC47MSMK6B2TLZ44KCFJZQ,1465497676,3.0,1634149673109
249,AETGCWXC47MSMK6B2TLZ44KCFJZQ,B07ZQFT4B1,5.0,1649517456801
256,AFJP74KDEKRSPJN5JINL452T6WNA,1250766567,5.0,1629407827607
...,...,...,...,...
424800,AHXZ66ATLSPVIW5HC5OTNLYGBDTQ,1416542744,4.0,1645198331443
424884,AHWBSG5WTNDC47SPUMJTWPIDZ7HQ,B08MQLJ99B,5.0,1629558239986
425194,AE5AXNZSQK6R5J2EXFUCFPDPSA6A,1643260448,2.0,1637475668742
425440,AFM4K7CAFB2KE6BHWQSS7KEHTWLA,0452282314,5.0,1643339582810


In [9]:
train_sample.describe(include='all').T

Unnamed: 0,count,unique,top,freq,mean,std,min,25%,50%,75%,max
user_id,152587.0,10000.0,AHAR2ITQ3O2FJJDLIORAQ226KZHQ,1241.0,,,,,,,
parent_asin,152587.0,93655.0,B00L9B7IKE,180.0,,,,,,,
rating,152587.0,,,,4.349597,0.998744,1.0,4.0,5.0,5.0,5.0
timestamp,152587.0,,,,1483382993471.7302,117693146983.97728,878061365000.0,1420299130000.0,1506746458355.0,1575254557129.5,1628642557237.0


# Persist sample

In [10]:
train_sample.to_parquet("../data/train.parquet")
val_sample.to_parquet("../data/val.parquet")

# Archive