In [None]:
%load_ext autoreload
%autoreload 2

In [6]:
import os
import sys
from functools import partial

import numpy as np
import plotly.express as px
from loguru import logger
from pydantic import BaseModel, model_validator
from load_dotenv import load_dotenv
import pandas as pd

from sqlalchemy import create_engine
from feast import FeatureStore
import pandas as pd

sys.path.insert(0, "..")

from src.utils.split_time_based import train_test_split_timebased
from src.utils.embedding_id_mapper import IDMapper
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

_ = load_dotenv(override=True)

## Controler

In [98]:
class Args(BaseModel):
    run_name: str = "000-data-prep"
    run_description: str = "Splitting data into train, val, test sets, then sampling data for quick iteration"
    testing: bool = False
    sample_data_persit_path: str = None    # path of the sampled data: train, test and val
    notebook_persit_path: str = None    # path of the notebook
    random_seed: int = 41

    user_col: str = "user_id"
    item_col: str = "parent_asin"
    rating_col: str = "rating"
    timestamp_col: str = "timestamp"

    sample_users: int = 5000
    min_user_interactions: int = 5
    min_item_interactions: int = 10

    val_num_days: int = 420
    test_num_days: int = 540

    rating_dataset_path: str = os.path.abspath("../data_for_ai/raw/amz_raw_rating.parquet")

    def init(self):
        self.sample_data_persit_path = os.path.abspath(f"../data_for_ai/interim")
        self.notebook_persit_path = os.path.abspath(f"./data/{self.run_name}")
        if not self.testing:
            os.makedirs(self.sample_data_persit_path, exist_ok=True)
            os.makedirs(self.notebook_persit_path, exist_ok=True)

        return self


args = Args().init()

print(args.model_dump_json(indent=2))

{
  "run_name": "000-data-prep",
  "run_description": "Splitting data into train, val, test sets, then sampling data for quick iteration",
  "testing": false,
  "sample_data_persit_path": "/home/dinhln/Desktop/real_time_recsys/data_for_ai/interim",
  "notebook_persit_path": "/home/dinhln/Desktop/real_time_recsys/notebooks/data/000-data-prep",
  "random_seed": 41,
  "user_col": "user_id",
  "item_col": "parent_asin",
  "rating_col": "rating",
  "timestamp_col": "timestamp",
  "sample_users": 5000,
  "min_user_interactions": 5,
  "min_item_interactions": 10,
  "val_num_days": 420,
  "test_num_days": 540,
  "rating_dataset_path": "/home/dinhln/Desktop/real_time_recsys/data_for_ai/raw/amz_raw_rating.parquet"
}


## Load data from a specific period in order to train the model

In notebook 002-simulate-oltp, we can see that the time period from March 2020 to Sep 2020 is the good choice. There are active interactions between users and items in this period and wen can keep the recency. So, we will load data from this period to train the model.

In [65]:
# Concatenate all processed chunks into a final DataFrame
full_df = pd.read_parquet(args.rating_dataset_path)
print(f"DataFrame shape: {full_df.shape}")


DataFrame shape: (43334103, 4)


In [60]:
full_df.head()

Unnamed: 0,user_id,parent_asin,rating,timestamp
0,AFKZENTNBQ7A7V7UXW5JJI6UGRYQ,B01G8JO5F2,5.0,2018-04-07 09:23:37.534
1,AFKZENTNBQ7A7V7UXW5JJI6UGRYQ,B07N69T6TM,1.0,2020-06-20 18:42:29.731
2,AFKZENTNBQ7A7V7UXW5JJI6UGRYQ,B083NRGZMM,3.0,2022-07-18 22:58:37.948
3,AGGZ357AO26RQZVRLGU4D4N52DZQ,B001OC5JKY,5.0,2010-11-20 18:41:35.000
4,AG2L7H23R5LLKDKLBEF2Q3L2MVDA,B07CJYMRWM,5.0,2023-02-17 02:39:41.238


In [31]:
def remove_random_users(df, k=10):
    users = df[args.user_col].unique()
    np.random.seed(args.random_seed)
    to_remove_users = np.random.choice(users, size=k, replace=False)
    return df.loc[lambda df: ~df[args.user_col].isin(to_remove_users)]


def get_unqualified(df, col: str, threshold: int):
    unqualified = df.groupby(col).size().loc[lambda s: s < threshold].index
    return unqualified

In [19]:
get_unqualified_users = partial(
    get_unqualified, col=args.user_col, threshold=5
)
get_unqualified_items = partial(
    get_unqualified, col=args.item_col, threshold=5
)

In [20]:
# get 5-core df
uu = get_unqualified_users(full_df)
ui = get_unqualified_items(full_df)
print(f"Number of unqualified users: {len(uu)}")
print(f"Number of unqualified items: {len(ui)}")

Number of unqualified users: 16420548
Number of unqualified items: 987017


In [21]:
full_5core_df = full_df.loc[lambda df: ~df[args.user_col].isin(uu)]
full_5core_df = full_5core_df.loc[lambda df: ~df[args.item_col].isin(ui)]
print(f"5-core DataFrame shape: {full_5core_df.shape}")

5-core DataFrame shape: (16879366, 4)


In [22]:
logger.info(f"{len(full_df) - len(full_5core_df)} ")

[32m2025-03-26 23:40:31.668[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1m26454737 [0m


In [99]:
# Split train, val, test
train_df, val_df, test_df = train_test_split_timebased(
    full_5core_df, user_id_col="user_id",
        item_id_col="parent_asin",
        timestamp_col="timestamp",
        val_num_days=args.val_num_days,
        test_num_days=args.test_num_days)

[32m2025-03-27 01:23:18.987[0m | [1mINFO    [0m | [36msrc.utils.split_time_based[0m:[36mtrain_test_split_timebased[0m:[36m26[0m - [1mRemoving users from val and test sets...[0m
[32m2025-03-27 01:23:24.875[0m | [1mINFO    [0m | [36msrc.utils.split_time_based[0m:[36mtrain_test_split_timebased[0m:[36m37[0m - [1mRemoved 141738 users from val set[0m
[32m2025-03-27 01:23:25.016[0m | [1mINFO    [0m | [36msrc.utils.split_time_based[0m:[36mtrain_test_split_timebased[0m:[36m40[0m - [1mRemoved 265886 users from test set[0m
[32m2025-03-27 01:23:25.017[0m | [1mINFO    [0m | [36msrc.utils.split_time_based[0m:[36mtrain_test_split_timebased[0m:[36m43[0m - [1mTrain set has 1762078 users[0m
[32m2025-03-27 01:23:25.236[0m | [1mINFO    [0m | [36msrc.utils.split_time_based[0m:[36mtrain_test_split_timebased[0m:[36m44[0m - [1mVal set has 680644 users[0m
[32m2025-03-27 01:23:25.424[0m | [1mINFO    [0m | [36msrc.utils.split_time_based[0m:[36mtr

In [100]:
assert train_df[args.timestamp_col].max() < val_df[args.timestamp_col].min(), "There are overlapping timestamps between train and validation datasets."
assert val_df[args.timestamp_col].max() < test_df[args.timestamp_col].min(), "There are overlapping timestamps between validation and test datasets."

In [101]:
logger.info(f"Train: {train_df.shape}, Val: {val_df.shape}, Test: {test_df.shape}")

[32m2025-03-27 01:23:27.019[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mTrain: (13043980, 4), Val: (1375948, 4), Test: (749578, 4)[0m


## Sampling data

Just randomly get X users will not guarantee that the output dataset would qualify the condition of **richness**. Instead we take an iterative approach where we gradually drop random users from the dataset while keeping an eye on the conditions and our sampling target.

In [102]:
get_unqualified_users = partial(
    get_unqualified, col=args.user_col, threshold=args.min_user_interactions
)
get_unqualified_items = partial(
    get_unqualified, col=args.item_col, threshold=args.min_item_interactions
)

In [None]:
buffer_perc = 0.2
perc_users_removed_each_round = 0.05
debug = True
keep_random_removing = True
r = 1

sample_df = train_df.copy()

while keep_random_removing:
    num_users_removed_each_round = int(
        perc_users_removed_each_round * sample_df[args.user_col].nunique()
    )
    print(
        f"\n\nRandomly removing {num_users_removed_each_round} users - Round {r} started"
    )
    sample_df = remove_random_users(sample_df, k=num_users_removed_each_round)

    keep_removing = True
    i = 1

    while keep_removing:
        if debug:
            logger.info(f"Sampling round {i} started")
        keep_removing = False
        uu = get_unqualified_users(sample_df)
        if debug:
            logger.info(f"{len(uu)=}")
        if len(uu):
            sample_df = sample_df.loc[lambda df: ~df[args.user_col].isin(uu)]
            if debug:
                logger.info(f"After removing uu: {len(sample_df)=}")
            assert len(get_unqualified_users(sample_df)) == 0
            keep_removing = True
        ui = get_unqualified_items(sample_df)
        if debug:
            logger.info(f"{len(ui)=}")
        if len(ui):
            sample_df = sample_df.loc[lambda df: ~df[args.item_col].isin(ui)]
            if debug:
                logger.info(f"After removing ui: {len(sample_df)=}")
            assert len(get_unqualified_items(sample_df)) == 0
            keep_removing = True
        i += 1
    
    sample_users = sample_df[args.user_col].unique()
    sample_items = sample_df[args.item_col].unique()
    num_users = len(sample_users)
    logger.info(f"After randomly removing users - round {r}: {num_users=}")
    
    if len(sample_df) < 20000:
        perc_users_removed_each_round = 0.1

    if num_users > args.sample_users * (1 + buffer_perc):
        logger.info(
            f"Number of users {num_users} are still greater than expected, keep removing..."
        )
    else:
        logger.info(
            f"Number of users {num_users} are falling below expected threshold, stop and use `sample_df` as final output..."
        )
        keep_random_removing = False
    
    val_sample_df = val_df.loc[
                lambda df: df[args.user_col].isin(sample_users)
                & df[args.item_col].isin(sample_items)
            ]
    test_sample_df = test_df.loc[
                lambda df: df[args.user_col].isin(sample_users)
                & df[args.item_col].isin(sample_items)
            ]
    if (num_val_records := val_sample_df.shape[0]) < 2000:
        logger.info(
            f"Number of val_df records {num_val_records:,.0f} are falling below expected threshold, stop and use `sample_df` as final output..."
        )
        keep_random_removing = False
    if (num_test_records := test_sample_df.shape[0]) < 2000:
        logger.info(
            f"Number of test_df records {num_test_records:,.0f} are falling below expected threshold, stop and use `sample_df` as final output..."
        )
        keep_random_removing = False

    r += 1

sample_users = sample_df[args.user_col].unique()
sample_items = sample_df[args.item_col].unique()
logger.info(f"Final sample sizes: {len(sample_users)=:,.0f}, {len(sample_items)=:,.0f}")


In [104]:
assert sample_df[args.timestamp_col].max() < val_sample_df[args.timestamp_col].min(), "There are overlapping timestamps between train and validation datasets."
assert val_sample_df[args.timestamp_col].max() < test_sample_df[args.timestamp_col].min(), "There are overlapping timestamps between validation and test datasets."

In [105]:
assert val_sample_df.loc[lambda df: ~df[args.user_col].isin(sample_users)].shape[0] == 0, "Validation DataFrame contains unexpected users."
assert test_sample_df.loc[lambda df: ~df[args.user_col].isin(sample_users)].shape[0] == 0, "Test DataFrame contains unexpected users."
assert val_sample_df.loc[lambda df: ~df[args.item_col].isin(sample_items)].shape[0] == 0, "Validation DataFrame contains unexpected items."
assert test_sample_df.loc[lambda df: ~df[args.item_col].isin(sample_items)].shape[0] == 0, "Test DataFrame contains unexpected items."

In [106]:
px.histogram(sample_df.groupby(args.user_col).size())

In [107]:
px.histogram(sample_df.groupby(args.item_col).size())

In [108]:
sample_df

Unnamed: 0,user_id,parent_asin,rating,timestamp
3194,AEYGPUCRKH7G4VM22FM3VAKSQ23Q,B06XKCPK5W,2.0,2012-06-11 16:41:10
3199,AEYGPUCRKH7G4VM22FM3VAKSQ23Q,B000CKVOOY,3.0,2012-08-02 02:04:13
3200,AEYGPUCRKH7G4VM22FM3VAKSQ23Q,B006GWO5WK,5.0,2012-09-15 16:34:46
3204,AEYGPUCRKH7G4VM22FM3VAKSQ23Q,B008LURQ76,5.0,2013-01-03 23:08:45
3208,AEYGPUCRKH7G4VM22FM3VAKSQ23Q,B00AQRUW4Q,4.0,2013-05-06 01:24:39
...,...,...,...,...
40882304,AFB4DWWKZBQFS22FAWDEP37EL2FA,B00KAF5RQ2,5.0,2016-02-22 17:44:10
40882305,AFB4DWWKZBQFS22FAWDEP37EL2FA,B001F6TXME,5.0,2016-02-22 17:44:40
40882306,AFB4DWWKZBQFS22FAWDEP37EL2FA,B007VGGIB6,5.0,2016-02-22 17:45:10
40882307,AFB4DWWKZBQFS22FAWDEP37EL2FA,B00WUID73W,5.0,2016-02-22 17:45:37


In [109]:
val_sample_df

Unnamed: 0,user_id,parent_asin,rating,timestamp
4668,AGZE3IYHOEGKUTJZSQCSFSQ4IFFQ,B0B787CN26,5.0,2021-10-27 19:43:57.873
10425,AEANO5BIASSZNFWNXBR2ECHCPJQQ,B0002MQGOA,5.0,2021-02-02 14:20:48.424
10426,AEANO5BIASSZNFWNXBR2ECHCPJQQ,B07HZLHPKP,5.0,2021-03-08 13:56:57.795
13265,AHDXCFTV7RS3AM6E2TRPWOG3A33Q,B07QWPVZJY,3.0,2021-12-11 00:34:19.152
14423,AEFHRRLFCZQ3TWNYCBA7UD3NIXCA,B00D96J8IM,1.0,2021-10-17 20:54:19.325
...,...,...,...,...
33760091,AHIIISHZP6YAVVHMDEBLJ5CWZ7ZA,B0BZ62FQ13,3.0,2021-07-16 17:08:55.044
34470392,AFTE3G43QHXWD3DJGDCI2DHEWQJQ,B08DMXDPW5,5.0,2021-01-14 01:48:09.423
35019360,AFENZZDPVUYFVBS47YDOWJCDYBSQ,B09XBT6DS9,4.0,2021-12-05 00:35:40.874
35323250,AFMBZYPDAXT5VO3ME67HW5Q5TAOQ,B097KBF8JK,5.0,2022-02-18 11:32:46.732


In [110]:
test_sample_df

Unnamed: 0,user_id,parent_asin,rating,timestamp
13270,AHDXCFTV7RS3AM6E2TRPWOG3A33Q,B0BHMVBV9M,2.0,2022-08-27 16:17:48.228
13271,AHDXCFTV7RS3AM6E2TRPWOG3A33Q,B08F1P3BCC,4.0,2023-01-16 04:19:10.669
29821,AHZ6GFHFM6Z7CRPSXRIYQ5Z7GERQ,B08YF1VBYD,4.0,2022-05-28 18:44:43.983
38367,AHAI4X3YAVRMXXUR6USAT5L5WG3A,B0BMK6DC5W,1.0,2022-09-16 13:15:24.402
41950,AGBF2BZRN6M65YBFZCF54ENDRRAA,B0BM73T3K6,5.0,2022-08-26 12:13:41.632
...,...,...,...,...
33756582,AEUXNGJ4HXZXHHU5OF3BPR6ZCLNQ,B0BHGRJDCK,5.0,2022-04-18 15:02:10.839
33756584,AEUXNGJ4HXZXHHU5OF3BPR6ZCLNQ,B08CKZ36N7,5.0,2023-01-28 23:46:05.518
35019361,AFENZZDPVUYFVBS47YDOWJCDYBSQ,B09PRD4T26,5.0,2023-03-13 00:48:18.717
35323251,AFMBZYPDAXT5VO3ME67HW5Q5TAOQ,B09PB85B9K,5.0,2022-07-05 23:12:38.472


In [111]:
subsets = ["train", "val", "test"]
original_length = {"train": train_df.shape[0], "val": val_df.shape[0], "test": test_df.shape[0]}
sampled_length = {"train": sample_df.shape[0], "val": val_sample_df.shape[0], "test": test_sample_df.shape[0]}


In [112]:
original_length

{'train': 13043980, 'val': 1375948, 'test': 749578}

In [113]:
sampled_length

{'train': 127392, 'val': 3479, 'test': 1822}

In [114]:
fig = make_subplots(rows=1, cols=3)

# Add data for each subset
for i, subset in enumerate(subsets):
    row = i // 3 + 1
    col = i % 3 +1

    # Add trace for 'curr'
    fig.add_trace(
        go.Bar(
            name="original",
            x=[subset],
            y=[original_length[subset]],
            marker_color = "lightblue",
            showlegend=(i == 0),
            texttemplate="%{y:.2}",
        ),
        row=row,
        col=col,
    )

    # Add trace for 'new'
    fig.add_trace(
        go.Bar(
            name="sample",
            x=[subset],
            y=[sampled_length[subset]],
            marker_color="lightgreen",
            showlegend=(i == 0),
            texttemplate="%{y:.2}",
        ),
        row=row,
        col=col,
    )

    # Add diff annotation
    difference = (sampled_length[subset] - original_length[subset]) / original_length[
        subset
    ]
    fig.add_annotation(
        x=subset,
        y=sampled_length[subset] * 1.10,  # Position above the tallest bar
        text=f"Δ={difference:.2%}",
        showarrow=False,
        font=dict(color="black", size=14),
        row=row,
        col=col,
    )

fig.update_layout(showlegend=True)

fig.show()

In [115]:
num_users = sample_df[args.user_col].nunique()
num_users

16407

In [116]:
# Perit the sampled data
sample_df.to_parquet(f"{args.sample_data_persit_path}/train_sample_interactions_{num_users}u.parquet")
val_sample_df.to_parquet(f"{args.sample_data_persit_path}/val_sample_interactions_{num_users}u.parquet")
test_sample_df.to_parquet(f"{args.sample_data_persit_path}/test_sample_interactions_{num_users}u.parquet")

Remember to version your data with dvc

In [117]:
train_sample_df = pd.read_parquet(f"{args.sample_data_persit_path}/train_sample_interactions_{num_users}u.parquet")

In [118]:
def plot_interactions_over_time(df):
    df = df.assign(timestamp=df[args.timestamp_col].dt.date)
    plot_df = df.groupby(args.timestamp_col).size()

    fig = px.line(
        x=plot_df.index,
        y=plot_df.values,
        labels={"x": "Date", "y": "Number of Interactions"},
        title="Interactions Over Time",
        height=500,
    )

    fig.update_layout(yaxis=dict(showticklabels=True, tickformat=","))

    fig.show()

In [119]:
train_sample_df

Unnamed: 0,user_id,parent_asin,rating,timestamp
3194,AEYGPUCRKH7G4VM22FM3VAKSQ23Q,B06XKCPK5W,2.0,2012-06-11 16:41:10
3199,AEYGPUCRKH7G4VM22FM3VAKSQ23Q,B000CKVOOY,3.0,2012-08-02 02:04:13
3200,AEYGPUCRKH7G4VM22FM3VAKSQ23Q,B006GWO5WK,5.0,2012-09-15 16:34:46
3204,AEYGPUCRKH7G4VM22FM3VAKSQ23Q,B008LURQ76,5.0,2013-01-03 23:08:45
3208,AEYGPUCRKH7G4VM22FM3VAKSQ23Q,B00AQRUW4Q,4.0,2013-05-06 01:24:39
...,...,...,...,...
40882304,AFB4DWWKZBQFS22FAWDEP37EL2FA,B00KAF5RQ2,5.0,2016-02-22 17:44:10
40882305,AFB4DWWKZBQFS22FAWDEP37EL2FA,B001F6TXME,5.0,2016-02-22 17:44:40
40882306,AFB4DWWKZBQFS22FAWDEP37EL2FA,B007VGGIB6,5.0,2016-02-22 17:45:10
40882307,AFB4DWWKZBQFS22FAWDEP37EL2FA,B00WUID73W,5.0,2016-02-22 17:45:37


In [120]:
# Build up idm
# Sorted to make sure that even rerun we get same idm mapping
unique_user_ids = sorted(train_sample_df[args.user_col].unique())
unique_item_ids = sorted(train_sample_df[args.item_col].unique())
logger.info(f"Number of unique users: {len(unique_user_ids):,.0f}")
logger.info(f"Number of unique items: {len(unique_item_ids):,.0f}")
idm = IDMapper()
idm.fit(unique_user_ids, unique_item_ids)

[32m2025-03-27 01:37:15.000[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m5[0m - [1mNumber of unique users: 16,407[0m
[32m2025-03-27 01:37:15.000[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mNumber of unique items: 4,817[0m


In [121]:
idm.save(f"{args.notebook_persit_path}/idm_{num_users}u.json")
idm_persist_fp = f"{args.notebook_persit_path}/idm_{num_users}u.json"
idm = IDMapper().load(idm_persist_fp)

In [122]:
len(idm.item_to_index)

4817

In [123]:
for k, _ in idm.item_to_index.items():
    assert type(k) is str, "Type of user id should be string"
for k,_ in idm.user_to_index.items():
    assert type(k) is str, "Type of item id should be string"