In [1]:
import os
import pandas as pd
import numpy as np
from loguru import logger
from collections import defaultdict
from pydantic import BaseModel

# Controller

In [2]:
class Args(BaseModel):
    testing: bool = False
    experiment_name: str = "FSDS RecSys - L5 - Reco Algo"
    run_name: str = '041-offline-negative-sampling-rating-prediction'
    notebook_persist_dp: str = None
    random_seed: int = 41

    user_col: str = 'user_id'
    item_col: str = 'parent_asin'
    rating_col: str = 'rating'
    timestamp_col: str = 'timestamp'

    def init(self):
        self.notebook_persist_dp = os.path.abspath(f"data/{self.run_name}")
        os.makedirs(self.notebook_persist_dp, exist_ok=True)

        return self
    
args = Args().init()

print(args.model_dump_json(indent=2))

{
  "testing": false,
  "experiment_name": "FSDS RecSys - L5 - Reco Algo",
  "run_name": "041-offline-negative-sampling-rating-prediction",
  "notebook_persist_dp": "/Users/dvq/frostmourne/reco-algo/notebooks/data/041-offline-negative-sampling-rating-prediction",
  "random_seed": 41,
  "user_col": "user_id",
  "item_col": "parent_asin",
  "rating_col": "rating",
  "timestamp_col": "timestamp"
}


# Test implementation

In [3]:
# Sample input: List of (user_id, item_id) interactions
interactions = [
    (1, 101, 1, 1),
    (1, 102, 2, 2),
    (1, 103, 3, 4),
    (2, 101, 4, 1),
    (2, 104, 5, 2),
    (3, 105, 1, 1),
    (3, 106, 2, 5),
    # Add more interactions as needed
]

# Convert the list to a DataFrame for easier manipulation
df = pd.DataFrame(interactions, columns=['user_id', 'item_id', 'rating', 'timestamp'])

In [4]:
def generate_negative_samples(
    df,
    user_col='user_id',
    item_col='item_id',
    label_col='rating',
    neg_label=0,
    seed=None,
    progress_bar_type='tqdm'  # Options: 'tqdm', 'tqdm_notebook', None
):
    """
    Optimized function to generate negative samples for a user-item interaction DataFrame.
    """
    
    # Handle random seed
    if seed is not None:
        np.random.seed(seed)
    
    # Import tqdm based on the progress_bar_type
    if progress_bar_type == 'tqdm':
        try:
            from tqdm import tqdm
            tqdm_bar = tqdm
        except ImportError:
            raise ImportError("tqdm is not installed. Please install it using 'pip install tqdm'.")
    elif progress_bar_type == 'tqdm_notebook':
        try:
            from tqdm.notebook import tqdm
            tqdm_bar = tqdm
        except ImportError:
            raise ImportError("tqdm.notebook is not available. Please install it using 'pip install tqdm'.")
    elif progress_bar_type is None:
        # Define a dummy tqdm function that does nothing
        def tqdm_bar(iterable, **kwargs):
            return iterable
    else:
        raise ValueError("Invalid progress_bar_type. Choose 'tqdm', 'tqdm_notebook', or None.")
    
    # Calculate item popularity based on the number of interactions
    item_popularity = df[item_col].value_counts()
    
    # Define all unique items from the DataFrame
    items = item_popularity.index.values
    all_items_set = set(items)
    
    # Create a user-item interaction dictionary
    user_item_dict = df.groupby(user_col)[item_col].apply(set).to_dict()
    
    # Prepare items list and corresponding popularity array
    popularity = item_popularity.values.astype(np.float64)
    
    # Calculate sampling probabilities based on item popularity
    total_popularity = popularity.sum()
    if total_popularity == 0:
        sampling_probs = np.ones(len(items)) / len(items)
    else:
        sampling_probs = popularity / total_popularity
    
    # Create item to index mapping for quick access
    item_to_index = {item: idx for idx, item in enumerate(items)}
    
    # Initialize a list to store negative samples
    negative_samples = []
    
    # Initialize the progress bar
    total_users = len(user_item_dict)
    progress_bar = tqdm_bar(user_item_dict.items(), total=total_users, desc="Generating Negative Samples")
    
    for user, pos_items in progress_bar:
        num_pos = len(pos_items)
        
        # Identify items not interacted with by the user
        negative_candidates = all_items_set - pos_items
        num_neg_candidates = len(negative_candidates)
        
        if num_neg_candidates == 0:
            # User has interacted with all items, skip negative sampling
            continue
        
        # Determine the number of negative samples to generate
        num_neg = min(num_pos, num_neg_candidates)
        
        # Convert set to list for indexing
        negative_candidates_list = list(negative_candidates)
        
        # Get the indices and probabilities of negative candidates
        candidate_indices = [item_to_index[item] for item in negative_candidates_list]
        candidate_probs = sampling_probs[candidate_indices]
        candidate_probs /= candidate_probs.sum()
        
        # Sample negative items without replacement
        sampled_items = np.random.choice(
            negative_candidates_list, size=num_neg, replace=False, p=candidate_probs
        )
        
        # Append the sampled negative items to the list
        negative_samples.extend([(user, item) for item in sampled_items])
    
    # Convert negative samples to a DataFrame
    df_negative = pd.DataFrame(negative_samples, columns=[user_col, item_col])
    df_negative[label_col] = neg_label  # Assign label for negative samples
    
    return df_negative

def add_features_to_neg_df(pos_df, neg_df, user_col, timestamp_col, feature_cols=[]):
    neg_df = neg_df.assign(
        timestamp_pseudo=lambda df: df.groupby('user_id').cumcount() + 1
    )
    neg_df = (
        pd.merge(
            neg_df,
            pos_df.assign(
                timestamp_pseudo=lambda df: df.groupby([user_col])[timestamp_col].rank(method='first')   
            )[[user_col, timestamp_col, 'timestamp_pseudo', *feature_cols]],
            how='left',
            on=[user_col, 'timestamp_pseudo']
        )
        .drop(columns=['timestamp_pseudo'])
    )
    return neg_df

In [5]:
neg_df = generate_negative_samples(df, progress_bar_type='tqdm_notebook')
neg_df = add_features_to_neg_df(df, neg_df, 'user_id', 'timestamp')

Generating Negative Samples:   0%|          | 0/3 [00:00<?, ?it/s]

In [6]:
neg_df.sort_values(['user_id', 'rating'])

Unnamed: 0,user_id,item_id,rating,timestamp
0,1,104,0,1
1,1,106,0,2
2,1,105,0,4
3,2,103,0,1
4,2,102,0,2
5,3,104,0,1
6,3,101,0,5


# Load data

In [7]:
train_df = pd.read_parquet("../data/train_features.parquet")
val_df = pd.read_parquet("../data/val_features.parquet")

In [8]:
assert (val_df[args.timestamp_col].min() - train_df[args.timestamp_col].max()) > 0
val_timestamp = train_df[args.timestamp_col].max() + 1
logger.info(f"{val_timestamp=}")

[32m2024-09-20 09:39:43.678[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m3[0m - [1mval_timestamp=np.int64(1628642557238)[0m


In [9]:
full_df = pd.concat([train_df, val_df], axis=0)
full_df

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,item_sequence
0,AFQFGIC62CA6X7B5WNYQJC3DQS6A,037376099X,5.0,878061365000,2021,18576,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1...."
1,AFJFQKVLBLJLGKHZYUHIDZLGVBDQ,1565922573,5.0,878680832000,2921,30312,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1...."
2,AFJFQKVLBLJLGKHZYUHIDZLGVBDQ,0449909433,4.0,879712608000,2921,6651,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1...."
3,AFQGSL2NLM3XYV4VU5YCHQZEMFRA,0553571818,4.0,887759677000,6551,42520,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1...."
4,AFQGSL2NLM3XYV4VU5YCHQZEMFRA,014018869X,5.0,888091095000,6551,17257,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1...."
...,...,...,...,...,...,...,...
9141,AE7CC33RBTGEOQ2MBIAZDHXEAG7A,B08XQWFMK4,2.0,1657994280406,5841,82836,"[-1, -1, -1, -1, -1, -1, -1, -1, 21304, 54333]"
9142,AHRDEE3ZO5VMRWUK7CUILRWSTB7A,1629798266,5.0,1657996230659,9471,53402,"[-1, 77791, 13305, 34579, 58946, 40435, 28041,..."
9143,AFG6YQ3GOY7TVFKQ3SKDVS6Q6RDQ,B07R3QYGHY,4.0,1657998389024,7843,86266,"[-1, -1, -1, -1, 66104, 12441, 57040, 4640, 33..."
9144,AHNN7AG7AL5Z7ZTX3ES5A4ZOQWUA,B01D1LNYWK,5.0,1657999964843,6013,7431,"[-1, -1, -1, -1, -1, -1, 42180, 71861, 73082, ..."


In [10]:
neg_df = generate_negative_samples(full_df, args.user_col, args.item_col, args.rating_col, neg_label=0, seed=args.random_seed, progress_bar_type='tqdm_notebook')
features = ["user_indice", "item_indice", "item_sequence"]
neg_ts_df = add_features_to_neg_df(full_df, neg_df, args.user_col, args.timestamp_col, features)
neg_ts_df

Generating Negative Samples:   0%|          | 0/10000 [00:00<?, ?it/s]

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,item_sequence
0,AE22QFIC5SDTXPDXBANVVZI6FX3Q,0374108234,0,1454944233000,7970,36931,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1...."
1,AE22QFIC5SDTXPDXBANVVZI6FX3Q,1617739251,0,1454944287000,7970,18292,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1...."
2,AE22QFIC5SDTXPDXBANVVZI6FX3Q,B007TC861M,0,1508347362448,7970,9740,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1...."
3,AE22QFIC5SDTXPDXBANVVZI6FX3Q,1451681755,0,1508347439290,7970,28049,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 369..."
4,AE22QFIC5SDTXPDXBANVVZI6FX3Q,B0876F272G,0,1572701704383,7970,83063,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 36931.0, ..."
...,...,...,...,...,...,...,...
161728,AHZZQNSG7UUC6YE5SKKA4HMCOQUQ,1465445609,0,1642920227166,967,58342,"[-1, -1, -1, 78320, 34949, 3911, 45408, 38195,..."
161729,AHZZRNJYTJETXCG4D43GZB7XL5VQ,1250619351,0,1395112263000,8104,69423,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1...."
161730,AHZZRNJYTJETXCG4D43GZB7XL5VQ,1423133161,0,1491674307000,8104,41281,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1...."
161731,AHZZRNJYTJETXCG4D43GZB7XL5VQ,B01MRH5HVW,0,1510787707469,8104,75796,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1...."


In [11]:
full_df = pd.concat([full_df, neg_ts_df], axis=0).sample(frac=1, replace=False, random_state=args.random_seed)

In [12]:
full_df

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,item_sequence
109825,AHRHRBEQUC2QWMOKM4NKZSURZYSA,1984830171,5.0,1567553652011,3997,65608,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1...."
36985,AEYYFUHPXZHZXW2NIDV723D5LNZQ,1455586420,0.0,1601576819114,642,67252,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1...."
32995,AFDLEF7U2BTXKWEQABKVOWWCMNIQ,1578563232,2.0,1406559422000,6603,50943,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1...."
116395,AEPQATTU5B6SKYIC7TFOU4ZRRTNA,1982123672,5.0,1578517855504,7418,39459,"[-1.0, -1.0, 35631.0, 69237.0, 89604.0, 29439...."
101319,AGMGDRKZHMDZY3F7AZBFWMJY77LA,0073381063,0.0,1486650643000,4563,32149,"[7543.0, 47867.0, 25646.0, 2350.0, 78923.0, 25..."
...,...,...,...,...,...,...,...
53491,AHLJ7GWIA6KYU5X47DCVLFQHTKGA,1619026007,5.0,1457579638000,3208,84823,"[30718.0, 39791.0, 34210.0, 72773.0, 71729.0, ..."
89227,AGBRSFUTITDVXT47M3DT77L3HR4A,0451233565,0.0,1557605001431,4061,62364,"[93199.0, 6352.0, 75733.0, 58077.0, 21320.0, 5..."
55325,AFHHMLPOSP2SP3UQVYIAJGKN35QQ,0529120666,0.0,1452564272000,8575,44681,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 68658.0, ..."
132003,AHMHBEWFVDLJEYTWRK3Q2E7JXHSQ,1844487156,5.0,1603138515370,6882,67962,"[9167.0, 60142.0, 47248.0, 69297.0, 15339.0, 2..."


In [13]:
full_df.to_parquet('../data/full_features_neg_sampling_df.parquet', index=False)

In [14]:
val_timestamp

np.int64(1628642557238)

In [15]:
train_neg_df = full_df.loc[lambda df: df[args.timestamp_col].lt(val_timestamp)]
val_neg_df = full_df.loc[lambda df: df[args.timestamp_col].ge(val_timestamp)]

In [16]:
train_neg_df.to_parquet("../data/train_features_neg_df.parquet", index=False)
val_neg_df.to_parquet("../data/val_features_neg_df.parquet", index=False)

In [17]:
full_df.loc[lambda df: df['user_id'].eq('AEYYFUHPXZHZXW2NIDV723D5LNZQ')].sort_values('timestamp')

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,item_sequence
36983,AEYYFUHPXZHZXW2NIDV723D5LNZQ,1938150457,0.0,1571063783442,642,34301,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1...."
111978,AEYYFUHPXZHZXW2NIDV723D5LNZQ,0142405965,5.0,1571063783442,642,34301,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1...."
128725,AEYYFUHPXZHZXW2NIDV723D5LNZQ,0991243560,5.0,1598565454026,642,20064,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1...."
36984,AEYYFUHPXZHZXW2NIDV723D5LNZQ,B07PLHWP46,0.0,1598565454026,642,20064,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1...."
36985,AEYYFUHPXZHZXW2NIDV723D5LNZQ,1455586420,0.0,1601576819114,642,67252,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1...."
130937,AEYYFUHPXZHZXW2NIDV723D5LNZQ,1524855154,5.0,1601576819114,642,67252,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1...."
36986,AEYYFUHPXZHZXW2NIDV723D5LNZQ,0062374818,0.0,1601576901716,642,74482,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 343..."
130938,AEYYFUHPXZHZXW2NIDV723D5LNZQ,0062861867,5.0,1601576901716,642,74482,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 343..."
36987,AEYYFUHPXZHZXW2NIDV723D5LNZQ,097196128X,0.0,1611182168373,642,86883,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 34301.0, ..."
137594,AEYYFUHPXZHZXW2NIDV723D5LNZQ,1524744603,5.0,1611182168373,642,86883,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 34301.0, ..."
