# Imports

In [1]:
import pandas as pd
import numpy as np
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as td

import gc
from pathlib import Path
import typing as tp
from tqdm.autonotebook import tqdm, trange
import wandb
import warnings
from joblib import Parallel, delayed
from math import log2

# torch.multiprocessing.set_start_method('fork', force=True)
warnings.filterwarnings("ignore", category=DeprecationWarning)
np.random.seed(31337)

  from tqdm.autonotebook import tqdm, trange


In [2]:
from typing import Callable, Literal
import numpy as np
import torch

class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self,
                 patience: int = 7,
                 threshold = 0,
                 threshold_mode: Literal['rel', 'abs'] = 'abs',
                 verbose: bool = False,
                 path: str = 'checkpoint.pt',
                 trace_func: Callable = print
                 ):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement.
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print
        """
        self.patience = patience
        self.verbose = verbose
        self.threshold_mode = threshold_mode
        self.counter = 0
        self.best_val_loss = None
        self.early_stop = False
        self.val_loss_min = np.inf
        self.threshold = threshold
        self.path = path
        self.trace_func = trace_func

    def _significant_improvement(self, val_loss) -> bool:
        if self.threshold_mode == 'abs':
            return self.best_val_loss - val_loss > self.threshold
        else:
            return (self.best_val_loss - val_loss) / self.best_val_loss > self.threshold

    def __call__(self, val_loss, model):
        # Check if validation loss is nan
        if np.isnan(val_loss):
            self.trace_func("Validation loss is NaN. Ignoring this epoch.")
            return

        if self.best_val_loss is None:
            self.best_val_loss = val_loss
            # self.save_checkpoint(val_loss, model)
        elif self._significant_improvement(val_loss):
            self.best_val_loss = val_loss
            # self.save_checkpoint(val_loss, model)
            self.counter = 0  # Reset counter since improvement occurred
        else:
            # No significant improvement
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True

    def save_checkpoint(self, val_loss: float, model: torch.nn.Module):
        '''Saves model when validation loss decreases.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

# Load data

In [3]:
folder = Path("/content/")
# folder = Path("data/data_kion")
users_df = pd.read_parquet(folder / "users_final.parquet")
items_df = pd.read_parquet(folder / "items_final.parquet")
train_data = pd.read_parquet(folder / "train_triplets.parquet")
val_data   = pd.read_parquet(folder / "val_pos.parquet")

In [4]:
NO_MOVIE = int(items_df.index.max()) + 1 # id bigger than any item id
PAD_SIZE = 30  # users' interactions are padded to have the same number of items
MOVIE_IDS_DTYPE = torch.int16
USER_IDS_DTYPE = torch.int32
MAIN_DTYPE = torch.float32

ITEMS_NUM_GENRE_FEATURES = 29
ITEMS_NUM_CAT_FEATURES = 4
NO_MOVIE

15963

# Dataset

In [5]:
def pad_with_specific_value(tensor, target_length, pad_value):
    """`tensor` is considered to have only unique values"""
    shuffled = tensor[torch.randperm(len(tensor))[:target_length]]
    cur_length = len(shuffled)
    if cur_length < target_length:
        padding = torch.full((target_length - cur_length,), pad_value, dtype=shuffled.dtype)
        return torch.cat([shuffled, padding])
    else:
        return shuffled

def group_by_user(triplets: pd.DataFrame, column: Literal['film_pos', 'film_neg'] = 'film_pos') -> pd.DataFrame:
    """Group interactions by user and convert to DataFrame with column 'interactions' - Tensors of unique film IDs"""
    groupped_users = triplets.groupby('user_id').apply(lambda x: x[column].tolist())
    groupped_users = pd.DataFrame({'interactions': groupped_users.values}, index=groupped_users.index)
    groupped_users['interactions'] = groupped_users['interactions'].apply(lambda x: torch.unique(torch.tensor(x, dtype=MOVIE_IDS_DTYPE)))
    return groupped_users

In [6]:
def collate_fn(data: list[tuple]):
    return data

class BaseDSSMDataset(torch.utils.data.Dataset):
    def __init__(self):
        super().__init__()
        self.type = None
        self.seen_items = {}

class TrainDSSMDataset(BaseDSSMDataset):
    def __init__(self, triplets: pd.DataFrame):
        super().__init__()
        self.type = 'train'
        self.triplets = triplets
        self.grouped_pos_users_interactions = group_by_user(triplets, column='film_pos')  # just like padded users but not padded
        self.padded_users = self.grouped_pos_users_interactions.copy(deep=True)
        self.padded_users['interactions'] = self.padded_users['interactions'].apply(lambda x: pad_with_specific_value(x, PAD_SIZE, NO_MOVIE).to(dtype=torch.int32))  # int32 required by EmbeddingBag

        self.grouped_neg_users_interactions = group_by_user(triplets, column='film_neg')  # all negative interactions grouped by users into tensors
        self.all_users = torch.tensor(self.grouped_pos_users_interactions.index.tolist(), dtype=USER_IDS_DTYPE)  # all unique users from triplets

    def __getitem__(self, index: int):
        cur_triplet = self.triplets.iloc[index]
        user_id = cur_triplet['user_id']

        user_info = torch.tensor(users_df.loc[user_id].values, dtype=MAIN_DTYPE)
        user_interactions = self.padded_users.loc[user_id]['interactions']
        pos_films_features = torch.tensor(items_df.loc[cur_triplet['film_pos']].values, dtype=MAIN_DTYPE)
        neg_films_features = torch.tensor(items_df.loc[cur_triplet['film_neg']].values, dtype=MAIN_DTYPE)
        pos_ids = self.grouped_pos_users_interactions.loc[user_id]['interactions']

        return user_id, user_interactions, user_info, pos_films_features, neg_films_features, pos_ids

    def __getitems__(self, index: tp.Sequence[int]):
        cur_triplets = self.triplets.iloc[index]
        user_ids = cur_triplets['user_id']

        user_info = torch.tensor(users_df.loc[user_ids].values, dtype=MAIN_DTYPE)
        user_interactions = torch.stack(self.padded_users.loc[user_ids]['interactions'].tolist())
        pos_films_features = torch.tensor(items_df.loc[cur_triplets['film_pos']].values, dtype=MAIN_DTYPE)
        neg_films_features = torch.tensor(items_df.loc[cur_triplets['film_neg']].values, dtype=MAIN_DTYPE)
        pos_ids = self.grouped_pos_users_interactions.loc[user_ids]['interactions'].tolist()

        return user_ids, user_interactions, user_info, pos_films_features, neg_films_features, pos_ids

    def get_user_data(self, user_id: int) -> tuple[Tensor, Tensor]:
        user_info = torch.tensor(users_df.loc[user_id].values, dtype=MAIN_DTYPE)
        if user_id in self.padded_users.index:
            user_interactions = self.padded_users.loc[user_id]['interactions']
        else:
            user_interactions = torch.full((PAD_SIZE,), NO_MOVIE, dtype=torch.int32)

        return user_interactions, user_info

    def get_users_data(self, user_ids: int) -> tuple[Tensor, Tensor]:
        user_info = torch.tensor(users_df.loc[user_ids].values, dtype=MAIN_DTYPE)

        inters = []
        for uid in user_ids:
            if uid in self.padded_users.index:
                inter = self.padded_users.loc[uid, 'interactions']
            else:
                inter = torch.full(
                    (PAD_SIZE,),
                    NO_MOVIE,
                    dtype=torch.int32
                )
            inters.append(inter)

        user_interactions = torch.stack(inters, dim=0)

        return user_interactions, user_info

    def __len__(self):
        return len(self.triplets)

class EvalDSSMDataset(BaseDSSMDataset):
    def __init__(self, pos_df: pd.DataFrame):
        super().__init__()
        self.type = 'eval'
        self.all_users = pos_df['user_id'].unique().astype(int)
        self.grouped_pos_users_interactions = group_by_user(pos_df, column='film_pos')
        # self.seen_items = train_triplets.grouped_pos_users_interactions | train_triplets.grouped_neg_users_interactions  # seen items are all positive and negative interactions from train triplets

    def __getitem__(self, index: int):
        user = int(self.all_users[index])
        pos  = int(self.grouped_pos_users_interactions[user]['interactions'])
        return user, pos

    def __getitems__(self, index: tp.Sequence[int]):
        users = self.all_users[index]
        pos   = self.grouped_pos_users_interactions.loc[users]['interactions'].tolist()
        return users, pos

    def get_all_users(self) -> Tensor:
        """
        Get IDs of all users, that have positive interactions in this dataset

        Returns:
            users (Tensor, dtype=torch.int32) : All user IDs
        """
        return self.all_users

    def __len__(self):
        return len(self.all_users)


# Model

In [7]:
class ItemNet(nn.Module):
    def __init__(self,
                 dim_embedding: int,
                 dim_input: int,
                 dim_hidden: int = 96,
                 activation: tp.Callable[[Tensor], Tensor] = nn.ReLU()
                 ) -> None:
        super().__init__()
        self.cat_embedding = nn.Linear(ITEMS_NUM_CAT_FEATURES, dim_hidden)
        self.genre_embedding = nn.Linear(ITEMS_NUM_GENRE_FEATURES, dim_hidden)
        self.country_embedding = nn.Linear(dim_input - ITEMS_NUM_CAT_FEATURES - ITEMS_NUM_GENRE_FEATURES, dim_hidden)
        self.dense_block = nn.Sequential(
            nn.Linear(dim_hidden * 3, dim_hidden),
            activation,
        )
        self.output_layer = nn.Linear(dim_hidden, dim_embedding, bias=False)
        self.norm = nn.LayerNorm(dim_embedding)

    def forward(self, item_features: Tensor) -> Tensor:
        cat_features = item_features[:, :ITEMS_NUM_CAT_FEATURES]
        genre_features = item_features[:, ITEMS_NUM_CAT_FEATURES:ITEMS_NUM_CAT_FEATURES + ITEMS_NUM_GENRE_FEATURES]
        country_features = item_features[:, ITEMS_NUM_CAT_FEATURES + ITEMS_NUM_GENRE_FEATURES:]

        cat_emb = self.cat_embedding(cat_features)
        genre_emb = self.genre_embedding(genre_features)
        country_emb = self.country_embedding(country_features)

        pop_genre = torch.concat([cat_emb, genre_emb, country_emb], axis=1)
        features = self.dense_block(pop_genre)
        output = self.output_layer(features)

        return self.norm(output)


class UserNet(nn.Module):
    def __init__(self,
                 dim_embedding: int,
                 num_items: int,
                 dim_user_features: int,
                 activation: tp.Callable[[Tensor], Tensor] = nn.ReLU()
                 ) -> None:                              # | +1 for the NO_MOVIE element
        super().__init__()                               # V
        self.track_embeddings = nn.EmbeddingBag(num_items + 1, dim_embedding, padding_idx=num_items)
        self.info_embedding = nn.Sequential(
            nn.Linear(dim_user_features, int(dim_embedding // 2)),
            nn.ReLU(),
            nn.Linear(int(dim_embedding // 2), dim_embedding),
            activation
        )
        self.dense_layer = nn.Sequential(
            nn.Linear(dim_embedding, int(dim_embedding // 2)),
            nn.ReLU(),
            nn.Linear(int(dim_embedding // 2), dim_embedding),
            activation
        )
        self.output_layer = nn.Linear(3*dim_embedding, dim_embedding, bias=False)
        self.norm = nn.LayerNorm(dim_embedding)
        self.num_items = num_items
        self.dim_user_features = dim_user_features

    def forward(self, user_interactions: Tensor, user_info: Tensor) -> Tensor:
        # print(f'EMBEDDING BAG MAX INPUT: {user_ids.max()} while was ready for {self.num_items}')
        interactions_emb = self.track_embeddings(user_interactions)
        info_emb = self.info_embedding(user_info.float())
        features = self.dense_layer(interactions_emb)
        x = torch.cat([interactions_emb, features, info_emb], dim=1)
        output = self.output_layer(x)
        return self.norm(output)

In [8]:
class DSSM(nn.Module):
    def __init__(self,
                 dim_item_features: int,
                 dim_user_features: int,
                 num_items: int,
                 embedding_dim: int = 100,
                 lr: float = 1e-3,
                 triplet_loss_margin: float = 0.4,
                 weight_decay: float = 1e-3,
                 log_to_prog_bar: bool = True,
                 ) -> None:
        super().__init__()
        self.lr = lr
        self.triplet_loss_margin = triplet_loss_margin
        self.weight_decay = weight_decay
        self.log_to_prog_bar = log_to_prog_bar
        self.item_net = ItemNet(embedding_dim, dim_item_features)
        self.user_net = UserNet(embedding_dim, num_items, dim_user_features)

    def forward(self,
                user_intercations: Tensor,
                user_info: Tensor,
                item_features_pos: Tensor,
                item_features_neg: Tensor,
                ) -> tuple[Tensor, Tensor, Tensor]:
        """Returns embeddings of users, positive items and negative items"""
        anchor = self.user_net(user_intercations, user_info)
        pos = self.item_net(item_features_pos)
        neg = self.item_net(item_features_neg)

        return anchor, pos, neg

    def encode_user(self, user_interactions: Tensor, user_info: Tensor) -> Tensor:
        return self.user_net(user_interactions, user_info)

    def encode_item(self, item_features: Tensor) -> Tensor:
        return self.item_net(item_features)

# Training

## Metrics

In [9]:
def make_true_sets(y_true):
    true_sets = []
    for t in y_true:
        if torch.is_tensor(t):
            arr = t.cpu().numpy().tolist()
        else:
            arr = list(t)
        true_sets.append(set(arr))
    return true_sets

def _recall_single(relevant_tensor, pred_tensor, k):
    if (num_relevant_items := len(relevant_tensor)) == 0:
        # TODO: такие сэмплы нам вообще не нужны в валидации, лучше их убрать заранее
        return None
    return torch.isin(pred_tensor[:k], relevant_tensor).sum().float() / num_relevant_items

def recall_at_k(y_true, y_pred, k, n_jobs=-1):
    recalls = Parallel(n_jobs=n_jobs)(
        delayed(_recall_single)(ts, pred, k)
        for ts, pred in tqdm(zip(y_true, y_pred), desc='Calculating Recall@k', total=len(y_true))
    )
    return np.mean(recalls)

def precision_at_k(y_true_sets, y_pred, k):
    precisions = []
    for true_set, pred in zip(y_true_sets, y_pred):
        pred_k = pred[:k]
        precisions.append(len(set(pred_k) & true_set) / k)
    return np.mean(precisions)

def _dcg(rel, k):
    score = 0.0
    for i, r in enumerate(rel[:k]):
        score += r / log2(i + 2)
    return score

def _ndcg_single(true_set, pred, k):
    rel = [1 if item in true_set else 0 for item in pred[:k]]
    dcg = _dcg(rel, k)
    ideal = sorted(rel, reverse=True)
    idcg = _dcg(ideal, k)
    return dcg / idcg if idcg > 0 else 0.0

def ndcg_at_k(y_true_sets, y_pred, k, n_jobs=-1):
    ndcgs = Parallel(n_jobs=n_jobs)(
        delayed(_ndcg_single)(ts, pred, k)
        for ts, pred in tqdm(zip(y_true_sets, y_pred), desc='Calculating NDCG@k', total=len(y_true_sets))
    )
    return np.mean(ndcgs)

def _mrr_single(true_set, pred, k):
    for rank, item in enumerate(pred[:k], start=1):
        if item in true_set:
            return 1.0 / rank
    return 0.0

def mrr_at_k(y_true_sets, y_pred, k, n_jobs=-1):
    mrrs = Parallel(n_jobs=n_jobs)(
        delayed(_mrr_single)(ts, pred, k)
        for ts, pred in tqdm(zip(y_true_sets, y_pred), desc='Calculating MRR@k', total=len(y_true_sets))
    )
    return np.mean(mrrs)

def _average_precision_single_full(relevant_ids, predicted_ids, k=None):
    """
    Calculate mean Average Precision (mAP) for recommendation systems.

    Args:
        relevant_ids: Tensor of relevant item IDs (ground truth)
        predicted_ids: Tensor of predicted item IDs sorted by relevance
        k: Optional cutoff for top-k items to consider. If None, use all predicted items.

    Returns:
        mAP score
    """
    if k is not None:
        predicted_ids = predicted_ids[:k]

    # Create binary relevance tensor (1 for relevant items, 0 otherwise)
    relevance = torch.isin(predicted_ids, relevant_ids).float()

    # Calculate precision at each position
    cum_relevance = torch.cumsum(relevance, dim=0)
    positions = torch.arange(1, len(predicted_ids) + 1)
    precision_at_k = cum_relevance / positions

    # Only consider precision at positions where item is relevant
    relevant_precision = precision_at_k * relevance

    # Avoid division by zero if there are no relevant items
    num_relevant = len(relevant_ids)
    if num_relevant == 0:
        return 0.0

    # Sum relevant precisions and divide by total relevant items
    ap = torch.sum(relevant_precision) / num_relevant

    return ap.item()


def mean_average_precision_full(y_true: list[torch.Tensor],
                                y_pred: list[torch.Tensor],
                                n_jobs: int = -1) -> float:
    """
    Compute MAP over full prediction list.
    """
    ap_scores = Parallel(n_jobs=n_jobs)(
        delayed(_average_precision_single_full)(true, pred)
        for true, pred in tqdm(zip(y_true, y_pred),
                               desc='Calculating MAP (no K)',
                               total=len(y_true))
    )

    valid_scores = [s for s in ap_scores if s is not None]
    if not valid_scores:
        return 0.0
    return float(np.mean(valid_scores))

## Configure training

In [10]:
print('Get train dataset')
train_dataset = TrainDSSMDataset(train_data)
print('Get validation dataset')
val_dataset = EvalDSSMDataset(val_data)

Get train dataset
Get validation dataset


In [11]:
# Hyperparams
EPOCHS = 50
BATCH_SIZE = 1024 * 16
NUM_WORKERS = 0
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'

EMBEDDING_DIM = 128
LR = 1e-4
WEIGHT_DECAY = 1e-4
TRIPLET_LOSS_MARGIN = 0.4
EXPERIMENT_NAME = 'wtf_1'
K = 100
LOG_TO_WANDB = False

train_dataloader = td.DataLoader(train_dataset, collate_fn=collate_fn, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_dataloader = td.DataLoader(val_dataset, collate_fn=collate_fn, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

model = DSSM(dim_item_features=items_df.shape[1],
             dim_user_features=users_df.shape[1],
             num_items=NO_MOVIE,
             embedding_dim=EMBEDDING_DIM
             ).to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
# lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, threshold=3e-3, threshold_mode='abs')
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    patience=3,
    threshold=1e-3,
    threshold_mode='abs'
)
early_stopping = EarlyStopping(
    patience=7,
    threshold=1e-3,
    threshold_mode='rel',
    verbose=True,
    trace_func=tqdm.write,
    path=f'model_weights/{EXPERIMENT_NAME}.ckpt'
)

## Training loop

In [None]:
if LOG_TO_WANDB:
    entity = "xenz5240-higher-school-of-economics"
    wandb.init(entity=entity, project='reelsrec-dssm', name=EXPERIMENT_NAME)

    wandb.config.update({
        "epochs": EPOCHS,
        "batch_size": BATCH_SIZE,
        "embedding_dim": EMBEDDING_DIM,
        "learning_rate": LR,
        "weight_decay": WEIGHT_DECAY,
        "triplet_loss_margin": TRIPLET_LOSS_MARGIN,
        "device": DEVICE
    })

for epoch in trange(EPOCHS, position=0, desc='Training', unit='epoch'):
    model.train()
    epoch_losses = []
    batch_times = []
    for step, batch in enumerate(tqdm(train_dataloader, position=1, desc=f'Epoch {epoch + 1}/{EPOCHS}', unit='batch')):
        _, user_inters, user_info, pos_features, neg_features, _ = batch
        optimizer.zero_grad()
        batch_users_embs, batch_positive_films_embs, batch_negative_films_embs = model(user_inters.to(DEVICE),
                                                                                       user_info.to(DEVICE),
                                                                                       pos_features.to(DEVICE),
                                                                                       neg_features.to(DEVICE))
        loss = F.triplet_margin_loss(batch_users_embs, batch_positive_films_embs, batch_negative_films_embs, margin=TRIPLET_LOSS_MARGIN)
        loss.backward()
        optimizer.step()

        epoch_losses.append(loss.item())

        if LOG_TO_WANDB and step % 20 == 0:
            wandb.log({"train/loss": loss.item(),
                       "epoch": epoch + 1,
                       "step": step})

    mean_train_loss = float(np.mean(epoch_losses))
    if LOG_TO_WANDB:
        wandb.log({"train/epoch_loss": mean_train_loss, "epoch": epoch + 1})

    # Evaluation
    model.eval()
    all_predictions = []
    all_references = []
    # Embed all items
    items_df.sort_index(inplace=True)  # Ensure items are sorted by their IDs
    item_features = torch.tensor(items_df.values, dtype=MAIN_DTYPE, device=DEVICE)
    EMPTY_WATCH_HISTORY = torch.empty(0, dtype=MOVIE_IDS_DTYPE)
    with torch.no_grad():
        all_items_embeds = model.encode_item(item_features)

        for step, batch in enumerate(tqdm(val_dataloader, position=1, desc=f'Validation {epoch + 1}/{EPOCHS}', unit='batch')):
            user_ids, positive_interactions = batch
            # user_embs = []
            # for user_id in users:
            user_inters, user_info = train_dataset.get_users_data(user_ids)
            user_embs = model.encode_user(user_inters.to(DEVICE), user_info.to(DEVICE))
                # сохранить эмбеддинги
                # user_embs.append(user_emb)

            # строим матрицу расстояний
            distance_matrix = torch.cdist(user_embs, all_items_embeds, p=2.0)
            val_recommendations = distance_matrix.argsort(dim=1, descending=False).to(dtype=MOVIE_IDS_DTYPE, device='cpu')  # (batch_users x all_items)

            # фильтруем рекомендации, убирая просмотренные фильмы
            # Get all films each user has interacted with (from both film_pos and film_neg)
            user_interacted_films: dict[int, Tensor] = (
                train_dataset.triplets[train_dataset.triplets['user_id'].isin(user_ids.tolist())]
                .groupby('user_id')[['film_pos', 'film_neg']]
                .apply(lambda x: torch.tensor(pd.unique(x.values.ravel('K')).tolist(), dtype=MOVIE_IDS_DTYPE))
            ).to_dict()

            # For each user remove watched films from their recommendations
            filtered_val_recommendations = [None] * len(user_ids)
            for i, (user_ids, recommendations) in tqdm(enumerate(zip(user_ids, val_recommendations)),
                                                        desc='Filtering recommendations',
                                                        total=len(user_ids),
                                                        leave=False):
                watched_films = user_interacted_films.get(user_ids, EMPTY_WATCH_HISTORY)
                filtered_val_recommendations[i] = recommendations[torch.isin(recommendations, watched_films, invert=True)]

            all_predictions.extend(filtered_val_recommendations)
            all_references.extend(positive_interactions)

    map_score = mean_average_precision_full(y_true=all_references, y_pred=all_predictions, n_jobs=8)
    tqdm.write(f'mAP: {map_score}')
    recall_score = recall_at_k(y_true=all_references, y_pred=all_predictions, k=K, n_jobs=8)
    tqdm.write(f'Recall: {recall_score}')
    # precision = precision_at_k(all_predictions, all_references, k=K)
    # ndcg   = ndcg_at_k(all_predictions, all_references, k=K)
    # mrr    = mrr_at_k(all_predictions, all_references, k=K)
    # tqdm.write(f'Mean val los: {mean_val_loss:.4f} | Recall: {recall:.4f} | Precision: {precision:.4f} | NDCG: {ndcg:.4f} | MRR: {mrr:.4f}')

    gc.collect()
    del all_predictions, all_references, filtered_val_recommendations, positive_interactions, all_items_embeds

    if LOG_TO_WANDB:
        wandb.log({
            f"val/recall@{K}": recall_score,
            # f"val/precision@{K}": precision,
            # f"val/ndcg@{K}": ndcg,
            # f"val/mrr@{K}": mrr,
            "train/lr": optimizer.param_groups[0]['lr'],
            "epoch": epoch + 1,
        })

Training:   0%|          | 0/50 [00:00<?, ?epoch/s]

Epoch 1/50:   0%|          | 0/235 [00:00<?, ?batch/s]

Validation 1/50:   0%|          | 0/3 [00:00<?, ?batch/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/13111 [00:00<?, ?it/s]

Calculating MAP (no K):   0%|          | 0/45879 [00:00<?, ?it/s]

mAP: 0.0019492063605703988


Calculating Recall@k:   0%|          | 0/45879 [00:00<?, ?it/s]

Recall: 0.026562903076410294


Epoch 2/50:   0%|          | 0/235 [00:00<?, ?batch/s]

Validation 2/50:   0%|          | 0/3 [00:00<?, ?batch/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/13111 [00:00<?, ?it/s]

Calculating MAP (no K):   0%|          | 0/45879 [00:00<?, ?it/s]

mAP: 0.002070709025198728


Calculating Recall@k:   0%|          | 0/45879 [00:00<?, ?it/s]

Recall: 0.022234711796045303


Epoch 3/50:   0%|          | 0/235 [00:00<?, ?batch/s]

Validation 3/50:   0%|          | 0/3 [00:00<?, ?batch/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/13111 [00:00<?, ?it/s]

Calculating MAP (no K):   0%|          | 0/45879 [00:00<?, ?it/s]

mAP: 0.003091146211951498


Calculating Recall@k:   0%|          | 0/45879 [00:00<?, ?it/s]

Recall: 0.04346254840493202


Epoch 4/50:   0%|          | 0/235 [00:00<?, ?batch/s]

Validation 4/50:   0%|          | 0/3 [00:00<?, ?batch/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/13111 [00:00<?, ?it/s]

Calculating MAP (no K):   0%|          | 0/45879 [00:00<?, ?it/s]

mAP: 0.002614462314242501


Calculating Recall@k:   0%|          | 0/45879 [00:00<?, ?it/s]

Recall: 0.0336497537791729


Epoch 5/50:   0%|          | 0/235 [00:00<?, ?batch/s]

Validation 5/50:   0%|          | 0/3 [00:00<?, ?batch/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/16384 [00:00<?, ?it/s]

Filtering recommendations:   0%|          | 0/13111 [00:00<?, ?it/s]

Calculating MAP (no K):   0%|          | 0/45879 [00:00<?, ?it/s]

In [None]:
filtered_val_recommendations[0]

In [None]:
positive_interactions[0]

In [None]:
(filtered_val_recommendations[0] == 3190).nonzero(as_tuple=False)

In [None]:
# import numpy as np
# from joblib import Parallel, delayed
# from math import log2
# import torch  # если вы работаете с pytorch-тензорами

# # -------------------------------------------------------------------
# # ШАГ 1: Преобразуем y_true (список тензоров или списков) в список set()
# # -------------------------------------------------------------------
# def make_true_sets(y_true):
#     true_sets = []
#     for t in y_true:
#         if torch.is_tensor(t):
#             arr = t.cpu().numpy().tolist()
#         else:
#             arr = list(t)
#         true_sets.append(set(arr))
#     return true_sets

# # -------------------------------------------------------------------
# # ШАГ 2: Оптимизированные функции, без ambiguity-условий
# # -------------------------------------------------------------------
# def precision_at_k(y_true_sets, y_pred, k):
#     precisions = []
#     for true_set, pred in zip(y_true_sets, y_pred):
#         pred_k = pred[:k]
#         precisions.append(len(set(pred_k) & true_set) / k)
#     return np.mean(precisions)

# def _recall_single(true_set, pred, k):
#     if len(true_set) == 0:
#         return 0.0
#     pred_k = pred[:k]
#     return len(set(pred_k) & true_set) / len(true_set)

# def recall_at_k(y_true_sets, y_pred, k, n_jobs=-1):
#     recalls = Parallel(n_jobs=n_jobs)(
#         delayed(_recall_single)(ts, pred, k)
#         for ts, pred in tqdm(zip(y_true_sets, y_pred), desc='Calculating Recall@k', total=len(y_true_sets))
#     )
#     return np.mean(recalls)

# def _dcg(rel, k):
#     score = 0.0
#     for i, r in enumerate(rel[:k]):
#         score += r / log2(i + 2)
#     return score

# def _ndcg_single(true_set, pred, k):
#     rel = [1 if item in true_set else 0 for item in pred[:k]]
#     dcg = _dcg(rel, k)
#     ideal = sorted(rel, reverse=True)
#     idcg = _dcg(ideal, k)
#     return dcg / idcg if idcg > 0 else 0.0

# def ndcg_at_k(y_true_sets, y_pred, k, n_jobs=-1):
#     ndcgs = Parallel(n_jobs=n_jobs)(
#         delayed(_ndcg_single)(ts, pred, k)
#         for ts, pred in tqdm(zip(y_true_sets, y_pred), desc='Calculating NDCG@k', total=len(y_true_sets))
#     )
#     return np.mean(ndcgs)

# def _mrr_single(true_set, pred, k):
#     for rank, item in enumerate(pred[:k], start=1):
#         if item in true_set:
#             return 1.0 / rank
#     return 0.0

# def mrr_at_k(y_true_sets, y_pred, k, n_jobs=-1):
#     mrrs = Parallel(n_jobs=n_jobs)(
#         delayed(_mrr_single)(ts, pred, k)
#         for ts, pred in tqdm(zip(y_true_sets, y_pred), desc='Calculating MRR@k', total=len(y_true_sets))
#     )
#     return np.mean(mrrs)

# # -------------------------------------------------------------------
# # ШАГ 3: Пример использования
# # -------------------------------------------------------------------
# # допустим, y_true — список тензоров или списков, y_pred — список списков предсказаний
# y_true = all_references.copy()
# y_pred = all_predictions.copy()
# y_true_sets = make_true_sets(y_true)

# # теперь можно вызывать:
# print("Recall@10:", recall_at_k(y_true_sets, y_pred, k=10))
# print("NDCG@10: ", ndcg_at_k(  y_true_sets, y_pred, k=10))
# print("MRR@10:  ", mrr_at_k(   y_true_sets, y_pred, k=10))