In [None]:
import json
import logging
import time
import warnings
from logging import getLogger

import numpy as np
import pandas as pd
import torch
from recbole.config import Config
from recbole.data import create_dataset, data_preparation
from recbole.quick_start import run_recbole
from recbole.utils import init_seed, init_logger
from tqdm import tqdm

warnings.filterwarnings("ignore")
from recbole.model.general_recommender.multivae import MultiVAE

# Read and process data

In [None]:
interactions_df = pd.read_csv("../data/interactions_processed.csv")
users_df = pd.read_csv("../data/users_processed.csv")
items_df = pd.read_csv("../data/items_processed.csv")

In [None]:
interactions_df["t_dat"] = pd.to_datetime(interactions_df["last_watch_dt"], format="%Y-%m-%d")
interactions_df["timestamp"] = interactions_df.t_dat.values.astype(np.int64) // 10**9

In [None]:
df = interactions_df[["user_id", "item_id", "timestamp"]].rename(
    columns={"user_id": "user_id:token", "item_id": "item_id:token", "timestamp": "timestamp:float"}
)

In [None]:
!mkdir recbox_data

In [None]:
df.to_csv("recbox_data/recbox_data.inter", index=False, sep="\t")

# Prepare training pipeline

## Hyperparameters

In [None]:
parameter_dict = {
    "data_path": "",
    "USER_ID_FIELD": "user_id",
    "ITEM_ID_FIELD": "item_id",
    "TIME_FIELD": "timestamp",
    "device": "GPU",
    "user_inter_num_interval": "[40,inf)",
    "item_inter_num_interval": "[40,inf)",
    "load_col": {"inter": ["user_id", "item_id", "timestamp"]},
    "neg_sampling": None,
    "epochs": 10,
    "eval_args": {"split": {"RS": [9, 0, 1]}, "group_by": "user", "order": "TO", "mode": "full"},
}
config = Config(model="MultiVAE", dataset="recbox_data", config_dict=parameter_dict)

# init random seed
init_seed(config["seed"], config["reproducibility"])

# logger initialization
init_logger(config)
logger = getLogger()
# Create handlers
c_handler = logging.StreamHandler()
c_handler.setLevel(logging.INFO)
logger.addHandler(c_handler)

## Splitting dataset

In [None]:
dataset = create_dataset(config)
logger.info(dataset)

In [None]:
train_data, valid_data, test_data = data_preparation(config, dataset)

# Brute-force exploring models

In [None]:
%%time
model_list = ["MultiVAE", "MultiDAE", "MacridVAE", "NeuMF", "RecVAE"]

for model_name in model_list:
    print(f"running {model_name}...")
    start = time.time()
    result = run_recbole(model=model_name, dataset="recbox_data", config_dict=parameter_dict)
    t = time.time() - start
    print(f"It took {t/60:.2f} mins")
    print(result)

# MultiVAE to production

In [None]:
result = run_recbole(model="MultiVAE", dataset="recbox_data", config_dict=parameter_dict)

## Loading model

In [None]:
model = MultiVAE(config, dataset)
checkpoint = torch.load("../model_weights/MultiVAE.pth")
model.load_state_dict(checkpoint["state_dict"])

## Create recommendations

In [None]:
def recommend_to_user(external_user_id, dataset, model):
    if external_user_id in dataset.field2token_id[dataset.uid_field] and external_user_id != "[PAD]":
        model.eval()
        with torch.no_grad():
            uid_series = dataset.token2id(dataset.uid_field, [external_user_id])
            index = np.isin(dataset[dataset.uid_field].numpy(), uid_series)
            new_inter = dataset[index]
            new_inter = new_inter.to(config["device"])
            new_scores = model.full_sort_predict(new_inter)
            new_scores = new_scores.view(-1, test_data.dataset.item_num)
            new_scores[:, 0] = -np.inf
            recommended_item_indices = torch.topk(new_scores, 10).indices[0].tolist()
            recos = dataset.id2token(dataset.iid_field, [recommended_item_indices]).tolist()
        return recos
    return []


recos = {
    user_id: recommend_to_user(user_id, dataset, model) for user_id in tqdm(dataset.field2token_id[dataset.uid_field])
}

with open("../data/MultiVAE-recommendations.json", "w") as f:
    json.dump(recos, f)