In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
import os
from loguru import logger
from dotenv import load_dotenv
from pydantic import BaseModel
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.sparse as sparse
import torch.optim as optim

import numpy as np # required for the scikit-learn pipeline to work
import pandas as pd
import plotly.express as px
import mlflow

load_dotenv()

sys.path.insert(0, '..')

from src.viz import blueq_colors

# Controller

In [3]:
class Args(BaseModel):
    testing: bool = False
    log_to_mlflow: bool = True
    experiment_name: str = "FSDS RecSys - L5 - Reco Algo"
    run_name: str = '065-content-based'
    notebook_persist_dp: str = None
    random_seed: int = 41

    user_col: str = 'user_id'
    item_col: str = 'parent_asin'
    rating_col: str = 'rating'
    timestamp_col: str = 'timestamp'
    
    top_K: int = 100
    top_k: int = 10

    batch_size: int = 128

    def init(self):
        self.notebook_persist_dp = os.path.abspath(f"data/{self.run_name}")
        
        if not os.environ.get("MLFLOW_TRACKING_URI"):
            logger.warning(
                f"Environment variable MLFLOW_TRACKING_URI is not set. Setting self.log_to_mlflow to false."
            )
            self.log_to_mlflow = False

        if self.log_to_mlflow:
            logger.info(
                f"Setting up MLflow experiment {self.experiment_name} - run {self.run_name}..."
            )
            import mlflow

            mlflow.set_experiment(self.experiment_name)
            mlflow.start_run(run_name=self.run_name)

        return self
    
args = Args().init()

print(args.model_dump_json(indent=2))

[32m2024-09-21 19:54:22.295[0m | [1mINFO    [0m | [36m__main__[0m:[36minit[0m:[36m29[0m - [1mSetting up MLflow experiment FSDS RecSys - L5 - Reco Algo - run 065-content-based...[0m


{
  "testing": false,
  "log_to_mlflow": true,
  "experiment_name": "FSDS RecSys - L5 - Reco Algo",
  "run_name": "065-content-based",
  "notebook_persist_dp": "/home/dvquys/frostmourne/reco-algo/notebooks/data/065-content-based",
  "random_seed": 41,
  "user_col": "user_id",
  "item_col": "parent_asin",
  "rating_col": "rating",
  "timestamp_col": "timestamp",
  "top_K": 100,
  "top_k": 10,
  "batch_size": 128
}


# Implement

In [4]:
from src.train_utils import train, MetricLogCallback
from src.model import ContentBased

In [5]:
def init_model():
    model = ContentBased(item_features=train_item_features, device=device)
    return model

In [6]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
# device = 'cpu'
logger.info(f"Using {device} device")

[32m2024-09-21 19:54:22.635[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m9[0m - [1mUsing cuda device[0m


In [7]:
import dill
import numpy as np # required for the pipeline to work
from src.data_prep_utils import chunk_transform

with open('../data/item_metadata_pipeline.dill', 'rb') as f:
    item_metadata_pipeline = dill.load(f)

# Test implementation

In [8]:
# Mock data
user_indices = [0, 0, 1, 2, 2]
item_indices = [0, 1, 2, 3, 4]
ratings = [1, 4, 5, 3, 2]
timestamps = [0, 1, 2, 3, 4]
main_category = ['All Electronics', 'Video Games', 'All Electronics', 'Video Games', "Unknown"]
title = ['All Electronics', 'Video Games', 'All Electronics', 'Video Games', "Unknown"]
description = [[], [], ["Video games blah blah"], [], ["blah blah"]]
categories = [[], ["Headsets"], ["Video Games"], [], ["blah blah"]]
price = ["from 14.99", "14.99", "price: 9.99", "20 dollars", "None"]

train_df = pd.DataFrame({
    "user_indice": user_indices,
    "item_indice": item_indices,
    args.rating_col: ratings,
    args.timestamp_col: timestamps,
    "main_category": main_category,
    "title": title,
    "description": description,
    "categories": categories,
    "price": price,
})
# Drop duplicated item features so that the ContentBased model will fit correctly in terms of index mapping
fit_df = train_df.drop_duplicates(subset=['item_indice'])
train_item_features = item_metadata_pipeline.transform(fit_df).astype(np.float32)

val_user_indices = [0, 1, 2]
val_item_indices = [2, 1, 2]
val_ratings = [2, 4, 5]
val_timestamps = [5, 6, 7]
val_main_category = ['All Electronics', 'Video Games', 'All Electronics']
val_title = ['All Electronics', 'Video Games', 'All Electronics']
val_description = [["Video games blah blah"], [], ["Video games blah blah"]]
val_categories = [["Video Games"], ["Headsets"], ["Video Games"]]
val_price = ["price: 9.99", "14.99", "price: 9.99"]

val_df = pd.DataFrame({
    "user_indice": val_user_indices,
    "item_indice": val_item_indices,
    args.rating_col: val_ratings,
    args.timestamp_col: val_timestamps,
    "main_category": val_main_category,
    "title": val_title,
    "description": val_description,
    "categories": val_categories,
    "price": val_price,
})
val_item_features = item_metadata_pipeline.transform(val_df).astype(np.float32)

In [9]:
n_users = len(set(user_indices))
n_items = len(set(item_indices))

model = init_model()

items1 = [1, 2]
items2 = [0, 3]
predictions = model.predict(items1, items2)
print(predictions)

print("\n\n")

users = [0, 1]
anchor_items = [2, 3]
recommendations = model.recommend(users, anchor_items, k=args.top_K, progress_bar_type='tqdm_notebook')
print(recommendations)

tensor([0.1249, 0.0388], device='cuda:0')





Generating Recommendations:   0%|          | 0/2 [00:00<?, ?it/s]

{'user_indice': [0, 0, 0, 0, 1, 1, 1, 1], 'recommendation': [0, 4, 1, 3, 1, 4, 0, 2], 'score': [0.6671584248542786, 0.13775993883609772, 0.03927098959684372, 0.03875051066279411, 0.9532012939453125, 0.32575443387031555, 0.1877065747976303, 0.03875051066279411]}


# Prep data

In [10]:
from src.id_mapper import IDMapper
from src.train_utils import map_indice

In [11]:
def get_last_item(df, item_sequence_col='item_sequence'):
    return (
        df
        .assign(
            last_item_indice=lambda df: df[item_sequence_col].apply(lambda s: s[-1])
        )
    )

In [12]:
train_df = pd.read_parquet("../data/train_features_neg_df.parquet")
val_df = pd.read_parquet("../data/val_features_neg_df.parquet")
idm = IDMapper().load("../data/idm.json")
# val_timestamp = 1628643414042  # https://amazon-reviews-2023.github.io/data_processing/5core.html
assert (val_df[args.timestamp_col].min() - train_df[args.timestamp_col].max()) > 0
val_timestamp = train_df[args.timestamp_col].max() + 1
print(f"{val_timestamp=}")

val_timestamp=np.int64(1628641464793)


In [13]:
train_df = train_df.pipe(get_last_item)
val_df = val_df.pipe(get_last_item)
full_df = pd.concat([train_df, val_df], axis=0)
user_ids = train_df[args.user_col].values
item_ids = train_df[args.item_col].values
unique_user_ids = list(set(user_ids))
unique_item_ids = list(set(item_ids))

logger.info(f"{len(unique_user_ids)=:,.0f}, {len(unique_item_ids)=:,.0f}")

[32m2024-09-21 19:54:25.928[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m9[0m - [1mlen(unique_user_ids)=20,366, len(unique_item_ids)=4,696[0m


In [14]:
train_df = train_df.pipe(map_indice, idm, args.user_col, args.item_col)
val_df = val_df.pipe(map_indice, idm, args.user_col, args.item_col)
full_df = full_df.pipe(map_indice, idm, args.user_col, args.item_col)

In [15]:
train_df

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,main_category,title,description,categories,price,item_sequence,last_item_indice
0,AEPV6L74QXWEH2DZGDL42UUBYV6A,B0B4CRTWGM,0.0,1417479397000,13282,2783,Video Games,Star Wars Knights of the Old Republic II: The ...,"[From the Manufacturer, Smug Statement: The Si...","[Video Games, Legacy Systems, Xbox Systems, Xb...",24.95,"[3355.0, 4276.0, 4032.0, 1909.0, 3803.0, 3713....",1271.0
1,AECRQW3YB7HM3V37DAZD4UPPNO3A,B08P1NS2X1,0.0,1375928743000,7731,3958,Video Games,LEGO City Undercover - PlayStation 4,"[Join the Chase! In LEGO CITY Undercover, play...","[Video Games, PlayStation 4, Games]",19.74,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1....",3909.0
2,AERF6JFF76TH6FSCFQWL66ANOHVA,B001G6064C,0.0,1545403509538,9451,2678,Video Games,F.E.A.R. 2: Project Origin - Playstation 3,"[Product Description, Fear Alma Again, Amazon....","[Video Games, Legacy Systems, PlayStation Syst...",44.49,"[-1.0, -1.0, -1.0, 1037.0, 1650.0, 3445.0, 151...",1132.0
3,AEAYIBNW4QQFLEC35Z7RQNZTUUOA,B005OGKBPE,0.0,1583917678434,12227,2545,Video Games,Syndicate - Origin PC [Online Game Code],"[From the Manufacturer, Syndicate is the re-im...","[Video Games, PC, Games]",,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 290.0, 22...",2845.0
4,AFAUFOADANHRMH3GR2FA73A6NHIA,B08VFQ3XJX,5.0,1408655680000,15838,1218,Video Games,Final Fantasy X,"[Product Description, Final Fantasy X finally ...","[Video Games, Legacy Systems, PlayStation Syst...",20.0,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 3674.0, 1...",3332.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...
340085,AFS7QJZZOXPJ4MDIXESEZ665TSJQ,B00XZQ58AI,0.0,1463416675000,9464,154,Video Games,NBA 2K16 - PlayStation 3,[The NBA 2K franchise is back with the most tr...,"[Video Games, Legacy Systems, PlayStation Syst...",59.68,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1....",2451.0
340086,AG2MDQBZUVR647JELUYNZL3UYSFA,B017GY07L4,0.0,1516402184763,5512,1898,Video Games,Nights of Azure - PlayStation 4,[Nights of Azure is a tragic tale of two frien...,"[Video Games, PlayStation 4, Games]",49.98,"[-1.0, -1.0, -1.0, -1.0, -1.0, 781.0, 2796.0, ...",198.0
340087,AGNE5FJF5GSLINB5RZ45EJVUVKYQ,B000P5BSUQ,0.0,1398805150000,18179,706,Video Games,Age Of Mythology: Titans - PC,[With the Age Of Mythology and The Titans bund...,"[Video Games, PC, Games]",,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1....",4279.0
340088,AEFJKBRX2TPPMMJ53DKHANKRYDXQ,B08N6NMGNB,5.0,1510350967758,18559,1746,Video Games,Thrustmaster T300 RS - Gran Turismo Edition Ra...,[Works with PS5 games (PS5 games compatibility...,"[Video Games, Legacy Systems, PlayStation Syst...",449.0,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 891...",3509.0


# Fit

In [16]:
fit_df = train_df.drop_duplicates(subset=[args.item_col])
train_item_features = item_metadata_pipeline.transform(fit_df).astype(np.float32)
model = init_model()

# Predict

In [17]:
val_df.sample(10)

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,main_category,title,description,categories,price,item_sequence,last_item_indice
1315,AEXN3VFNZS7CKHX2NHDHLYDBZZIQ,B0087ACBAW,0.0,1633099679814,6987,3732,Video Games,Rayman Origins [Download],"[From the Manufacturer, Product Description:, ...","[Video Games, PC]",,"[2128, 1144, 2691, 303, 3974, 3175, 4212, 4035...",1861
578,AESRA7QT2GPZYV73VIZZZHSCH7JA,B0BJKR3LJJ,5.0,1638056044509,9486,4237,Computers,EasySMX Wireless 2.4g Gaming Controller Suppor...,[],"[Video Games, Legacy Systems, PlayStation Syst...",,"[1753, 3952, 4239, 2038, 3251, 1324, 1650, 192...",1146
314,AHHGI57WUFRNZGU3CYCOI37PT4RA,B0001ZZNME,0.0,1652105390744,4455,3200,Video Games,The Legend of Zelda - Classic NES Series,"[From the Manufacturer, Embark on a quest to f...","[Video Games, Legacy Systems, Nintendo Systems...",52.49,"[4239, 2345, 173, 970, 4676, 4289, 3337, 1037,...",4238
1417,AEXFEQ7QOP6EHDEZ3K6NN27MQ7KA,B09V25XG1G,5.0,1651718479413,6008,4599,Video Games,"8BitDo Pro 2 Bluetooth Controller for Switch, ...","[Compatible with Switch, Windows PC, macOS, An...","[Video Games, Nintendo Switch, Accessories, Co...",49.99,"[1546, 1457, 3966, 3107, 891, 4461, 1759, 251,...",2707
1711,AEOY2365QPPEVDTOXL6N7ZA4NSAA,B00PDRZG9U,5.0,1628820275218,11654,275,Video Games,Code Name: S.T.E.A.M.,"[Launch S.T.E.A.M., an elite team of steam-pow...","[Video Games, Legacy Systems, Nintendo Systems...",12.99,"[-1, -1, -1, -1, -1, 1741, 928, 331, 4061, 1821]",1821
48,AE3U66S5YBEMPF36PVYR6QAS5ETA,B003R7JXOM,0.0,1630816495703,14425,1704,Video Games,Need For Speed Hot Pursuit - Nintendo Wii,"[Product Description, In Need for Speed Hot Pu...","[Video Games, Legacy Systems, Nintendo Systems...",45.22,"[-1, 2897, 3428, 254, 832, 200, 1276, 223, 224...",3292
376,AG2LMAQDTQSWAOLZYSZVFGXTZGLA,B00HX1UXD8,5.0,1634734203188,6682,2641,Video Games,Star Wars: The Old Republic 60-Day Pre-paid Ti...,[Star Wars: The Old Republic Prepaid Time Card],"[Video Games, PC, Games]",37.5,"[-1, -1, -1, 1923, 511, 357, 2690, 2850, 4035,...",2504
1479,AFMOSTKHH2HFLI35E3YMI7GLYDCQ,B06Y8XBVX8,0.0,1628687441776,19837,2333,Video Games,Eagle Flight - PlayStation VR,[Experience free flight as you soar through th...,"[Video Games, PlayStation 4, Games]",24.99,"[-1, -1, 1592, 3974, 4574, 902, 239, 1397, 227...",4013
1054,AECGR42H24LUU7DA7HEUMCRLTX7Q,B0BVVTQ5JP,4.0,1634956260370,19963,3058,Computers,Logitech G502 HERO High Performance Wired Gami...,[Logitech updated its iconic G502 gaming mouse...,"[Video Games, PC, Accessories, Gaming Mice]",45.87,"[316, 3799, 3891, 4472, 1259, 478, 4563, 2129,...",4673
1296,AFBRTNVOROW7UVA66UPX5YCFC6MQ,B01LRLJV28,0.0,1636189764550,6380,3797,Video Games,PlayStation 4 Slim 500GB Console - Uncharted 4...,[The new slim PlayStation 4 opens the door to ...,"[Video Games, PlayStation 4, Consoles]",272.95,"[-1, -1, -1, -1, 4147, 3768, 2921, 3638, 2141,...",3425


In [18]:
user_id = val_df.sample(1)[args.user_col].values[0]
test_df = val_df.loc[lambda df: df[args.user_col].eq(user_id)]
test_df

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,main_category,title,description,categories,price,item_sequence,last_item_indice
933,AHR6RGZTLOMBD7EBF3OK43JWMGNQ,B00005QEFD,0.0,1648282613310,12835,1803,Computers,GameCube (Jet Black),"[Product Description, The GameCube is the firs...","[Video Games, Legacy Systems, Nintendo Systems...",159.99,"[-1, -1, -1, -1, 896, 213, 3364, 759, 1260, 3203]",3203
1167,AHR6RGZTLOMBD7EBF3OK43JWMGNQ,B0BG2DPCKM,4.0,1637111038077,12835,3203,Computers,Razer DeathAdder V2 Gaming Mouse: 20K DPI Opti...,"[With over 10 million Razer DeathAdders sold, ...","[Video Games, PC, Accessories, Gaming Mice]",41.0,"[-1, -1, -1, -1, -1, 896, 213, 3364, 759, 1260]",1260
1363,AHR6RGZTLOMBD7EBF3OK43JWMGNQ,B0BDWVBWC9,5.0,1648282613310,12835,4213,Video Games,PowerA Charging Stand for PlayStation 4,[Charge and display your DUALSHOCK 4 wireless ...,"[Video Games, PlayStation 4, PlayStation VR Ha...",,"[-1, -1, -1, -1, 896, 213, 3364, 759, 1260, 3203]",3203
1870,AHR6RGZTLOMBD7EBF3OK43JWMGNQ,B00IAVDQCK,0.0,1637111038077,12835,2128,Video Games,Xbox One Stereo Headset,"[Surround your senses, Immerse yourself in ric...","[Video Games, Xbox One, Accessories, Headsets]",18.99,"[-1, -1, -1, -1, -1, 896, 213, 3364, 759, 1260]",1260


In [19]:
item_id = test_df.loc[lambda df: df[args.rating_col].gt(0)][args.item_col].values[0]
logger.info(f"Test predicting before training with {args.user_col} = {user_id} and {args.item_col} = {item_id}")
anchor_item_indice = test_df['last_item_indice'].values[0]
item_indice = idm.get_item_index(item_id)
model.predict([item_indice], [anchor_item_indice])

[32m2024-09-21 19:54:28.493[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mTest predicting before training with user_id = AHR6RGZTLOMBD7EBF3OK43JWMGNQ and parent_asin = B0BG2DPCKM[0m


tensor([1.0000], device='cuda:0')

# Recommend

In [20]:
val_anchor_item_indices = val_df['last_item_indice'].values
val_user_indices = val_df['user_indice'].values

In [21]:
recommendations = model.recommend(
    val_user_indices,
    val_anchor_item_indices,
    args.top_K,
    progress_bar_type='tqdm_notebook'
)

Generating Recommendations:   0%|          | 0/1898 [00:00<?, ?it/s]

# Evaluate

## Ranking metrics

In [22]:
from src.eval import create_label_df, create_rec_df, merge_recs_with_target
from src.eval import log_ranking_metrics

In [23]:
recommendations_df = pd.DataFrame(recommendations).pipe(create_rec_df, idm, args.user_col, args.item_col)
recommendations_df

Unnamed: 0,user_indice,recommendation,score,rec_ranking,user_id,parent_asin
0,2377,3641,0.306040,1.0,AEFWYBITAJIQEAGJMGBBZQPD246Q,B004V9QC80
1,2377,2442,0.290738,3.0,AEFWYBITAJIQEAGJMGBBZQPD246Q,B01K1OO5PU
2,2377,1241,0.288765,5.0,AEFWYBITAJIQEAGJMGBBZQPD246Q,B074HBNNH6
3,2377,3759,0.287635,7.0,AEFWYBITAJIQEAGJMGBBZQPD246Q,B002C1AUP0
4,2377,1637,0.281889,9.0,AEFWYBITAJIQEAGJMGBBZQPD246Q,B06WVCWY41
...,...,...,...,...,...,...
189795,19050,1306,0.082616,192.0,AHAKU6TTWIHJPZIODW7MGC52M2DA,B000PD0HQE
189796,19050,3111,0.082434,194.0,AHAKU6TTWIHJPZIODW7MGC52M2DA,B002P5UBG6
189797,19050,3431,0.081589,196.0,AHAKU6TTWIHJPZIODW7MGC52M2DA,B00EQNP8F4
189798,19050,1665,0.081103,198.0,AHAKU6TTWIHJPZIODW7MGC52M2DA,B07X2LNHGB


In [24]:
label_df = create_label_df(val_df, user_col=args.user_col, item_col=args.item_col, rating_col=args.rating_col, timestamp_col=args.timestamp_col)
label_df

Unnamed: 0,user_id,parent_asin,rating,rating_rank
1711,AEOY2365QPPEVDTOXL6N7ZA4NSAA,B00PDRZG9U,5.0,1.0
425,AFGHX4VLP6P5XORLDJX3LZKUAAZA,B00Z9TJBUW,5.0,1.0
189,AFCH2PDOFM2S3622QFV6PHCHGMCA,B00KSQHX1K,5.0,1.0
1297,AEURBISVS35ALE7YQLR5L4K7AHCA,B07QQ8N7LL,1.0,1.0
320,AEMA3SW3WPNLEH3IACW23K2ZSUFA,B09JDLC31H,4.0,1.0
...,...,...,...,...
663,AFB6FYPPCN33UMUU5536IHXNOHCQ,B00BGA9WK2,0.0,18.0
453,AESD4RLWUKM6JTD6SNNWYLHLLQQA,B00Z9TJHEC,0.0,18.0
582,AG4RCXKPTC6QRORJLUSBY4SO2IAA,B001G7PSGW,0.0,18.0
1374,AFB6FYPPCN33UMUU5536IHXNOHCQ,B01K1OO5PU,0.0,19.0


In [25]:
eval_df = merge_recs_with_target(recommendations_df, label_df, k=args.top_K, user_col=args.user_col, item_col=args.item_col, rating_col=args.rating_col)
eval_df

Unnamed: 0,user_indice,recommendation,score,rec_ranking,user_id,parent_asin,rating,rating_rank
159,2711.0,3244.0,0.577719,1,AE2AZ2MNROPF33U6SS53VI22OXJA,B07BZP7HML,0,
160,2711.0,3244.0,0.577719,2,AE2AZ2MNROPF33U6SS53VI22OXJA,B07BZP7HML,0,
151,2711.0,4565.0,0.518209,3,AE2AZ2MNROPF33U6SS53VI22OXJA,B072K63ZPD,0,
152,2711.0,4565.0,0.518209,4,AE2AZ2MNROPF33U6SS53VI22OXJA,B072K63ZPD,0,
8,2711.0,4329.0,0.498134,5,AE2AZ2MNROPF33U6SS53VI22OXJA,B00005Q8J1,0,
...,...,...,...,...,...,...,...,...
191501,15813.0,1343.0,0.105370,196,AHZNHP6OKXRZV2UJMYDPLWCKFKEA,B007NUQICE,0,
191444,15813.0,3407.0,0.105184,197,AHZNHP6OKXRZV2UJMYDPLWCKFKEA,B000B9RI00,0,
191445,15813.0,3407.0,0.105184,198,AHZNHP6OKXRZV2UJMYDPLWCKFKEA,B000B9RI00,0,
191492,15813.0,3079.0,0.104983,199,AHZNHP6OKXRZV2UJMYDPLWCKFKEA,B006PP4136,0,


In [26]:
ranking_report = log_ranking_metrics(args, eval_df)

  return (1 + beta_sqr) * precision_arr * recall_arr / (beta_sqr * precision_arr + recall_arr)


## Classification metrics

In [27]:
from evidently.metric_preset import ClassificationPreset
from src.eval import log_classification_metrics
from sklearn.preprocessing import MinMaxScaler

In [28]:
val_anchor_item_indices = val_df['last_item_indice'].values
val_item_indices = val_df['item_indice'].values

In [29]:
classifications = model.predict(val_item_indices, val_anchor_item_indices).cpu().detach().numpy().reshape(-1, 1)
classifications = MinMaxScaler(feature_range=(0, 1)).fit_transform(classifications)

In [30]:
eval_classification_df = val_df.assign(
    classification_proba=classifications,
    label=lambda df: df[args.rating_col].gt(0).astype(int)
)
eval_classification_df

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,main_category,title,description,categories,price,item_sequence,last_item_indice,classification_proba,label
0,AEFWYBITAJIQEAGJMGBBZQPD246Q,B001EYUS7G,0.0,1650810855155,2377,2080,Video Games,Far Cry 2: Fortune's Edition | PC Code - Ubiso...,"[Product Description, Includes Game + Fortune'...","[Video Games, Legacy Systems, PlayStation Syst...",,"[-1, -1, -1, -1, 2044, 1400, 4253, 3448, 3402,...",2818,0.135022,0
1,AEXN3VFNZS7CKHX2NHDHLYDBZZIQ,B002CZ38KA,0.0,1633099443693,6987,2376,Video Games,Heavy Rain - Greatest Hits,"[Product Description, Experience a gripping ps...","[Video Games, Legacy Systems, PlayStation Syst...",7.66,"[-1, 3431, 2128, 1144, 2691, 303, 3974, 3175, ...",4035,0.144466,0
2,AGCYZBKXV6Q5BGHWJB7J7D2HRWSA,B09R21G9DL,0.0,1640957371979,7520,4611,Computers,"Cipon Gamecube Controller, Wired Controller Ga...",[],"[Video Games, Legacy Systems, Nintendo Systems...",17.99,"[-1, -1, -1, -1, 1103, 2459, 750, 673, 2850, 3...",3872,0.142974,0
3,AEWCUX5UKUYPDZJIOB6XMLCBJ3KA,B0BLFYF8K2,4.0,1630263342566,9303,4165,Computers,"Logitech G600 MMO Gaming Mouse, RGB Backlit, 2...","[With 20 buttons, the Logitech G600 MMO Gaming...","[Video Games, PC, Accessories, Gaming Mice]",37.99,"[1829, 1711, 3115, 1930, 1657, 4651, 1579, 250...",4036,0.121690,1
4,AFFPVZ3JNCTQIKAK4XK37E2ENWWA,B00HVBPRUO,4.0,1655428133046,6775,2216,Video Games,Gold Wireless Stereo Headset - PlayStation 4,[A Headset for Gamers: Experience everything f...,"[Video Games, PlayStation 4, Accessories, Head...",,"[-1, -1, 4399, 3877, 1233, 3713, 2050, 3803, 2...",2469,0.142480,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1893,AFUWPAK6VCGEL2OVIL2YGZNFQJZQ,B08N6NCR3Q,4.0,1642699950266,3144,4617,Video Games,Thrustmaster T 16000M SPACE SIM DUO STICK (PC),[The THRUSTMASTER T.16000M FCS Space Sim Duo c...,"[Video Games, PC, Accessories, Controllers, Fl...",119.51,"[-1, -1, -1, -1, 3648, 3017, 4093, 3173, 4263,...",3622,0.373774,1
1894,AEPOQDJZJCF5APANNFRSABUNU4IA,B07G3KB7RT,0.0,1643422574208,10070,200,Video Games,Satisfye – ZenGrip Pro Gen 3 OLED Elite Bundle...,[],"[Video Games, Nintendo Switch, Accessories, Ha...",89.99,"[3808, 1356, 638, 3934, 495, 4213, 2717, 1721,...",2021,0.156401,0
1895,AFH63KLSVQQYRNFS7NLQGD3GSP3A,B094YHB1QK,5.0,1652564728981,13283,3456,Video Games,PlayStation DualSense Wireless Controller – Ga...,[Plot a course for astronomical adventures on ...,"[Video Games, PlayStation 5, Accessories, Cont...",74.99,"[-1, 1999, 1652, 2454, 2557, 1334, 129, 2409, ...",4445,0.217129,1
1896,AFPPTJOEUPVXA5C63SNRGID3EQNA,B0BVVTQ5JP,4.0,1635968491390,15033,3058,Computers,Logitech G502 HERO High Performance Wired Gami...,[Logitech updated its iconic G502 gaming mouse...,"[Video Games, PC, Accessories, Gaming Mice]",45.87,"[-1, -1, -1, -1, -1, 2884, 1953, 1724, 3591, 1...",1371,0.159081,1


In [31]:
classification_report = log_classification_metrics(args, eval_classification_df, target_col='label', prediction_col='classification_proba')

# Clean up

In [32]:
all_params = [args]

if args.log_to_mlflow:
    for params in all_params:
        params_dict = params.dict()
        params_ = {f"{params.__repr_name__()}.{k}": v for k, v in params_dict.items()}
        mlflow.log_params(params_)

    mlflow.end_run()

2024/09/21 19:54:33 INFO mlflow.tracking._tracking_service.client: 🏃 View run 065-content-based at: http://localhost:5003/#/experiments/1/runs/7c68a689e9724be48b52b7ae907cc77e.
2024/09/21 19:54:33 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://localhost:5003/#/experiments/1.
