# Imports

In [1]:
import pandas as pd
import numpy as np
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as td
# from torchmetrics.retrieval.ndcg import RetrievalNormalizedDCG

from sklearn.preprocessing import MultiLabelBinarizer
# from sklearn.metrics import ndcg_score

import shutil
import time
import gc
from pathlib import Path
import ast
import typing as tp
import random
from collections import Counter
from tqdm.autonotebook import tqdm, trange
import wandb
import warnings
from joblib import Parallel, delayed
from math import log2

torch.multiprocessing.set_start_method('fork', force=True)
warnings.filterwarnings("ignore", category=DeprecationWarning)
np.random.seed(31337)

  from tqdm.autonotebook import tqdm, trange


In [2]:
from typing import Callable, Literal
import numpy as np
import torch


class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self,
                 patience: int = 7,
                 threshold = 0,
                 threshold_mode: Literal['rel', 'abs'] = 'abs',
                 verbose: bool = False,
                 path: str = 'checkpoint.pt',
                 trace_func: Callable = print
                 ):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement.
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print
        """
        self.patience = patience
        self.verbose = verbose
        self.threshold_mode = threshold_mode
        self.counter = 0
        self.best_val_loss = None
        self.early_stop = False
        self.val_loss_min = np.inf
        self.threshold = threshold
        self.path = path
        self.trace_func = trace_func

    def _significant_improvement(self, val_loss) -> bool:
        if self.threshold_mode == 'abs':
            return self.best_val_loss - val_loss > self.threshold
        else:
            return (self.best_val_loss - val_loss) / self.best_val_loss > self.threshold

    def __call__(self, val_loss, model):
        # Check if validation loss is nan
        if np.isnan(val_loss):
            self.trace_func("Validation loss is NaN. Ignoring this epoch.")
            return

        if self.best_val_loss is None:
            self.best_val_loss = val_loss
            # self.save_checkpoint(val_loss, model)
        elif self._significant_improvement(val_loss):
            self.best_val_loss = val_loss
            # self.save_checkpoint(val_loss, model)
            self.counter = 0  # Reset counter since improvement occurred
        else:
            # No significant improvement
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True

    def save_checkpoint(self, val_loss: float, model: torch.nn.Module):
        '''Saves model when validation loss decreases.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

# Preprocessing

In [3]:
# folder = Path("/content/")
folder = Path("data/data_kion")
users_df = pd.read_csv(folder / "users_processed.csv")
items_df = pd.read_csv(folder / "items_processed.csv")
interactions_df = pd.read_csv(folder / "interactions_processed.csv")

## Users preprocessing

In [4]:
users_df.head()

Unnamed: 0,user_id,age,income,sex,kids_flg
0,973171,age_25_34,income_60_90,M,True
1,962099,age_18_24,income_20_40,M,False
2,1047345,age_45_54,income_40_60,F,False
3,721985,age_45_54,income_20_40,F,False
4,704055,age_35_44,income_60_90,F,False


Закодируем возраст и доход числами (по возрастанию от 0), а тем юзерам, у которых они неизвестны, заполним их медианой (категорией, в которую попадает медиана среднего по диапазонам категорий).

In [5]:
sorted_age_categories = sorted(users_df['age'].unique(), key=lambda s: float(s.split('_')[1] if len(s.split('_')) == 3 else np.inf))
age_mapper = {age: id for id, age in enumerate(sorted_age_categories)}
median_age = users_df[users_df['age'] != 'age_unknown']['age'].map(lambda s: (float(s.split('_')[2] if float(s.split('_')[2]) < np.inf else s.split('_')[1]) + float(s.split('_')[1])) / 2).median()
age_fill_value = None
for age_cat in sorted_age_categories:
    low, high = age_cat.split('_')[1:]
    if int(low) < median_age < int(high):
        age_fill_value = age_cat
        break
age_mapper['age_unknown'] = age_mapper[age_fill_value]

sorted_income_categories = sorted(users_df['income'].unique(), key=lambda s: float(s.split('_')[1] if len(s.split('_')) == 3 else np.inf))
income_mapper = {income: id for id, income in enumerate(sorted_income_categories)}
median_income = users_df[users_df['income'] != 'income_unknown']['income'].map(lambda s: (float(s.split('_')[2] if float(s.split('_')[2]) < np.inf else s.split('_')[1]) + float(s.split('_')[1])) / 2).median()
income_fill_value = None
for income_cat in sorted_income_categories:
    low, high = income_cat.split('_')[1:]
    if int(low) < median_income < int(high):
        income_fill_value = income_cat
        break
income_mapper['income_unknown'] = income_mapper[income_fill_value]
income_mapper

sex_mapper = {'M': -1, 'sex_unknown': 0, 'F': 1}

age_mapper, income_mapper, sex_mapper

({'age_18_24': 0,
  'age_25_34': 1,
  'age_35_44': 2,
  'age_45_54': 3,
  'age_55_64': 4,
  'age_65_inf': 5,
  'age_unknown': 2},
 {'income_0_20': 0,
  'income_20_40': 1,
  'income_40_60': 2,
  'income_60_90': 3,
  'income_90_150': 4,
  'income_150_inf': 5,
  'income_unknown': 1},
 {'M': -1, 'sex_unknown': 0, 'F': 1})

In [6]:
users_df['age'] = users_df['age'].map(age_mapper)
users_df['income'] = users_df['income'].map(income_mapper)
users_df['sex'] = users_df['sex'].map(sex_mapper)
users_df['kids_flg'] = users_df['kids_flg']
users_df = users_df.astype(np.int32)
users_df.head()

Unnamed: 0,user_id,age,income,sex,kids_flg
0,973171,1,3,-1,1
1,962099,0,1,-1,0
2,1047345,3,2,1,0
3,721985,3,1,1,0
4,704055,2,3,1,0


In [7]:
users_df.info(verbose=True)

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 840197 entries, 0 to 840196
Data columns (total 5 columns):
 #   Column    Non-Null Count   Dtype
---  ------    --------------   -----
 0   user_id   840197 non-null  int32
 1   age       840197 non-null  int32
 2   income    840197 non-null  int32
 3   sex       840197 non-null  int32
 4   kids_flg  840197 non-null  int32
dtypes: int32(5)
memory usage: 16.0 MB


## Items preprocessing

In [8]:
items_df['genres'] = items_df['genres'].map(ast.literal_eval)
items_df['countries'] = items_df['countries'].map(lambda s: list(set(s.split(', '))))

In [9]:
items_df.head()

Unnamed: 0,item_id,content_type,title,title_orig,genres,countries,for_kids,age_rating,studios,directors,actors,description,keywords,release_year_cat
0,10711,film,поговори с ней,Hable con ella,"[драмы, детективы, мелодрамы]",[испания],False,16,unknown,педро альмодовар,"Адольфо Фернандес, Ана Фернандес, Дарио Гранди...",Мелодрама легендарного Педро Альмодовара «Пого...,"Поговори, ней, 2002, Испания, друзья, любовь, ...",2000_2010
1,2508,film,голые перцы,Search Party,"[приключения, комедии]",[сша],False,16,unknown,скот армстронг,"Адам Палли, Брайан Хаски, Дж.Б. Смув, Джейсон ...",Уморительная современная комедия на популярную...,"Голые, перцы, 2014, США, друзья, свадьбы, прео...",2010_2020
2,10716,film,тактическая сила,Tactical Force,"[криминал, триллеры, боевики, комедии]",[канада],False,16,unknown,адам п. калтраро,"Адриан Холмс, Даррен Шалави, Джерри Вассерман,...",Профессиональный рестлер Стив Остин («Все или ...,"Тактическая, сила, 2011, Канада, бандиты, ганг...",2010_2020
3,7868,film,45 лет,45 Years,"[драмы, мелодрамы]",[великобритания],False,16,unknown,эндрю хэй,"Александра Риддлстон-Барретт, Джеральдин Джейм...","Шарлотта Рэмплинг, Том Кортни, Джеральдин Джей...","45, лет, 2015, Великобритания, брак, жизнь, лю...",2010_2020
4,16268,film,все решает мгновение,,"[драмы, спорт, мелодрамы]",[ссср],False,12,ленфильм,виктор садовский,"Александр Абдулов, Александр Демьяненко, Алекс...",Расчетливая чаровница из советского кинохита «...,"Все, решает, мгновение, 1978, СССР, сильные, ж...",1970_1980


In [10]:
items_df.columns

Index(['item_id', 'content_type', 'title', 'title_orig', 'genres', 'countries',
       'for_kids', 'age_rating', 'studios', 'directors', 'actors',
       'description', 'keywords', 'release_year_cat'],
      dtype='object')

Закодируем жанры (отдельный бинарный столбец для каждого жанра)

In [11]:
genres_mlb = MultiLabelBinarizer()
genres_one_hot = pd.DataFrame(genres_mlb.fit_transform(items_df['genres']),
                              columns=list(map(lambda s: f'genre_{s}', genres_mlb.classes_)),
                              index=items_df.index,
                              dtype=np.int32)
countries_mlb = MultiLabelBinarizer()
countries_one_hot = pd.DataFrame(countries_mlb.fit_transform(items_df['countries']),
                                 columns=list(map(lambda s: f'country_{s}', countries_mlb.classes_)),
                                 index=items_df.index,
                                 dtype=np.int32)
items_df = pd.concat([items_df, genres_one_hot, countries_one_hot], axis=1).drop(columns=['genres', 'genre_no_genre', 'studios', 'countries'])
items_df.head()

Unnamed: 0,item_id,content_type,title,title_orig,for_kids,age_rating,directors,actors,description,keywords,...,country_хорватия,country_чехия,country_чили,country_швейцария,country_швеция,country_эквадор,country_эстония,country_юар,country_югославия,country_япония
0,10711,film,поговори с ней,Hable con ella,False,16,педро альмодовар,"Адольфо Фернандес, Ана Фернандес, Дарио Гранди...",Мелодрама легендарного Педро Альмодовара «Пого...,"Поговори, ней, 2002, Испания, друзья, любовь, ...",...,0,0,0,0,0,0,0,0,0,0
1,2508,film,голые перцы,Search Party,False,16,скот армстронг,"Адам Палли, Брайан Хаски, Дж.Б. Смув, Джейсон ...",Уморительная современная комедия на популярную...,"Голые, перцы, 2014, США, друзья, свадьбы, прео...",...,0,0,0,0,0,0,0,0,0,0
2,10716,film,тактическая сила,Tactical Force,False,16,адам п. калтраро,"Адриан Холмс, Даррен Шалави, Джерри Вассерман,...",Профессиональный рестлер Стив Остин («Все или ...,"Тактическая, сила, 2011, Канада, бандиты, ганг...",...,0,0,0,0,0,0,0,0,0,0
3,7868,film,45 лет,45 Years,False,16,эндрю хэй,"Александра Риддлстон-Барретт, Джеральдин Джейм...","Шарлотта Рэмплинг, Том Кортни, Джеральдин Джей...","45, лет, 2015, Великобритания, брак, жизнь, лю...",...,0,0,0,0,0,0,0,0,0,0
4,16268,film,все решает мгновение,,False,12,виктор садовский,"Александр Абдулов, Александр Демьяненко, Алекс...",Расчетливая чаровница из советского кинохита «...,"Все, решает, мгновение, 1978, СССР, сильные, ж...",...,0,0,0,0,0,0,0,0,0,0


И так же, как с юзерами, закодируем числами категориальные фичи (которых по одному значению на строку)

In [12]:
content_type_mapper = {'film': 0, 'series': 1}
age_rating_mapper = {v: id for id, v in enumerate(sorted(items_df['age_rating'].unique()))}
release_year_cat_mapper = {y: id for id, y in enumerate(sorted(items_df['release_year_cat'].unique(), key=lambda s: float(s.split('_')[1])))}
content_type_mapper, age_rating_mapper, release_year_cat_mapper

({'film': 0, 'series': 1},
 {np.int64(0): 0,
  np.int64(6): 1,
  np.int64(12): 2,
  np.int64(16): 3,
  np.int64(18): 4,
  np.int64(21): 5},
 {'inf_1920': 0,
  '1920_1930': 1,
  '1930_1940': 2,
  '1940_1950': 3,
  '1950_1960': 4,
  '1960_1970': 5,
  '1970_1980': 6,
  '1980_1990': 7,
  '1990_2000': 8,
  '2000_2010': 9,
  '2010_2020': 10,
  '2020_inf': 11})

In [13]:
items_df['for_kids'] = items_df['for_kids'].astype(np.int32)
items_df['age_rating'] = items_df['age_rating'].map(age_rating_mapper).astype(np.int32)
items_df['content_type'] = items_df['content_type'].map(content_type_mapper).astype(np.int32)
items_df['release_year_cat'] = items_df['release_year_cat'].map(release_year_cat_mapper).astype(np.int32)
items_df = items_df.drop(columns=['title', 'title_orig', 'directors', 'actors', 'description', 'keywords']).astype(np.int32)

items_df.head()

Unnamed: 0,item_id,content_type,for_kids,age_rating,release_year_cat,genre_аниме,genre_биография,genre_боевики,genre_военные,genre_детективы,...,country_хорватия,country_чехия,country_чили,country_швейцария,country_швеция,country_эквадор,country_эстония,country_юар,country_югославия,country_япония
0,10711,0,0,3,9,0,0,0,0,1,...,0,0,0,0,0,0,0,0,0,0
1,2508,0,0,3,10,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,10716,0,0,3,10,0,0,1,0,0,...,0,0,0,0,0,0,0,0,0,0
3,7868,0,0,3,10,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,16268,0,0,2,6,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [14]:
items_df.info(verbose=True)

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 15963 entries, 0 to 15962
Data columns (total 125 columns):
 #    Column                    Dtype
---   ------                    -----
 0    item_id                   int32
 1    content_type              int32
 2    for_kids                  int32
 3    age_rating                int32
 4    release_year_cat          int32
 5    genre_аниме               int32
 6    genre_биография           int32
 7    genre_боевики             int32
 8    genre_военные             int32
 9    genre_детективы           int32
 10   genre_детские             int32
 11   genre_для взрослых        int32
 12   genre_документальное      int32
 13   genre_драмы               int32
 14   genre_исторические        int32
 15   genre_комедии             int32
 16   genre_короткометражные    int32
 17   genre_криминал            int32
 18   genre_мелодрамы           int32
 19   genre_музыка              int32
 20   genre_мультфильмы         int32
 21   genre_мюзи

In [15]:
ITEMS_NUM_CAT_FEATURES = 4
ITEMS_NUM_GENRE_FEATURES = len([g for g in items_df.columns if g.startswith('genre_')])

## Interactions preprocessing

In [16]:
interactions_df.head()

Unnamed: 0,user_id,item_id,last_watch_dt,total_dur,watched_pct
0,176549,9506,2021-05-11,4250,72
1,699317,1659,2021-05-29,8317,100
2,656683,7107,2021-05-09,10,0
3,864613,7638,2021-07-05,14483,100
4,964868,9506,2021-04-30,6725,100


Фильтруем малоактивных юзеров и непопулярные фильмы.

In [17]:
interactions_df['item_id'].value_counts(), interactions_df['user_id'].value_counts()

(item_id
 10440    202457
 15297    193123
 9728     132865
 13865    122119
 4151      91167
           ...  
 2435          1
 7978          1
 10642         1
 13008         1
 9286          1
 Name: count, Length: 15706, dtype: int64,
 user_id
 416206     1341
 1010539     764
 555233      685
 11526       676
 409259      625
            ... 
 690921        1
 255412        1
 264195        1
 150067        1
 337469        1
 Name: count, Length: 962179, dtype: int64)

Оставим юзеров, посмотревших хотя бы 5 фильмов, и фильмы, которые посмотрело хотя бы 5 юзеров

In [18]:
before_filtering_users = interactions_df['user_id'].nunique()
before_filtering_items = interactions_df['item_id'].nunique()

interactions_df = interactions_df[interactions_df.watched_pct > 10]

valid_users = []

neg_films_features = Counter(interactions_df['user_id'])
for user_ids, entries in neg_films_features.most_common():
  if entries >= 5:
    valid_users.append(user_ids)

valid_items = []

neg_films_features = Counter(interactions_df['item_id'])
for item_id, entries in neg_films_features.most_common():
  if entries >= 5:
    valid_items.append(item_id)

interactions_df = interactions_df[interactions_df['user_id'].isin(valid_users)]
interactions_df = interactions_df[interactions_df['item_id'].isin(valid_items)]

print(f"Users before filtering: {before_filtering_users:>7}")
print(f"Users after filtering:  {interactions_df['user_id'].nunique():>7}")
print(f"Items before filtering:  {before_filtering_items:>6}")
print(f"Items after filtering:   {interactions_df['item_id'].nunique():>6}")

Users before filtering:  962179
Users after filtering:   207255
Items before filtering:   15706
Items after filtering:     8823


Переведём фичи в 32-битный int

In [19]:
interactions_df['total_dur'].min(), interactions_df['total_dur'].mean(), interactions_df['total_dur'].max()

(np.int64(6), np.float64(11384.85642516867), np.int64(80411672))

In [20]:
interactions_df[['user_id', 'item_id', 'total_dur', 'watched_pct']] = interactions_df[['user_id', 'item_id', 'total_dur', 'watched_pct']].astype(np.int32)
interactions_df.info(verbose=True)

<class 'pandas.core.frame.DataFrame'>
Index: 2646592 entries, 0 to 5476249
Data columns (total 5 columns):
 #   Column         Dtype 
---  ------         ----- 
 0   user_id        int32 
 1   item_id        int32 
 2   last_watch_dt  object
 3   total_dur      int32 
 4   watched_pct    int32 
dtypes: int32(4), object(1)
memory usage: 80.8+ MB


Поскольку `nn.Embedding` принимает количество треков `num_tracks`, и считает, что их индексы будут [0; `num_tracks`) (что сейчас не так), нужно перемаппить индексы треков (и юзеров заодно) в этот диапазон  

Про часть юзеров, которые есть в `interactions_df`, у нас нет информации. Заполним информацию про них медианными значениями.

In [21]:
users_df.columns

Index(['user_id', 'age', 'income', 'sex', 'kids_flg'], dtype='object')

In [22]:
users_df

Unnamed: 0,user_id,age,income,sex,kids_flg
0,973171,1,3,-1,1
1,962099,0,1,-1,0
2,1047345,3,2,1,0
3,721985,3,1,1,0
4,704055,2,3,1,0
...,...,...,...,...,...
840192,339025,5,0,1,0
840193,983617,0,1,1,1
840194,251008,2,1,0,0
840195,590706,2,1,1,0


In [23]:
unknown_users = set(interactions_df["user_id"]).difference(users_df['user_id'])
unknown_users_df = pd.DataFrame({'user_id' : list(unknown_users),
                                 'age': [age_mapper['age_unknown']] * len(unknown_users),
                                 'income': [income_mapper['income_unknown']] * len(unknown_users),
                                 'sex': [0] * len(unknown_users),
                                 'kids_flg': [0] * len(unknown_users),
                                 }
                )
users_df = pd.concat([users_df, unknown_users_df], axis=0)
users_df

Unnamed: 0,user_id,age,income,sex,kids_flg
0,973171,1,3,-1,1
1,962099,0,1,-1,0
2,1047345,3,2,1,0
3,721985,3,1,1,0
4,704055,2,3,1,0
...,...,...,...,...,...
39287,524269,2,1,0,0
39288,786421,2,1,0,0
39289,786425,2,1,0,0
39290,524284,2,1,0,0


In [24]:
unique_user_ids = np.unique(np.concat([interactions_df["user_id"], users_df['user_id']]))
unique_item_ids = np.unique(np.concat([interactions_df["item_id"], items_df['item_id']]))

# Create mappings (old ID → new consecutive ID starting from 0)
user_id_map = {old_id: np.int32(new_id) for new_id, old_id in enumerate(unique_user_ids)}
user_id_map_to_orig = {np.int32(new_id): old_id for new_id, old_id in enumerate(unique_user_ids)}
item_id_map = {old_id: np.int32(new_id) for new_id, old_id in enumerate(unique_item_ids)}
item_id_map_to_orig = {np.int32(new_id): old_id for new_id, old_id in enumerate(unique_item_ids)}

In [25]:
len(user_id_map), len(item_id_map)

(879489, 15963)

In [26]:
users_df['user_id'] = users_df['user_id'].map(user_id_map)
items_df['item_id'] = items_df['item_id'].map(item_id_map)
interactions_df['user_id'] = interactions_df['user_id'].map(user_id_map)
interactions_df['item_id'] = interactions_df['item_id'].map(item_id_map)

users_df.set_index('user_id', inplace=True)
items_df.set_index('item_id', inplace=True)
interactions_df.head()

Unnamed: 0,user_id,item_id,last_watch_dt,total_dur,watched_pct
0,141413,9178,2021-05-11,4250,72
1,560236,1604,2021-05-29,8317,100
3,692633,7376,2021-07-05,14483,100
5,826986,6454,2021-05-13,11286,100
6,814480,343,2021-08-14,1672,25


In [27]:
def get_triplets(interactions_df: pd.DataFrame) -> pd.DataFrame:
    """
    Extract triplets (user_id, item_id, watched_pct) from interactions DataFrame.
    """
    positives = interactions_df[interactions_df["watched_pct"] > 80].copy()
    negatives = interactions_df[interactions_df["watched_pct"] < 30].copy()
    positives = positives[["user_id", "item_id"]].rename(columns={"item_id": "film_pos"})

    NUM_NEGATIVE_SAMPLES = 10
    negatives_grouped = (
        negatives.groupby('user_id')
        .apply(lambda x: x.sample(n=min(len(x), NUM_NEGATIVE_SAMPLES), random_state=42))
        .reset_index(drop=True)
        .rename(columns={"item_id": "film_neg"})
    )

    triplets = positives.merge(negatives_grouped, on="user_id", how="inner")
    return triplets

In [28]:
interactions_df = interactions_df.sort_values(by=['last_watch_dt'], ascending=True)

train_data = get_triplets(interactions_df.iloc[:int(len(interactions_df) * 0.8)])
val_data = get_triplets(interactions_df.iloc[int(len(interactions_df) * 0.8):int(len(interactions_df) * 0.9)])
test_data = get_triplets(interactions_df.iloc[int(len(interactions_df) * 0.9):])

len(train_data), len(val_data), len(test_data)

(3844664, 119954, 127847)

In [29]:
interactions_df.iloc[:int(len(interactions_df) * 0.8)]

Unnamed: 0,user_id,item_id,last_watch_dt,total_dur,watched_pct
1196595,44185,7128,2021-03-13,6664,100
4376993,846792,1229,2021-03-13,2618,45
2971068,732651,8172,2021-03-13,82197,100
4376798,33977,9767,2021-03-13,9520,100
632848,127388,418,2021-03-13,5921,100
...,...,...,...,...,...
5419963,565580,4949,2021-08-02,995,14
1460003,184495,11764,2021-08-02,8862,15
1021880,223970,7159,2021-08-02,777,12
4467899,475182,10945,2021-08-02,2996,40


In [None]:
NO_MOVIE = len(item_id_map)  # id bigger than any item id
PAD_SIZE = 30  # users' interactions are padded to have the same number of items
MOVIE_IDS_DTYPE = torch.int16
USER_IDS_DTYPE = torch.int32
MAIN_DTYPE = torch.float32
NO_MOVIE

15963

In [None]:
def pad_with_specific_value(tensor, target_length, pad_value):
    """`tensor` is considered to have only unique values"""
    shuffled = tensor[torch.randperm(len(tensor))[:target_length]]
    cur_length = len(shuffled)
    if cur_length < target_length:
        padding = torch.full((target_length - cur_length,), pad_value, dtype=shuffled.dtype) 
        return torch.cat([shuffled, padding])
    else:
        return shuffled

def group_by_user(triplets: pd.DataFrame, column: Literal['film_pos', 'film_neg'] = 'film_pos') -> pd.DataFrame:
    """Group interactions by user and convert to DataFrame with column 'interactions' - Tensors of unique film IDs"""
    groupped_users = triplets.loc[:, ['user_id', 'film_pos', 'film_neg']].groupby('user_id').apply(lambda x: x[column].tolist())
    groupped_users = pd.DataFrame({'interactions': groupped_users.values}, index=groupped_users.index)
    groupped_users['interactions'] = groupped_users['interactions'].apply(lambda x: torch.unique(torch.tensor(x, dtype=MOVIE_IDS_DTYPE)))
    return groupped_users

In [32]:
# How to deep copy dataframe with lists

# import copy

# A = pd.DataFrame({'inter': [[1, 2, 3], [4, 3, 7, 9]]})
# B = A.copy(deep=True)
# B['inter'] = B['inter'].apply(copy.deepcopy)  # deep-copy each list inside the column

# A.loc[0, 'inter'][1] = 13
# print("A:", A, "\nB:", B, sep="\n")

# Dataset

In [None]:
def collate_fn(data: list[tuple]):
    return data

class DSSMDataset(td.Dataset):
    def __init__(self, triplets: pd.DataFrame, type: Literal['train', 'val', 'test'] = 'train'):
        super().__init__()
        self.triplets = triplets
        self.grouped_pos_users_interactions = group_by_user(triplets, column='film_pos')  # just like padded users but not padded
        self.padded_users = self.grouped_pos_users_interactions.copy(deep=True)
        self.padded_users['interactions'] = self.padded_users['interactions'].apply(lambda x: pad_with_specific_value(x, PAD_SIZE, NO_MOVIE).to(dtype=torch.int32))  # int32 required by EmbeddingBag

        if type != 'test':
            self.grouped_neg_users_interactions = group_by_user(triplets, column='film_neg')  # all negative interactions grouped by users into tensors
            self.all_users = torch.tensor(self.grouped_neg_users_interactions.index.tolist(), dtype=USER_IDS_DTYPE)  # all unique users from triplets

    def __getitem__(self, index: int):
        cur_triplet = self.triplets.iloc[index]
        user_id = cur_triplet['user_id']
        user_info = torch.tensor(users_df.loc[user_id].values, dtype=MAIN_DTYPE)
        user_interactions = self.padded_users.loc[user_id]['interactions']
        pos_films_features = torch.tensor(items_df.loc[cur_triplet['film_pos']].values, dtype=MAIN_DTYPE)
        neg_films_features = torch.tensor(items_df.loc[cur_triplet['film_neg']].values, dtype=MAIN_DTYPE)
        pos_ids = self.grouped_pos_users_interactions.loc[user_ids]['interactions']
        return user_id, user_interactions, user_info, pos_films_features, neg_films_features, pos_ids

    def __getitems__(self, index: tp.Sequence[int]):
        cur_triplets = self.triplets.iloc[index]
        user_ids = cur_triplets['user_id']

        user_info = torch.tensor(users_df.loc[user_ids].values, dtype=MAIN_DTYPE)
        user_interactions = torch.stack(self.padded_users.loc[user_ids]['interactions'].tolist())
        pos_films_features = torch.tensor(items_df.loc[cur_triplets['film_pos']].values, dtype=MAIN_DTYPE)
        neg_films_features = torch.tensor(items_df.loc[cur_triplets['film_neg']].values, dtype=MAIN_DTYPE)
        pos_ids = self.grouped_pos_users_interactions.loc[user_ids]['interactions'].tolist()

        return user_ids, user_interactions, user_info, pos_films_features, neg_films_features, pos_ids

    def get_all_users(self) -> Tensor:
        """
        Get IDs of all users, that have positive interactions in this dataset
        
        Returns:
            users (Tensor, dtype=torch.int32) : All user IDs
        """
        return self.all_users

    def get_user_data(self, user_id: int) -> tuple[Tensor, Tensor, Tensor, Tensor]:
        """
        Get all data in a dataset of one user
        
        Args:
            user_id (int) : ID of the user
        
        Returns:
            user_interactions (Tensor, dtype=torch.int16) : interactions of this user stored in the dataset
            user_info (Tensor, dtype=torch.float32) : Features of this user
            pos_film_features (Tensor, dtype=torch.float32) : Features of positive films found in triplets for this user in the dataset
            neg_film_features (Tensor, dtype=torch.float32) : Features of negative films found in triplets for this user in the dataset
        """
        # TODO: кажется, не надо возвращать pos_film_features и neg_film_features потому что есть функция DSSM.encode_user(), которая 
        # просто посчитает эмбеддинг юзера по его истории взаимодействий (user_interactions) и его фичам (user_info)
        # TODO: можно ли это сделать для нескольких user_id сразу для ускорения?  
        pass

    def __len__(self):
        return len(self.triplets)

# Model

In [34]:
class ItemNet(nn.Module):
    def __init__(self,
                 dim_embedding: int,
                 dim_input: int,
                 dim_hidden: int = 96,
                 activation: tp.Callable[[Tensor], Tensor] = nn.ReLU()
                 ) -> None:
        super().__init__()
        self.cat_embedding = nn.Linear(ITEMS_NUM_CAT_FEATURES, dim_hidden)
        self.genre_embedding = nn.Linear(ITEMS_NUM_GENRE_FEATURES, dim_hidden)
        self.country_embedding = nn.Linear(dim_input - ITEMS_NUM_CAT_FEATURES - ITEMS_NUM_GENRE_FEATURES, dim_hidden)
        self.dense_block = nn.Sequential(
            nn.Linear(dim_hidden * 3, dim_hidden),
            activation,
        )
        self.output_layer = nn.Linear(dim_hidden, dim_embedding, bias=False)
        self.norm = nn.LayerNorm(dim_embedding)

    def forward(self, item_features: Tensor) -> Tensor:
        cat_features = item_features[:, :ITEMS_NUM_CAT_FEATURES]
        genre_features = item_features[:, ITEMS_NUM_CAT_FEATURES:ITEMS_NUM_CAT_FEATURES + ITEMS_NUM_GENRE_FEATURES]
        country_features = item_features[:, ITEMS_NUM_CAT_FEATURES + ITEMS_NUM_GENRE_FEATURES:]

        cat_emb = self.cat_embedding(cat_features)
        genre_emb = self.genre_embedding(genre_features)
        country_emb = self.country_embedding(country_features)

        pop_genre = torch.concat([cat_emb, genre_emb, country_emb], axis=1)
        features = self.dense_block(pop_genre)
        output = self.output_layer(features)

        return self.norm(output)


class UserNet(nn.Module):
    def __init__(self,
                 dim_embedding: int,
                 num_items: int,
                 dim_user_features: int,
                 activation: tp.Callable[[Tensor], Tensor] = nn.ReLU()
                 ) -> None:                              # | +1 for the NO_MOVIE element
        super().__init__()                               # V
        self.track_embeddings = nn.EmbeddingBag(num_items + 1, dim_embedding, padding_idx=num_items)
        self.info_embedding = nn.Sequential(
            nn.Linear(dim_user_features, int(dim_embedding // 2)),
            nn.ReLU(),
            nn.Linear(int(dim_embedding // 2), dim_embedding),
            activation
        )
        self.dense_layer = nn.Sequential(
            nn.Linear(dim_embedding, int(dim_embedding // 2)),
            nn.ReLU(),
            nn.Linear(int(dim_embedding // 2), dim_embedding),
            activation
        )
        self.output_layer = nn.Linear(3*dim_embedding, dim_embedding, bias=False)
        self.norm = nn.LayerNorm(dim_embedding)
        self.num_items = num_items
        self.dim_user_features = dim_user_features

    def forward(self, user_interactions: Tensor, user_info: Tensor) -> Tensor:
        # print(f'EMBEDDING BAG MAX INPUT: {user_ids.max()} while was ready for {self.num_items}')
        interactions_emb = self.track_embeddings(user_interactions)
        info_emb = self.info_embedding(user_info.float())
        features = self.dense_layer(interactions_emb)
        x = torch.cat([interactions_emb, features, info_emb], dim=1)
        output = self.output_layer(x)
        return self.norm(output)

In [35]:
class DSSM(nn.Module):
    def __init__(self,
                 dim_item_features: int,
                 dim_user_features: int,
                 num_items: int,
                 embedding_dim: int = 100,
                 lr: float = 1e-3,
                 triplet_loss_margin: float = 0.4,
                 weight_decay: float = 1e-3,
                 log_to_prog_bar: bool = True,
                 ) -> None:
        super().__init__()
        self.lr = lr
        self.triplet_loss_margin = triplet_loss_margin
        self.weight_decay = weight_decay
        self.log_to_prog_bar = log_to_prog_bar
        self.item_net = ItemNet(embedding_dim, dim_item_features)
        self.user_net = UserNet(embedding_dim, num_items, dim_user_features)

    def forward(self,
                user_intercations: Tensor,
                user_info: Tensor,
                item_features_pos: Tensor,
                item_features_neg: Tensor,
                ) -> tuple[Tensor, Tensor, Tensor]:
        """Returns embeddings of users, positive items and negative items"""
        anchor = self.user_net(user_intercations, user_info)
        pos = self.item_net(item_features_pos)
        neg = self.item_net(item_features_neg)

        return anchor, pos, neg

    def encode_user(self, user_interactions: Tensor, user_info: Tensor) -> Tensor:
        return self.user_net(user_interactions, user_info)

    def encode_item(self, item_features: Tensor) -> Tensor:
        return self.item_net(item_features)

# Training

## Define metrics

In [None]:
a = torch.tensor([1, 2, 3, 4, 5], dtype=MOVIE_IDS_DTYPE)
b = torch.tensor([1, 4], dtype=MOVIE_IDS_DTYPE)
torch.isin(a, b).sum()

tensor(2)

In [None]:
# # def recall_at_k(preds: list[list[int]], targets: list[int], k: int = 10) -> float:
# #     hits = 0
# #     for pred, true in zip(preds, targets):
# #         if true in pred[:k]:
# #             hits += 1
# #     return hits / len(targets)

# # def ndcg_at_k(preds: list[list[int]], targets: list[int], k: int = 10) -> float:
# #     # return ndcg_score(y_true=targets, y_score=preds, k=k)
# #     total = 0.0
# #     for pred, true in zip(preds, targets):
# #         if true in pred[:k]:
# #             rank = pred.index(true) + 1
# #             total += 1.0 / np.log2(rank + 1)
# #     return total / len(targets)

# # def mrr_at_k(preds: list[list[int]], targets: list[int], k: int = 10) -> float:
# #     total = 0.0
# #     for pred, true in zip(preds, targets):
# #         if true in pred[:k]:
# #             rank = pred.index(true) + 1
# #             total += 1.0 / rank
# #     return total / len(targets)

# def precision_at_k(y_true, y_pred, k):
#     precisions = []
#     for true, pred in zip(y_true, y_pred):
#         pred_k = pred[:k]
#         true_set = set(true)
#         precisions.append(len(set(pred_k) & true_set) / k)
#     return np.mean(precisions)

# def recall_at_k(y_true, y_pred, k):
#     recalls = []
#     for true, pred in tqdm(zip(y_true, y_pred), desc='Calculating Recall@k', total=len(y_true)):
#         pred_k = pred[:k]
#         true_set = set(true)
#         if len(true_set) == 0:
#             recalls.append(0.0)
#         else:
#             recalls.append(len(set(pred_k) & true_set) / len(true_set))
#     return np.mean(recalls)

# def dcg_at_k(rel_scores, k):
#     return np.sum([
#         rel / np.log2(idx + 2) for idx, rel in enumerate(rel_scores[:k])
#     ])

# def ndcg_at_k(y_true, y_pred, k):
#     ndcgs = []
#     for true, pred in zip(y_true, y_pred):
#         rel = [1 if item in true else 0 for item in pred[:k]]
#         dcg = dcg_at_k(rel, k)
#         ideal_rel = sorted(rel, reverse=True)
#         idcg = dcg_at_k(ideal_rel, k)
#         ndcgs.append(dcg / idcg if idcg > 0 else 0.0)
#     return np.mean(ndcgs)

# def mrr_at_k(y_true, y_pred, k):
#     mrrs = []
#     for true, pred in zip(y_true, y_pred):
#         for rank, item in enumerate(pred[:k], start=1):
#             if item in true:
#                 mrrs.append(1 / rank)
#                 break
#         else:
#             mrrs.append(0.0)
#     return np.mean(mrrs)


def make_true_sets(y_true):
    true_sets = []
    for t in y_true:
        if torch.is_tensor(t):
            arr = t.cpu().numpy().tolist()
        else:
            arr = list(t)
        true_sets.append(set(arr))
    return true_sets

def _recall_single(relevant_tensor, pred_tensor, k):
    if (num_relevant_items := len(relevant_tensor)) == 0:
        # TODO: такие сэмплы нам вообще не нужны в валидации, лучше их убрать заранее
        return None
    return torch.isin(pred_tensor[:k], relevant_tensor).sum().float() / num_relevant_items

def recall_at_k(y_true, y_pred, k, n_jobs=-1):
    recalls = Parallel(n_jobs=n_jobs)(
        delayed(_recall_single)(ts, pred, k)
        for ts, pred in tqdm(zip(y_true, y_pred), desc='Calculating Recall@k', total=len(y_true))
    )
    return np.mean(recalls)

def precision_at_k(y_true_sets, y_pred, k):
    precisions = []
    for true_set, pred in zip(y_true_sets, y_pred):
        pred_k = pred[:k]
        precisions.append(len(set(pred_k) & true_set) / k)
    return np.mean(precisions)

def _dcg(rel, k):
    score = 0.0
    for i, r in enumerate(rel[:k]):
        score += r / log2(i + 2)
    return score

def _ndcg_single(true_set, pred, k):
    rel = [1 if item in true_set else 0 for item in pred[:k]]
    dcg = _dcg(rel, k)
    ideal = sorted(rel, reverse=True)
    idcg = _dcg(ideal, k)
    return dcg / idcg if idcg > 0 else 0.0

def ndcg_at_k(y_true_sets, y_pred, k, n_jobs=-1):
    ndcgs = Parallel(n_jobs=n_jobs)(
        delayed(_ndcg_single)(ts, pred, k)
        for ts, pred in tqdm(zip(y_true_sets, y_pred), desc='Calculating NDCG@k', total=len(y_true_sets))
    )
    return np.mean(ndcgs)

def _mrr_single(true_set, pred, k):
    for rank, item in enumerate(pred[:k], start=1):
        if item in true_set:
            return 1.0 / rank
    return 0.0

def mrr_at_k(y_true_sets, y_pred, k, n_jobs=-1):
    mrrs = Parallel(n_jobs=n_jobs)(
        delayed(_mrr_single)(ts, pred, k)
        for ts, pred in tqdm(zip(y_true_sets, y_pred), desc='Calculating MRR@k', total=len(y_true_sets))
    )
    return np.mean(mrrs)

## Configure training

In [None]:
print('Get train dataset')
train_dataset = DSSMDataset(train_data, type='train')
print('Get validation dataset')
val_dataset = DSSMDataset(val_data, type='val')
print('Get test dataset')
test_dataset = DSSMDataset(test_data, type='test')

Get train dataset
Get validation dataset
Get test dataset


In [39]:
# Hyperparams
EPOCHS = 50
BATCH_SIZE = 1024 * 16
NUM_WORKERS = 0
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'

EMBEDDING_DIM = 128
LR = 1e-4
WEIGHT_DECAY = 1e-4
TRIPLET_LOSS_MARGIN = 0.4
EXPERIMENT_NAME = 'wtf_1'
K = 20
LOG_TO_WANDB = False

train_dataloader = td.DataLoader(train_dataset, collate_fn=collate_fn, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_dataloader = td.DataLoader(val_dataset, collate_fn=collate_fn, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
test_dataloader = td.DataLoader(test_dataset, collate_fn=collate_fn, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

model = DSSM(dim_item_features=items_df.shape[1],
             dim_user_features=users_df.shape[1],
             num_items=len(item_id_map),
             embedding_dim=EMBEDDING_DIM
             ).to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
# lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, threshold=3e-3, threshold_mode='abs')
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    patience=3,
    threshold=1e-3,
    threshold_mode='abs'
)
early_stopping = EarlyStopping(
    patience=7,
    threshold=1e-3,
    threshold_mode='rel',
    verbose=True,
    trace_func=tqdm.write,
    path=f'model_weights/{EXPERIMENT_NAME}.ckpt'
)

## Training loop

In [None]:
if LOG_TO_WANDB:
    entity = "xenz5240-higher-school-of-economics"
    wandb.init(entity=entity, project='reelsrec-dssm', name=EXPERIMENT_NAME)

    wandb.config.update({
        "epochs": EPOCHS,
        "batch_size": BATCH_SIZE,
        "embedding_dim": EMBEDDING_DIM,
        "learning_rate": LR,
        "weight_decay": WEIGHT_DECAY,
        "triplet_loss_margin": TRIPLET_LOSS_MARGIN,
        "device": DEVICE
    })

for epoch in trange(EPOCHS, position=0, desc='Training', unit='epoch'):
    model.train()
    epoch_losses = []
    batch_times = []
    for step, batch in enumerate(tqdm(train_dataloader, position=1, desc=f'Epoch {epoch + 1}/{EPOCHS}', unit='batch')):
        _, user_inters, user_info, pos_features, neg_features, _ = batch
        optimizer.zero_grad()
        batch_users_embs, batch_positive_films_embs, batch_negative_films_embs = model(user_inters.to(DEVICE),
                                                                                       user_info.to(DEVICE),
                                                                                       pos_features.to(DEVICE),
                                                                                       neg_features.to(DEVICE))
        loss = F.triplet_margin_loss(batch_users_embs, batch_positive_films_embs, batch_negative_films_embs, margin=TRIPLET_LOSS_MARGIN)
        loss.backward()
        optimizer.step()

        epoch_losses.append(loss.item())

        if LOG_TO_WANDB and step % 20 == 0:
            wandb.log({"train/loss": loss.item(),
                       "epoch": epoch + 1,
                       "step": step})

    mean_train_loss = float(np.mean(epoch_losses))
    if LOG_TO_WANDB:
        wandb.log({"train/epoch_loss": mean_train_loss, "epoch": epoch + 1})

    # Evaluation
    model.eval()
    val_losses = []
    all_predictions = []
    all_references = []
    # Embed all items
    items_df.sort_index(inplace=True)  # Ensure items are sorted by their IDs
    item_features = torch.tensor(items_df.values, dtype=MAIN_DTYPE, device=DEVICE)
    EMPTY_WATCH_HISTORY = torch.empty(0, dtype=MOVIE_IDS_DTYPE)
    with torch.no_grad():
        all_items_embeds = model.encode_item(item_features)
        
        #############################################################################################
        # new validation code (rough draft)
        #############################################################################################
        # TODO: тут нужно что-то такое
        val_users = val_dataset.get_all_users()
        val_user_embs = []
        for user_id in val_users:
            user_inters, user_info, _, _ = train_dataset.get_user_data(user_id)  
            user_emb = model.encode_user(user_inters, user_info)  # TODO проверить размерности и тензоры это должны быть или нет 
            # сохранить эмбеддинги
            val_user_embs.append(user_emb)
            
        # строим матрицу расстояний
        distance_matrix = torch.cdist(val_user_embs, all_items_embeds, p=2.0)
        val_recommendations = distance_matrix.argsort(dim=1, descending=False).to(dtype=MOVIE_IDS_DTYPE, device='cpu')  # (batch_users x all_items)

        # TODO: отфильтровать из рекомендаций интеракции из трейна
        filtered_val_recommendations = ...

        # сравнить полученные рекомендации с позитивными интеракциями валидации
        # TODO написать для этого метод DSSMDataset.get_positive_interactions(user_ids: Tensor)?
        val_positive_interactions = []
        for user_id in val_users:
            user_val_positive_interactions = val_dataset.grouped_pos_users_interactions.loc[user_id, 'interactions']
            val_positive_interactions.append(user_val_positive_interactions)

        # считаем метрики
        recall = recall_at_k(val_positive_interactions, filtered_val_recommendations)
        ...

        #############################################################################################
        # old validation code starts here
        #############################################################################################
        for batch in tqdm(val_dataloader, position=1, desc=f'Validation', unit='batch'):
            user_ids, user_inters, user_info, pos_features, neg_features, pos_ids = batch
            batch_users_embs, batch_positive_films_embs, batch_negative_films_embs = model(user_inters.to(DEVICE),
                                                                                           user_info.to(DEVICE),
                                                                                           pos_features.to(DEVICE),
                                                                                           neg_features.to(DEVICE))
            val_loss = F.triplet_margin_loss(batch_users_embs, batch_positive_films_embs, batch_negative_films_embs, margin=TRIPLET_LOSS_MARGIN)
            val_losses.append(val_loss.item())

            # Films sorted by distance to the user
            distance_matrix = torch.cdist(batch_users_embs, all_items_embeds, p=2.0)
            batch_recommendations = distance_matrix.argsort(dim=1, descending=False).to(dtype=MOVIE_IDS_DTYPE, device='cpu')  # (batch_users x all_items)
            del user_inters, user_info, pos_features, neg_features, batch_positive_films_embs, batch_negative_films_embs, batch_users_embs
            torch.cuda.empty_cache()
            gc.collect()

            # Get all films each user has interacted with (from both film_pos and film_neg)
            user_interacted_films: dict[int, Tensor] = (
                train_dataset.triplets[train_dataset.triplets['user_id'].isin(user_ids.tolist())]
                .groupby('user_id')[['film_pos', 'film_neg']]
                .apply(lambda x: torch.tensor(pd.unique(x.values.ravel('K')).tolist(), dtype=MOVIE_IDS_DTYPE))
            ).to_dict()

            # For each user remove watched films from their recommendations
            filtered_recommendations = [None] * len(user_ids)
            for i, (user_ids, recommendations) in tqdm(enumerate(zip(user_ids, batch_recommendations)),
                                                       desc='Filtering recommendations',
                                                       total=len(user_ids),
                                                       leave=False):
                watched_films = user_interacted_films.get(user_ids, EMPTY_WATCH_HISTORY)
                filtered_recommendations[i] = recommendations[torch.isin(recommendations, watched_films, invert=True)]

            # filtered_recommendations is a list of MOVIES_IDS_DTYPE cpu tensors of variable-length (batch_size x O(all_items))
            all_predictions.extend(filtered_recommendations)
            all_references.extend(pos_ids)

    mean_val_loss = float(np.mean(val_losses))
    torch.cuda.empty_cache()
    gc.collect()

    # tqdm.write('Make true sets')
    # all_references = make_true_sets(all_references)

    recall = recall_at_k(all_predictions, all_references, k=K, n_jobs=8)
    tqdm.write(f'Mean val los: {mean_val_loss:.4f} | Recall: {recall}')
    # precision = precision_at_k(all_predictions, all_references, k=K)
    # ndcg   = ndcg_at_k(all_predictions, all_references, k=K)
    # mrr    = mrr_at_k(all_predictions, all_references, k=K)
    # tqdm.write(f'Mean val los: {mean_val_loss:.4f} | Recall: {recall:.4f} | Precision: {precision:.4f} | NDCG: {ndcg:.4f} | MRR: {mrr:.4f}')
    
    #############################################################################################
    # old validation code ends here
    #############################################################################################

    if LOG_TO_WANDB:
        wandb.log({
            "val/loss": mean_val_loss,
            f"val/recall@{K}": recall,
            # f"val/precision@{K}": precision,
            # f"val/ndcg@{K}": ndcg,
            # f"val/mrr@{K}": mrr,
            "train/lr": optimizer.param_groups[0]['lr'],
            "epoch": epoch + 1,
        })

    lr_scheduler.step(mean_val_loss)
    early_stopping(mean_val_loss, model)
    if early_stopping.early_stop:
        break

Training:   0%|          | 0/50 [00:00<?, ?epoch/s]

Epoch 1/50:   0%|          | 0/235 [00:00<?, ?batch/s]

Validation:   0%|          | 0/8 [00:00<?, ?batch/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/5266 [00:00<?, ?it/s]

All predictions:
[tensor([14117,  3633, 14452,  ...,  3674,  6675,  8967], dtype=torch.int16), tensor([14117,  3633, 14452,  ...,  3674,  6675,  8967], dtype=torch.int16), tensor([14117,  3633, 14452,  ...,  3674,  6675,  8967], dtype=torch.int16), tensor([ 9879,  4874, 13651,  ..., 13317, 11325,  9193], dtype=torch.int16), tensor([10475, 14167, 15220,  ...,  7105, 14543,  9416], dtype=torch.int16), tensor([10475, 14167, 15220,  ...,  7105, 14543,  9416], dtype=torch.int16), tensor([10475, 14167, 15220,  ...,  7105, 14543,  9416], dtype=torch.int16), tensor([10475, 14167, 15220,  ...,  7105, 14543,  9416], dtype=torch.int16), tensor([10475, 14167, 15220,  ...,  7105, 14543,  9416], dtype=torch.int16), tensor([10475, 14167, 15220,  ...,  7105, 14543,  9416], dtype=torch.int16)]


Calculating Recall@k:   0%|          | 0/119954 [00:00<?, ?it/s]

Mean val los: 0.1892 | Recall: 0.00042409368325024843


Epoch 2/50:   0%|          | 0/235 [00:00<?, ?batch/s]

Validation:   0%|          | 0/8 [00:00<?, ?batch/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/5266 [00:00<?, ?it/s]

All predictions:
[tensor([14117,  3633, 14452,  ..., 15272,  6675, 15681], dtype=torch.int16), tensor([14117,  3633, 14452,  ..., 15272,  6675, 15681], dtype=torch.int16), tensor([14117,  3633, 14452,  ..., 15272,  6675, 15681], dtype=torch.int16), tensor([ 9879,  4874, 13651,  ..., 11907, 11325,  9193], dtype=torch.int16), tensor([10475, 14167, 15220,  ...,  4238, 15488,  9416], dtype=torch.int16), tensor([10475, 14167, 15220,  ...,  4238, 15488,  9416], dtype=torch.int16), tensor([10475, 14167, 15220,  ...,  4238, 15488,  9416], dtype=torch.int16), tensor([10475, 14167, 15220,  ...,  4238, 15488,  9416], dtype=torch.int16), tensor([10475, 14167, 15220,  ...,  4238, 15488,  9416], dtype=torch.int16), tensor([10475, 14167, 15220,  ...,  4238, 15488,  9416], dtype=torch.int16)]


Calculating Recall@k:   0%|          | 0/119954 [00:00<?, ?it/s]

Mean val los: 0.1892 | Recall: 0.00042409368325024843
EarlyStopping counter: 1 out of 7


Epoch 3/50:   0%|          | 0/235 [00:00<?, ?batch/s]

Validation:   0%|          | 0/8 [00:00<?, ?batch/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/5266 [00:00<?, ?it/s]

All predictions:
[tensor([14117,  3633, 14452,  ..., 15272, 15681,  6675], dtype=torch.int16), tensor([14117,  3633, 14452,  ..., 15272, 15681,  6675], dtype=torch.int16), tensor([14117,  3633, 14452,  ..., 15272, 15681,  6675], dtype=torch.int16), tensor([ 9879,  4874,  3693,  ..., 11907, 11325,  9193], dtype=torch.int16), tensor([10475, 14167, 15220,  ...,  4238, 14543,  9416], dtype=torch.int16), tensor([10475, 14167, 15220,  ...,  4238, 14543,  9416], dtype=torch.int16), tensor([10475, 14167, 15220,  ...,  4238, 14543,  9416], dtype=torch.int16), tensor([10475, 14167, 15220,  ...,  4238, 14543,  9416], dtype=torch.int16), tensor([10475, 14167, 15220,  ...,  4238, 14543,  9416], dtype=torch.int16), tensor([10475, 14167, 15220,  ...,  4238, 14543,  9416], dtype=torch.int16)]


Calculating Recall@k:   0%|          | 0/119954 [00:00<?, ?it/s]

Mean val los: 0.1868 | Recall: 0.00042409368325024843


Epoch 4/50:   0%|          | 0/235 [00:00<?, ?batch/s]

Validation:   0%|          | 0/8 [00:00<?, ?batch/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/5266 [00:00<?, ?it/s]

All predictions:
[tensor([14117, 11294, 14452,  ..., 15493, 15681,  6675], dtype=torch.int16), tensor([14117, 11294, 14452,  ..., 15493, 15681,  6675], dtype=torch.int16), tensor([14117, 11294, 14452,  ..., 15493, 15681,  6675], dtype=torch.int16), tensor([ 4874,  9879,  3693,  ..., 11907, 11325,  9193], dtype=torch.int16), tensor([14167, 15220, 10475,  ..., 14543,  4238,  9416], dtype=torch.int16), tensor([14167, 15220, 10475,  ..., 14543,  4238,  9416], dtype=torch.int16), tensor([14167, 15220, 10475,  ..., 14543,  4238,  9416], dtype=torch.int16), tensor([14167, 15220, 10475,  ..., 14543,  4238,  9416], dtype=torch.int16), tensor([14167, 15220, 10475,  ..., 14543,  4238,  9416], dtype=torch.int16), tensor([14167, 15220, 10475,  ..., 14543,  4238,  9416], dtype=torch.int16)]


Calculating Recall@k:   0%|          | 0/119954 [00:00<?, ?it/s]

Mean val los: 0.1854 | Recall: 0.00042409368325024843


Epoch 5/50:   0%|          | 0/235 [00:00<?, ?batch/s]

Validation:   0%|          | 0/8 [00:00<?, ?batch/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/5266 [00:00<?, ?it/s]

All predictions:
[tensor([14117, 11294, 14452,  ..., 15493, 15681,  6675], dtype=torch.int16), tensor([14117, 11294, 14452,  ..., 15493, 15681,  6675], dtype=torch.int16), tensor([14117, 11294, 14452,  ..., 15493, 15681,  6675], dtype=torch.int16), tensor([ 4874,  9879,  1679,  ..., 11907, 11325,  9193], dtype=torch.int16), tensor([10475, 14167, 15220,  ..., 14543,  7105,  9416], dtype=torch.int16), tensor([10475, 14167, 15220,  ..., 14543,  7105,  9416], dtype=torch.int16), tensor([10475, 14167, 15220,  ..., 14543,  7105,  9416], dtype=torch.int16), tensor([10475, 14167, 15220,  ..., 14543,  7105,  9416], dtype=torch.int16), tensor([10475, 14167, 15220,  ..., 14543,  7105,  9416], dtype=torch.int16), tensor([10475, 14167, 15220,  ..., 14543,  7105,  9416], dtype=torch.int16)]


Calculating Recall@k:   0%|          | 0/119954 [00:00<?, ?it/s]

Mean val los: 0.1860 | Recall: 0.00042409368325024843
EarlyStopping counter: 1 out of 7


Epoch 6/50:   0%|          | 0/235 [00:00<?, ?batch/s]

Validation:   0%|          | 0/8 [00:00<?, ?batch/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/5266 [00:00<?, ?it/s]

All predictions:
[tensor([14117, 11294, 14452,  ..., 15272, 15681,  6675], dtype=torch.int16), tensor([14117, 11294, 14452,  ..., 15272, 15681,  6675], dtype=torch.int16), tensor([14117, 11294, 14452,  ..., 15272, 15681,  6675], dtype=torch.int16), tensor([ 4874,  9879,  1133,  ...,  3579, 11907,  9640], dtype=torch.int16), tensor([14167, 15220, 10475,  ..., 14543,  3865,  9416], dtype=torch.int16), tensor([14167, 15220, 10475,  ..., 14543,  3865,  9416], dtype=torch.int16), tensor([14167, 15220, 10475,  ..., 14543,  3865,  9416], dtype=torch.int16), tensor([14167, 15220, 10475,  ..., 14543,  3865,  9416], dtype=torch.int16), tensor([14167, 15220, 10475,  ..., 14543,  3865,  9416], dtype=torch.int16), tensor([14167, 15220, 10475,  ..., 14543,  3865,  9416], dtype=torch.int16)]


Calculating Recall@k:   0%|          | 0/119954 [00:00<?, ?it/s]

Mean val los: 0.1841 | Recall: 0.00042409368325024843


Epoch 7/50:   0%|          | 0/235 [00:00<?, ?batch/s]

Validation:   0%|          | 0/8 [00:00<?, ?batch/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/5266 [00:00<?, ?it/s]

All predictions:
[tensor([14117,  7604,   176,  ..., 15272, 15681,  6675], dtype=torch.int16), tensor([14117,  7604,   176,  ..., 15272, 15681,  6675], dtype=torch.int16), tensor([14117,  7604,   176,  ..., 15272, 15681,  6675], dtype=torch.int16), tensor([ 4874,  9879,  1133,  ..., 11325,  3579, 11907], dtype=torch.int16), tensor([14167, 15220, 10475,  ...,  4238, 14543,  9416], dtype=torch.int16), tensor([14167, 15220, 10475,  ...,  4238, 14543,  9416], dtype=torch.int16), tensor([14167, 15220, 10475,  ...,  4238, 14543,  9416], dtype=torch.int16), tensor([14167, 15220, 10475,  ...,  4238, 14543,  9416], dtype=torch.int16), tensor([14167, 15220, 10475,  ...,  4238, 14543,  9416], dtype=torch.int16), tensor([14167, 15220, 10475,  ...,  4238, 14543,  9416], dtype=torch.int16)]


Calculating Recall@k:   0%|          | 0/119954 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
# import numpy as np
# from joblib import Parallel, delayed
# from math import log2
# import torch  # если вы работаете с pytorch-тензорами

# # -------------------------------------------------------------------
# # ШАГ 1: Преобразуем y_true (список тензоров или списков) в список set()
# # -------------------------------------------------------------------
# def make_true_sets(y_true):
#     true_sets = []
#     for t in y_true:
#         if torch.is_tensor(t):
#             arr = t.cpu().numpy().tolist()
#         else:
#             arr = list(t)
#         true_sets.append(set(arr))
#     return true_sets

# # -------------------------------------------------------------------
# # ШАГ 2: Оптимизированные функции, без ambiguity-условий
# # -------------------------------------------------------------------
# def precision_at_k(y_true_sets, y_pred, k):
#     precisions = []
#     for true_set, pred in zip(y_true_sets, y_pred):
#         pred_k = pred[:k]
#         precisions.append(len(set(pred_k) & true_set) / k)
#     return np.mean(precisions)

# def _recall_single(true_set, pred, k):
#     if len(true_set) == 0:
#         return 0.0
#     pred_k = pred[:k]
#     return len(set(pred_k) & true_set) / len(true_set)

# def recall_at_k(y_true_sets, y_pred, k, n_jobs=-1):
#     recalls = Parallel(n_jobs=n_jobs)(
#         delayed(_recall_single)(ts, pred, k)
#         for ts, pred in tqdm(zip(y_true_sets, y_pred), desc='Calculating Recall@k', total=len(y_true_sets))
#     )
#     return np.mean(recalls)

# def _dcg(rel, k):
#     score = 0.0
#     for i, r in enumerate(rel[:k]):
#         score += r / log2(i + 2)
#     return score

# def _ndcg_single(true_set, pred, k):
#     rel = [1 if item in true_set else 0 for item in pred[:k]]
#     dcg = _dcg(rel, k)
#     ideal = sorted(rel, reverse=True)
#     idcg = _dcg(ideal, k)
#     return dcg / idcg if idcg > 0 else 0.0

# def ndcg_at_k(y_true_sets, y_pred, k, n_jobs=-1):
#     ndcgs = Parallel(n_jobs=n_jobs)(
#         delayed(_ndcg_single)(ts, pred, k)
#         for ts, pred in tqdm(zip(y_true_sets, y_pred), desc='Calculating NDCG@k', total=len(y_true_sets))
#     )
#     return np.mean(ndcgs)

# def _mrr_single(true_set, pred, k):
#     for rank, item in enumerate(pred[:k], start=1):
#         if item in true_set:
#             return 1.0 / rank
#     return 0.0

# def mrr_at_k(y_true_sets, y_pred, k, n_jobs=-1):
#     mrrs = Parallel(n_jobs=n_jobs)(
#         delayed(_mrr_single)(ts, pred, k)
#         for ts, pred in tqdm(zip(y_true_sets, y_pred), desc='Calculating MRR@k', total=len(y_true_sets))
#     )
#     return np.mean(mrrs)

# # -------------------------------------------------------------------
# # ШАГ 3: Пример использования
# # -------------------------------------------------------------------
# # допустим, y_true — список тензоров или списков, y_pred — список списков предсказаний
# y_true = all_references.copy()
# y_pred = all_predictions.copy()
# y_true_sets = make_true_sets(y_true)

# # теперь можно вызывать:
# print("Recall@10:", recall_at_k(y_true_sets, y_pred, k=10))
# print("NDCG@10: ", ndcg_at_k(  y_true_sets, y_pred, k=10))
# print("MRR@10:  ", mrr_at_k(   y_true_sets, y_pred, k=10))

Error in callback <bound method _WandbInit._pre_run_cell_hook of <wandb.sdk.wandb_init._WandbInit object at 0x73592c781210>> (for pre_run_cell):


BrokenPipeError: [Errno 32] Broken pipe

Calculating Recall@k:   0%|          | 0/119954 [00:00<?, ?it/s]