# ProtoVQ-VAE:

## Resources

### Papers
Melchiorre, Alessandro B., Navid Rekabsaz, Christian Ganhör, and Markus Schedl. 2022. **“ProtoMF: Prototype-Based Matrix Factorization for Effective and Explainable Recommendations.”** In Sixteenth ACM Conference on Recommender Systems, 246–56. Seattle WA USA: ACM. https://doi.org/10.1145/3523227.3546756.

Oord, Aaron van den, Oriol Vinyals, and Koray Kavukcuoglu. 2018. **“Neural Discrete Representation Learning.”** arXiv. http://arxiv.org/abs/1711.00937.

Liang, Dawen, Rahul G. Krishnan, Matthew D. Hoffman, and Tony Jebara. 2018. **“Variational Autoencoders for Collaborative Filtering.”** arXiv. http://arxiv.org/abs/1802.05814.

Shenbin, Ilya, Anton Alekseev, Elena Tutubalina, Valentin Malykh, and Sergey I. Nikolenko. 2020. **“RecVAE: A New Variational Autoencoder for Top-N Recommendations with Implicit Feedback.”** In Proceedings of the 13th International Conference on Web Search and Data Mining, 528–36. Houston TX USA: ACM. https://doi.org/10.1145/3336191.3371831.

### Code
Recbole Base dataset:                           https://github.com/RUCAIBox/RecSysDatasets/blob/master/conversion_tools/src/base_dataset.py 

LFM2b1monDataset adapted from:                  https://github.com/RUCAIBox/RecSysDatasets/blob/master/conversion_tools/src/extended_dataset.py

Implementation of VQ-VAE:                       https://colab.research.google.com/github/zalandoresearch/pytorch-vq-vae/blob/master/vq-vae.ipynb

### Dataset
Last FM Music Listening Events (2020 Subset):   http://www.cp.jku.at/datasets/LFM-2b/

### General
Thesis writing guide (ML institute):            https://docs.google.com/document/d/1p5d0eykqaw0dfB-L62kPOpdetn8X3CgcJQ3wh7RhPlw/edit#heading=h.8hffm3guomu
NeurIPS 2024 template:                          https://www.overleaf.com/latex/templates/neurips-2024/tpsbbrdqcmsh

### Changelog
v0.1: for-loop used for vector quantization, very inefficient
v0.2: for-loop replaced with matrix-based vector quantization




## Imports

In [1]:
# Google Colab only
#!pip install recbole
#!pip install -U ray
#!pip install wandb
#!jupyter nbextension enable --py widgetsnbextension

import os
import sys
import gc
from logging import getLogger
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

from recbole.quick_start import run_recbole
from recbole.model.abstract_recommender import GeneralRecommender
from recbole.utils import InputType, init_logger, init_seed, get_flops, set_color
from recbole.model.init import xavier_normal_initialization
from recbole.trainer import Trainer
from recbole.config import Config
from recbole.data import create_dataset, data_preparation
from recbole.data.dataloader.general_dataloader import FullSortEvalDataLoader
from recbole.data.transform import construct_transform

In [2]:
# Google Colab only
#from google.colab import drive
#drive.mount('/content/drive')
#config = 'drive/MyDrive/Colab Notebooks/data/protovq-vae/config.yaml'

config = '/home/matt/SynologyDrive/Code/ProtoVQ-VAE/config.yaml'

## GPU Checks

In [3]:
# Empty GPU cache
gc.collect()
torch.cuda.empty_cache()

# Get GPU info
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Mon Jun 17 21:14:52 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 555.42.02              Driver Version: 555.42.02      CUDA Version: 12.5     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce GTX 1650        Off |   00000000:01:00.0  On |                  N/A |
| 59%   40C    P8             10W /   75W |     577MiB /   4096MiB |     17%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [4]:
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

Your runtime has 16.7 gigabytes of available RAM

Not using a high-RAM runtime


## Data Preprocessing
Transforming raw data into atomic format required for RecBole (see https://recbole.io/docs/user_guide/data/atomic_files.html).

LFM2B-1MON download:        http://www.cp.jku.at/datasets/LFM-2b/

=> Extract to folder in CWD:   /data/lfm2b-1mon/original/

In [5]:
# Prepare LFM2B-1M

def data_to_inter(path:str, user_col:str, item_col:str):
    """Turns data into atomic inter format.
    Args:
        path (str): File path
        user_col (str): Column name of user IDs
        item_col (str): Column name of item IDs
    """
    df = pd.read_csv(path)
    df_inter = df[[user_col, item_col]]

    return df_inter

def factorize_lfm2b_1mon():
    df = pd.read_csv('data/atomic/lfm2b-1mon/lfm2b-1mon.inter', sep='\t')
    df.rename(
        columns={'user_id:token': 'user_id_old:token', 'item_id:token': 'item_id_old:token'},
        inplace=True
    )
    df['user_id:token'] = pd.factorize(df['user_id_old:token'])[0]
    df['item_id:token'] = pd.factorize(df['item_id_old:token'])[0]
    columns_new = ['user_id_old:token', 'item_id_old:token', 'user_id:token', 'item_id:token', 'timestamp:float', 'num_repeat:float']
    df = df[columns_new]
    df.to_csv('data/lfm2b-1mon/atomic/lfm2b-1mon_remapped.inter', index=False, sep='\t')


# Uncomment for data preprocessing
#data = data_to_inter('data/lfm2b-1mon/listening_history.csv', 'user_id:token', 'item_id:token')
#data.to_csv('data/lfm2b-1mon/atomic/test_data.inter', index=False, sep='\t')
#factorize_lfm2b_1mon()


In [6]:
# BaseDataset adapted from: https://github.com/RUCAIBox/RecSysDatasets/blob/master/conversion_tools/src/base_dataset.py
class BaseDataset(object):
    def __init__(self, input_path, output_path):
        super(BaseDataset, self).__init__()

        self.dataset_name = ''
        self.input_path = input_path
        self.output_path = output_path
        self.check_output_path()

        # input file
        self.inter_file = os.path.join(self.input_path, 'inters.dat')
        self.item_file = os.path.join(self.input_path, 'items.dat')
        self.user_file = os.path.join(self.input_path, 'users.dat')
        self.sep = '\t'

        # output file
        self.output_inter_file, self.output_item_file, self.output_user_file = self.get_output_files()

        # selected feature fields
        self.inter_fields = {}
        self.item_fields = {}
        self.user_fields = {}

    def check_output_path(self):
        if not os.path.isdir(self.output_path):
            os.makedirs(self.output_path)

    def get_output_files(self):
        output_inter_file = os.path.join(self.output_path, self.dataset_name + '.inter')
        output_item_file = os.path.join(self.output_path, self.dataset_name + '.item')
        output_user_file = os.path.join(self.output_path, self.dataset_name + '.user')
        return output_inter_file, output_item_file, output_user_file

    def load_inter_data(self) -> pd.DataFrame():
        raise NotImplementedError

    def load_item_data(self) -> pd.DataFrame():
        raise NotImplementedError

    def load_user_data(self) -> pd.DataFrame():
        raise NotImplementedError

    def convert_inter(self):
        try:
            input_inter_data = self.load_inter_data()
            self.convert(input_inter_data, self.inter_fields, self.output_inter_file)
        except NotImplementedError:
            print('This dataset can\'t be converted to inter file\n')

    def convert_item(self):
        try:
            input_item_data = self.load_item_data()
            self.convert(input_item_data, self.item_fields, self.output_item_file)
        except NotImplementedError:
            print('This dataset can\'t be converted to item file\n')

    def convert_user(self):
        try:
            input_user_data = self.load_user_data()
            self.convert(input_user_data, self.user_fields, self.output_user_file)
        except NotImplementedError:
            print('This dataset can\'t be converted to user file\n')

    @staticmethod
    def convert(input_data, selected_fields, output_file):
        output_data = pd.DataFrame()
        for column in selected_fields:
            output_data[column] = input_data.iloc[:, column]
        with open(output_file, 'w') as fp:
            fp.write('\t'.join([selected_fields[column] for column in output_data.columns]) + '\n')
            for i in tqdm(range(output_data.shape[0])):
                fp.write('\t'.join([str(output_data.iloc[i, j])
                                    for j in range(output_data.shape[1])]) + '\n')

    def parse_json(self, data_path):
        with open(data_path, 'rb') as g:
            for l in g:
                yield eval(l)

    def getDF(self, data_path):
        i = 0
        df = {}
        for d in self.parse_json(data_path):
            df[i] = d
            i += 1
        data = pd.DataFrame.from_dict(df, orient='index')

        return data

In [7]:
# LFM2b1monDataset adapted from: https://github.com/RUCAIBox/RecSysDatasets/blob/master/conversion_tools/src/extended_dataset.py
class LFM2b1monDataset(BaseDataset):
    def __init__(self, input_path, output_path, duplicate_removal):
        super(LFM2b1monDataset, self).__init__(input_path, output_path)
        self.input_path = input_path
        self.output_path = output_path

        self.duplicate_removal = duplicate_removal  # merge repeat interactions if 'duplicate_removal' is True

        self.dataset_name = 'lfm2b-1mon'

        # input file
        self.inter_file = os.path.join(self.input_path, 'listening_events.tsv')
        self.item_file = os.path.join(self.input_path, 'tracks.tsv')
        self.user_file = os.path.join(self.input_path, 'users.tsv')

        self.sep = '\t'

        # output file
        self.output_inter_file, self.output_item_file, self.output_user_file = self.get_output_files()

        # selected feature fields
        if self.duplicate_removal == True:
            self.inter_fields = {0: 'user_id:token',
                                 1: 'item_id:token',
                                 2: 'timestamp:float',
                                 3: 'num_repeat:float'
                                 }
        else:
            self.inter_fields = {0: 'user_id:token',
                                 1: 'item_id:token',
                                 2: 'timestamp:float'
                                 }

        self.item_fields = {0: 'item_id:token',
                            1: 'name:token_seq',
                            2: 'artists_id:token'
                            }

        self.user_fields = {0: 'user_id:token',
                            1: 'country:token',
                            2: 'age:float',
                            3: 'gender:token',
                            4: 'timestamp:float',
                            }

    def convert_inter(self):
        fout = open(self.output_inter_file, 'w')
        fout.write('\t'.join([self.inter_fields[i] for i in range(len(self.inter_fields))]) + '\n')

        if self.duplicate_removal == True:
            self.run_duplicate_removal(fout)
        else:
            with open(self.inter_file, 'r') as f:
                line = f.readline()
                while True:
                    if not line:
                        break

                    line = line.strip().split('\t')
                    userid, itemid, timestamp = line[0], line[1], line[2]
                    fout.write(str(userid) + '\t' + str(itemid) + '\t' + str(timestamp) + '\n')
                    line = f.readline()

        print(self.output_inter_file + ' is done!')
        fout.close()

    def convert_item(self):
        fout = open(self.output_item_file, 'w')
        fout.write('\t'.join([self.item_fields[i] for i in range(len(self.item_fields))]) + '\n')

        cnt_row = 0
        dict_all_items = {}
        with open(self.item_file, 'r') as f:
            line = f.readline()
            while True:
                if not line:
                    break
                fout.write(line)
                line = f.readline()
        print(self.output_item_file + ' is done!')
        fout.close()

    def convert_user(self):
        fout = open(self.output_user_file, 'w')
        fout.write('\t'.join([self.user_fields[i] for i in range(len(self.user_fields))]) + '\n')

        with open(self.user_file, 'r') as f1:
            with open(self.user_file, 'r') as f2:
                line1 = f1.readline()
                line2 = f2.readline()
                line1 = f1.readline()
                line2 = f2.readline()
                while True:
                    if not line1 or not line2:
                        break
                    line1 = line1.strip()
                    line2 = line2.strip().replace('?', '')
                    line2 = line2.split('\t')
                    fout.write(line1 + '\t')
                    fout.write('\t'.join([line2[i] for i in range(1, len(line2))]) + '\n')
                    line1 = f1.readline()
                    line2 = f2.readline()
        print(self.output_user_file + ' is done!')
        fout.close()

    def run_duplicate_removal(self, fout):
        a_user = {}
        pre_userid = '33738'
        user_order = []
        with open(self.inter_file, 'r') as f:
            next(f)
            line = f.readline()
            while True:
                if not line:
                    if pre_userid not in user_order:
                        user_order.append(pre_userid)
                    for userid in user_order:
                        for key, value in a_user[userid].items():
                            fout.write(
                                str(userid) + '\t' + str(key) + '\t' + str(value[0]) + '\t' + str(value[1]) + '\n')
                    break
                line = line.strip().split('\t')
                userid, itemid, timestamp = line[0], line[1], line[2]

                if userid not in a_user.keys():
                    a_user[userid] = {}
                if itemid not in a_user[userid].keys():
                    a_user[userid][itemid] = [timestamp, 1]
                else:
                    a_user[userid][itemid][1] += 1

                if userid != pre_userid:
                    if pre_userid not in user_order:
                        user_order.append(pre_userid)
                    pre_userid = userid
                line = f.readline()

# Uncomment lines below to run once for conversion to RecBole .inter format
#dataset = LFM2b1monDataset(
#    input_path='data/lfm2b-1mon/original/',
#    output_path='data/lfm2b-1mon/atomic/',
#    duplicate_removal=True
#)

#dataset.convert_inter()

# python3 run.py --dataset lfm1b --input_path data --output_path lfm2b_1mon --interaction_type tracks --duplicate_removal --convert_inter

In [8]:
class ProtoVQ_VAE(GeneralRecommender):
    input_type = InputType.POINTWISE

    def __init__(self, config, dataset):
        super(ProtoVQ_VAE, self).__init__(config, dataset)

        # Load dataset info
        self.n_users = dataset.user_num
        self.n_items = dataset.item_num

        # Load config
        self.layers = config['mlp_hidden_size']
        #self.lat_dim = config['latent_dimension']
        #self.lat_split = config['latent_split']

        # latent dimension must be divisible without remainder by latent split
        #assert self.lat_dim % self.lat_split == 0, f'lat_dim ({self.lat_dim}) not divisible without remainder by lat_split ({self.lat_split}).'

        #self.proto_dim = int(self.lat_dim / self.lat_split) # dimension of embedding vectors D
        self.proto_dim = config['proto_dim']
        self.n_proto = config['n_proto']                    # number of vectors in codebook K
        self.topk_proto = config['topk_proto']              # number of most similar prototypes
        self.proto_idx_hist = []
        self.drop_out = config['drop_out']
        #self.anneal_cap = config['anneal_cap']
        #self.total_anneal_steps = config['total_anneal_steps']
        self.commitment_cost = config['commitment_cost']    # beta term in paper, term (3)

        # Load item history
        self.history_item_id, self.history_item_value, _ = dataset.history_item_matrix()
        self.history_item_id = self.history_item_id.to(self.device)
        self.history_item_value = self.history_item_value.to(self.device)

        # Create encoder and decoder
        self.encode_layer_dims = [self.n_items] + self.layers + [self.proto_dim]
        self.decode_layer_dims = [int(self.proto_dim)] + self.encode_layer_dims[::-1][1:]
        self.encoder = self.mlp_layers(self.encode_layer_dims)
        self.decoder = self.mlp_layers(self.decode_layer_dims) 

        # Parameter initialization
        self.apply(xavier_normal_initialization)
        
        # Create prototype embeddings
        self.prototypes = nn.Embedding(self.n_proto, self.proto_dim)
        self.prototypes.weight.data.uniform_(-1/self.n_proto, 1/self.n_proto)

        self.update = 0


    def get_rating_matrix(self, user):
        r"""Get a batch of user's features with the user's id and history interaction matrix.

        Args:
            user (torch.LongTensor): Input tensor that contains user's id, shape: [batch_size, ]

        Returns:
            torch.FloatTensor: The user features of a batch of users, shape: [batch_size, n_items]
        """

        # Construct tensor of shape [batch_size, n_items] using tensor of shape [B, H]
        col_indices = self.history_item_id[user].flatten()
        row_indices = (
            torch.arange(user.shape[0])
            .to(self.device)
            .repeat_interleave(self.history_item_id.shape[1], dim=0)
        )
        rating_matrix = (torch.zeros(1).to(self.device).repeat(user.shape[0], self.n_items))
        rating_matrix.index_put_((row_indices, col_indices), self.history_item_value[user].flatten())

        return rating_matrix


    def mlp_layers(self, layer_dims):
        mlp_modules = []
        for i, (d_in, d_out) in enumerate(zip(layer_dims[:-1], layer_dims[1:])):
            mlp_modules.append(nn.Linear(d_in, d_out))
            if i != len(layer_dims[:-1]) - 1:
                mlp_modules.append(nn.ReLU())
        return nn.Sequential(*mlp_modules)


    def vector_quantizer(self, x):
        x_shape = x.shape

        # Flatten input
        flat_input = x.view(-1, self.proto_dim)

        # Calculate distances
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
                    + torch.sum(self.prototypes.weight**2, dim=1)
                    - 2 * torch.matmul(flat_input, self.prototypes.weight.t()))
        #distances = torch.dist(flat_input, self.prototypes.weight)


        # Encoding
        encoding_indices = torch.topk(distances, k=self.topk_proto, largest=False).indices 
        #encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        
        # Add to proto_idx_hist to check frequency of prototype vectors being used
        #self.proto_idx_hist.append(encoding_indices)
        
        # Set index of closest prototype vector to 1
        encodings = torch.zeros(encoding_indices.shape[0], self.n_proto, device=x.device)
        encodings.scatter_(1, encoding_indices, 1)

        # Quantize and unflatten
        quantized = torch.matmul(encodings, self.prototypes.weight).view(x_shape) #/ self.topk_proto

        # Loss - Last part of term (3) in paper
        e_latent_loss = F.mse_loss(quantized.detach(), x)
        q_latent_loss = F.mse_loss(quantized, x.detach())
        vq_loss = q_latent_loss + self.commitment_cost * e_latent_loss

        quantized = x + (quantized - x).detach()
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

        return quantized, vq_loss, perplexity, encodings


    def forward(self, x):
        # From MultiVAE
        x = F.normalize(x) # Question: Why?
        x = F.dropout(x, self.drop_out, training=self.training)
        x = self.encoder(x)

        # Split tensor
        #x_split = torch.tensor_split(x, self.lat_split, dim=1)

        #z, vq_loss, perplexity = [], [], []
        #for x_ in x:
        #    quantized_, vq_loss_, perplexity_, _ = self.vector_quantizer(x_)
        #    z.append(self.decoder(quantized_))
        #    vq_loss.append(vq_loss_)
        #    perplexity.append(perplexity_)

        quantized, vq_loss, perplexity, _ = self.vector_quantizer(x)
        z = self.decoder(quantized)
        
        #z = torch.stack(z, dim=0)
        #z = torch.sum(torch.stack(z, dim=0), dim=0)
        vq_loss = torch.mean(vq_loss)
        #vq_loss = torch.stack(vq_loss)
        perplexity = torch.mean(perplexity)
        #perplexity = torch.stack(perplexity)

        return z, vq_loss, perplexity


    def calculate_loss(self, interaction):
        user = interaction[self.USER_ID]
        x = self.get_rating_matrix(user) 

        self.update += 1

        #if self.total_anneal_steps > 0:
        #    anneal = min(self.anneal_cap, 1.0 * self.update / self.total_anneal_steps)
        #else:
        #    anneal = self.anneal_cap

        z, vq_loss, perplexity = self.forward(x)

        #vq_loss *= anneal

        #CE loss
        #recon_error = -(F.log_softmax(x_recon, 1) * x).sum(1).mean()
        recon_error = -(F.log_softmax(z, 1) * x).sum(1).mean()
        del z
        #recon_error = F.cross_entropy(x_recon, x)
        #loss = reScon_error + vq_loss

        return recon_error, vq_loss


    def predict(self, interaction):
        user = interaction[self.USER_ID]
        item = interaction[self.ITEM_ID]

        rating_matrix = self.get_rating_matrix(user)

        scores, _, _ = self.forward(rating_matrix)

        return scores[[torch.arange(len(item)).to(self.device), item]]


    def full_sort_predict(self, interaction):
        user = interaction[self.USER_ID]

        rating_matrix = self.get_rating_matrix(user)

        scores, _, _ = self.forward(rating_matrix)

        return scores.view(-1)

In [9]:
def run_recbole(
    model=None, dataset=None, config_file_list=None, config_dict=None, saved=True
):
    r"""A fast running api, which includes the complete process of
    training and testing a model on a specified dataset

    Args:
        model (str, optional): Model name. Defaults to ``None``.
        dataset (str, optional): Dataset name. Defaults to ``None``.
        config_file_list (list, optional): Config files used to modify experiment parameters. Defaults to ``None``.
        config_dict (dict, optional): Parameters dictionary used to modify experiment parameters. Defaults to ``None``.
        saved (bool, optional): Whether to save the model. Defaults to ``True``.
    """
    # configurations initialization
    config = Config(
        model=model,
        dataset=dataset,
        config_file_list=config_file_list,
        config_dict=config_dict,
    )
    init_seed(config["seed"], config["reproducibility"])
    # logger initialization
    init_logger(config)
    logger = getLogger()
    logger.info(sys.argv)
    logger.info(config)

    # dataset filtering
    dataset = create_dataset(config)
    logger.info(dataset)

    # dataset splitting
    train_data, valid_data, test_data = data_preparation(config, dataset)

    # model loading and initialization
    init_seed(config["seed"] + config["local_rank"], config["reproducibility"])
    model = model(config, train_data.dataset).to(config['device'])
    logger.info(model)

    transform = construct_transform(config)
    flops = get_flops(model, dataset, config["device"], logger, transform)
    logger.info(set_color("FLOPs", "blue") + f": {flops}")

    # trainer loading and initialization
    trainer = Trainer(config, model)

    # model training
    best_valid_score, best_valid_result = trainer.fit(
        train_data, valid_data, saved=saved, show_progress=config["show_progress"]
    )

    # Show distribution of prototype indices chosen during all epochs
    print(pd.DataFrame(np.bincount(torch.cat(model.proto_idx_hist).flatten().cpu().numpy())))

    # model evaluation
    test_result = trainer.evaluate(
        test_data, load_best_model=saved, show_progress=config["show_progress"]
    )

    logger.info(set_color("best valid ", "yellow") + f": {best_valid_result}")
    logger.info(set_color("test result", "yellow") + f": {test_result}")

    return {
        "best_valid_score": best_valid_score,
        "valid_score_bigger": config["valid_metric_bigger"],
        "best_valid_result": best_valid_result,
        "test_result": test_result,
    }

In [10]:
run_recbole(model=ProtoVQ_VAE, dataset='lfm2b-1mon', config_file_list=[config])

17 Jun 21:14    INFO  ['/home/matt/miniconda3/envs/thesis/lib/python3.9/site-packages/ipykernel_launcher.py', '--f=/home/matt/.local/share/jupyter/runtime/kernel-v2-5970FqhvI0ElYJ4c.json']
17 Jun 21:14    INFO  
General Hyper Parameters:
gpu_id = 0
use_gpu = True
seed = 2024
state = INFO
reproducibility = True
data_path = /home/matt/SynologyDrive/Code/ProtoVQ-VAE/data/lfm2b-1mon
checkpoint_dir = saved
show_progress = True
save_dataset = False
dataset_save_path = None
save_dataloaders = False
dataloaders_save_path = None
log_wandb = True

Training Hyper Parameters:
epochs = 3000
train_batch_size = 2048
learner = adam
learning_rate = 0.001
train_neg_sample_args = {'distribution': 'uniform', 'sample_num': 1, 'alpha': 1.0, 'dynamic': False, 'candidate_num': 0}
eval_step = 10
stopping_step = 10
clip_grad_norm = None
weight_decay = 0.0
loss_decimal_place = 4

Evaluation Hyper Parameters:
eval_args = {'split': {'RS': [0.9, 0.05, 0.05]}, 'order': 'RO', 'group_by': 'user', 'mode': 'full'}
repea

[1;35mTrain     0[0m:   0%|                                                          | 0/799 [00:00<?, ?it/s…

17 Jun 21:15    INFO  epoch 0 training [time: 25.80s, train_loss1: 800945.7791, train_loss2: 451028.5852]


[1;35mTrain     1[0m:   0%|                                                          | 0/799 [00:00<?, ?it/s…

17 Jun 21:16    INFO  epoch 1 training [time: 26.42s, train_loss1: 766404.9478, train_loss2: 365207.7835]


[1;35mTrain     2[0m:   0%|                                                          | 0/799 [00:00<?, ?it/s…

17 Jun 21:16    INFO  epoch 2 training [time: 25.54s, train_loss1: 746061.5517, train_loss2: 283931.8888]


[1;35mTrain     3[0m:   0%|                                                          | 0/799 [00:00<?, ?it/s…

17 Jun 21:16    INFO  epoch 3 training [time: 25.54s, train_loss1: 733833.2191, train_loss2: 262978.2529]


[1;35mTrain     4[0m:   0%|                                                          | 0/799 [00:00<?, ?it/s…

17 Jun 21:17    INFO  epoch 4 training [time: 25.46s, train_loss1: 726296.1158, train_loss2: 260712.3419]


[1;35mTrain     5[0m:   0%|                                                          | 0/799 [00:00<?, ?it/s…

17 Jun 21:17    INFO  epoch 5 training [time: 25.33s, train_loss1: 720925.4944, train_loss2: 308263.1979]


[1;35mTrain     6[0m:   0%|                                                          | 0/799 [00:00<?, ?it/s…

17 Jun 21:18    INFO  epoch 6 training [time: 25.74s, train_loss1: 715937.7354, train_loss2: 321919.0356]


[1;35mTrain     7[0m:   0%|                                                          | 0/799 [00:00<?, ?it/s…

17 Jun 21:18    INFO  epoch 7 training [time: 25.88s, train_loss1: 711695.8633, train_loss2: 329243.8239]


[1;35mTrain     8[0m:   0%|                                                          | 0/799 [00:00<?, ?it/s…

17 Jun 21:19    INFO  epoch 8 training [time: 25.99s, train_loss1: 707820.9816, train_loss2: 345470.7531]


[1;35mTrain     9[0m:   0%|                                                          | 0/799 [00:00<?, ?it/s…

17 Jun 21:19    INFO  epoch 9 training [time: 25.82s, train_loss1: 704160.5280, train_loss2: 356706.8347]


[1;35mEvaluate   [0m:   0%|                                                        | 0/11358 [00:00<?, ?it/s…

17 Jun 21:21    INFO  epoch 9 evaluating [time: 95.84s, valid_score: 0.087600]
17 Jun 21:21    INFO  valid result: 
recall@10 : 0.0799    mrr@10 : 0.0876    ndcg@10 : 0.0604    hit@10 : 0.1998    precision@10 : 0.0246
17 Jun 21:21    INFO  Saving current: saved/ProtoVQ_VAE-Jun-17-2024_21-15-08.pth


[1;35mTrain    10[0m:   0%|                                                          | 0/799 [00:00<?, ?it/s…

17 Jun 21:21    INFO  epoch 10 training [time: 26.45s, train_loss1: 700797.8506, train_loss2: 357338.2409]


[1;35mTrain    11[0m:   0%|                                                          | 0/799 [00:00<?, ?it/s…

17 Jun 21:21    INFO  epoch 11 training [time: 26.03s, train_loss1: 697940.8705, train_loss2: 352414.3676]


[1;35mTrain    12[0m:   0%|                                                          | 0/799 [00:00<?, ?it/s…

17 Jun 21:22    INFO  epoch 12 training [time: 26.07s, train_loss1: 695044.7509, train_loss2: 343778.2518]


[1;35mTrain    13[0m:   0%|                                                          | 0/799 [00:00<?, ?it/s…

17 Jun 21:22    INFO  epoch 13 training [time: 27.23s, train_loss1: 692487.9475, train_loss2: 332088.3042]


[1;35mTrain    14[0m:   0%|                                                          | 0/799 [00:00<?, ?it/s…

17 Jun 21:23    INFO  epoch 14 training [time: 26.20s, train_loss1: 690353.6406, train_loss2: 320110.7780]


[1;35mTrain    15[0m:   0%|                                                          | 0/799 [00:00<?, ?it/s…

17 Jun 21:23    INFO  epoch 15 training [time: 25.79s, train_loss1: 688302.0112, train_loss2: 312056.9519]


[1;35mTrain    16[0m:   0%|                                                          | 0/799 [00:00<?, ?it/s…

OutOfMemoryError: CUDA out of memory. Tried to allocate 22.00 MiB (GPU 0; 3.63 GiB total capacity; 2.23 GiB already allocated; 28.81 MiB free; 2.39 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

: 