In [1]:
import pandas as pd
import numpy as np
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as td

from sklearn.preprocessing import MultiLabelBinarizer

# import tensorboard
import shutil
import ast
import typing as tp
import random
from collections import Counter
from tqdm.autonotebook import tqdm, trange
# from tqdm import tqdm, trange


np.random.seed(31337)

  from tqdm.autonotebook import tqdm, trange


# Preprocessing

In [2]:
users_df = pd.read_csv("data/data_kion/users_processed.csv")
items_df = pd.read_csv("data/data_kion/items_processed.csv")
interactions_df = pd.read_csv("data/data_kion/interactions_processed.csv")

## Users preprocessing

In [3]:
users_df.head()

Unnamed: 0,user_id,age,income,sex,kids_flg
0,973171,age_25_34,income_60_90,M,True
1,962099,age_18_24,income_20_40,M,False
2,1047345,age_45_54,income_40_60,F,False
3,721985,age_45_54,income_20_40,F,False
4,704055,age_35_44,income_60_90,F,False


Закодируем возраст и доход числами (по возрастанию от 0), а тем юзерам, у которых они неизвестны, заполним их медианой (категорией, в которую попадает медиана среднего по диапазонам категорий).

In [4]:
sorted_age_categories = sorted(users_df['age'].unique(), key=lambda s: float(s.split('_')[1] if len(s.split('_')) == 3 else np.inf))
age_mapper = {age: id for id, age in enumerate(sorted_age_categories)}
median_age = users_df[users_df['age'] != 'age_unknown']['age'].map(lambda s: (float(s.split('_')[2] if float(s.split('_')[2]) < np.inf else s.split('_')[1]) + float(s.split('_')[1])) / 2).median()
age_fill_value = None 
for age_cat in sorted_age_categories:
    low, high = age_cat.split('_')[1:]
    if int(low) < median_age < int(high):
        age_fill_value = age_cat
        break
age_mapper['age_unknown'] = age_mapper[age_fill_value]

sorted_income_categories = sorted(users_df['income'].unique(), key=lambda s: float(s.split('_')[1] if len(s.split('_')) == 3 else np.inf))
income_mapper = {income: id for id, income in enumerate(sorted_income_categories)}
median_income = users_df[users_df['income'] != 'income_unknown']['income'].map(lambda s: (float(s.split('_')[2] if float(s.split('_')[2]) < np.inf else s.split('_')[1]) + float(s.split('_')[1])) / 2).median()
income_fill_value = None 
for income_cat in sorted_income_categories:
    low, high = income_cat.split('_')[1:]
    if int(low) < median_income < int(high):
        income_fill_value = income_cat
        break
income_mapper['income_unknown'] = income_mapper[income_fill_value]
income_mapper

sex_mapper = {'M': -1, 'sex_unknown': 0, 'F': 1}

age_mapper, income_mapper, sex_mapper

({'age_18_24': 0,
  'age_25_34': 1,
  'age_35_44': 2,
  'age_45_54': 3,
  'age_55_64': 4,
  'age_65_inf': 5,
  'age_unknown': 2},
 {'income_0_20': 0,
  'income_20_40': 1,
  'income_40_60': 2,
  'income_60_90': 3,
  'income_90_150': 4,
  'income_150_inf': 5,
  'income_unknown': 1},
 {'M': -1, 'sex_unknown': 0, 'F': 1})

In [5]:
users_df['age'] = users_df['age'].map(age_mapper)
users_df['income'] = users_df['income'].map(income_mapper)
users_df['sex'] = users_df['sex'].map(sex_mapper)
users_df['kids_flg'] = users_df['kids_flg']
users_df = users_df.astype(np.int32)
users_df.head()

Unnamed: 0,user_id,age,income,sex,kids_flg
0,973171,1,3,-1,1
1,962099,0,1,-1,0
2,1047345,3,2,1,0
3,721985,3,1,1,0
4,704055,2,3,1,0


In [6]:
users_df.info(verbose=True)

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 840197 entries, 0 to 840196
Data columns (total 5 columns):
 #   Column    Non-Null Count   Dtype
---  ------    --------------   -----
 0   user_id   840197 non-null  int32
 1   age       840197 non-null  int32
 2   income    840197 non-null  int32
 3   sex       840197 non-null  int32
 4   kids_flg  840197 non-null  int32
dtypes: int32(5)
memory usage: 16.0 MB


## Items preprocessing

In [7]:
items_df['genres'] = items_df['genres'].map(ast.literal_eval)
items_df['countries'] = items_df['countries'].map(lambda s: list(set(s.split(', '))))

In [8]:
items_df.head()

Unnamed: 0,item_id,content_type,title,title_orig,genres,countries,for_kids,age_rating,studios,directors,actors,description,keywords,release_year_cat
0,10711,film,поговори с ней,Hable con ella,"[драмы, детективы, мелодрамы]",[испания],False,16,unknown,педро альмодовар,"Адольфо Фернандес, Ана Фернандес, Дарио Гранди...",Мелодрама легендарного Педро Альмодовара «Пого...,"Поговори, ней, 2002, Испания, друзья, любовь, ...",2000_2010
1,2508,film,голые перцы,Search Party,"[приключения, комедии]",[сша],False,16,unknown,скот армстронг,"Адам Палли, Брайан Хаски, Дж.Б. Смув, Джейсон ...",Уморительная современная комедия на популярную...,"Голые, перцы, 2014, США, друзья, свадьбы, прео...",2010_2020
2,10716,film,тактическая сила,Tactical Force,"[криминал, триллеры, боевики, комедии]",[канада],False,16,unknown,адам п. калтраро,"Адриан Холмс, Даррен Шалави, Джерри Вассерман,...",Профессиональный рестлер Стив Остин («Все или ...,"Тактическая, сила, 2011, Канада, бандиты, ганг...",2010_2020
3,7868,film,45 лет,45 Years,"[драмы, мелодрамы]",[великобритания],False,16,unknown,эндрю хэй,"Александра Риддлстон-Барретт, Джеральдин Джейм...","Шарлотта Рэмплинг, Том Кортни, Джеральдин Джей...","45, лет, 2015, Великобритания, брак, жизнь, лю...",2010_2020
4,16268,film,все решает мгновение,,"[драмы, спорт, мелодрамы]",[ссср],False,12,ленфильм,виктор садовский,"Александр Абдулов, Александр Демьяненко, Алекс...",Расчетливая чаровница из советского кинохита «...,"Все, решает, мгновение, 1978, СССР, сильные, ж...",1970_1980


In [9]:
items_df.columns

Index(['item_id', 'content_type', 'title', 'title_orig', 'genres', 'countries',
       'for_kids', 'age_rating', 'studios', 'directors', 'actors',
       'description', 'keywords', 'release_year_cat'],
      dtype='object')

Закодируем жанры (отдельный бинарный столбец для каждого жанра)

In [10]:
genres_mlb = MultiLabelBinarizer()
genres_one_hot = pd.DataFrame(genres_mlb.fit_transform(items_df['genres']),
                              columns=list(map(lambda s: f'genre_{s}', genres_mlb.classes_)),
                              index=items_df.index,
                              dtype=np.int32)
countries_mlb = MultiLabelBinarizer()
countries_one_hot = pd.DataFrame(countries_mlb.fit_transform(items_df['countries']),
                                 columns=list(map(lambda s: f'country_{s}', countries_mlb.classes_)),
                                 index=items_df.index,
                                 dtype=np.int32)
items_df = pd.concat([items_df, genres_one_hot, countries_one_hot], axis=1).drop(columns=['genres', 'genre_no_genre', 'studios', 'countries'])
items_df.head()

Unnamed: 0,item_id,content_type,title,title_orig,for_kids,age_rating,directors,actors,description,keywords,...,country_хорватия,country_чехия,country_чили,country_швейцария,country_швеция,country_эквадор,country_эстония,country_юар,country_югославия,country_япония
0,10711,film,поговори с ней,Hable con ella,False,16,педро альмодовар,"Адольфо Фернандес, Ана Фернандес, Дарио Гранди...",Мелодрама легендарного Педро Альмодовара «Пого...,"Поговори, ней, 2002, Испания, друзья, любовь, ...",...,0,0,0,0,0,0,0,0,0,0
1,2508,film,голые перцы,Search Party,False,16,скот армстронг,"Адам Палли, Брайан Хаски, Дж.Б. Смув, Джейсон ...",Уморительная современная комедия на популярную...,"Голые, перцы, 2014, США, друзья, свадьбы, прео...",...,0,0,0,0,0,0,0,0,0,0
2,10716,film,тактическая сила,Tactical Force,False,16,адам п. калтраро,"Адриан Холмс, Даррен Шалави, Джерри Вассерман,...",Профессиональный рестлер Стив Остин («Все или ...,"Тактическая, сила, 2011, Канада, бандиты, ганг...",...,0,0,0,0,0,0,0,0,0,0
3,7868,film,45 лет,45 Years,False,16,эндрю хэй,"Александра Риддлстон-Барретт, Джеральдин Джейм...","Шарлотта Рэмплинг, Том Кортни, Джеральдин Джей...","45, лет, 2015, Великобритания, брак, жизнь, лю...",...,0,0,0,0,0,0,0,0,0,0
4,16268,film,все решает мгновение,,False,12,виктор садовский,"Александр Абдулов, Александр Демьяненко, Алекс...",Расчетливая чаровница из советского кинохита «...,"Все, решает, мгновение, 1978, СССР, сильные, ж...",...,0,0,0,0,0,0,0,0,0,0


И так же, как с юзерами, закодируем числами категориальные фичи (которых по одному значению на строку)

In [11]:
content_type_mapper = {'film': 0, 'series': 1}
age_rating_mapper = {v: id for id, v in enumerate(sorted(items_df['age_rating'].unique()))}
release_year_cat_mapper = {y: id for id, y in enumerate(sorted(items_df['release_year_cat'].unique(), key=lambda s: float(s.split('_')[1])))}
content_type_mapper, age_rating_mapper, release_year_cat_mapper

({'film': 0, 'series': 1},
 {np.int64(0): 0,
  np.int64(6): 1,
  np.int64(12): 2,
  np.int64(16): 3,
  np.int64(18): 4,
  np.int64(21): 5},
 {'inf_1920': 0,
  '1920_1930': 1,
  '1930_1940': 2,
  '1940_1950': 3,
  '1950_1960': 4,
  '1960_1970': 5,
  '1970_1980': 6,
  '1980_1990': 7,
  '1990_2000': 8,
  '2000_2010': 9,
  '2010_2020': 10,
  '2020_inf': 11})

In [12]:
items_df['for_kids'] = items_df['for_kids'].astype(np.int32)
items_df['age_rating'] = items_df['age_rating'].map(age_rating_mapper).astype(np.int32)
items_df['content_type'] = items_df['content_type'].map(content_type_mapper).astype(np.int32)
items_df['release_year_cat'] = items_df['release_year_cat'].map(release_year_cat_mapper).astype(np.int32)
items_df = items_df.drop(columns=['title', 'title_orig', 'directors', 'actors', 'description', 'keywords']).astype(np.int32)

items_df.head()

Unnamed: 0,item_id,content_type,for_kids,age_rating,release_year_cat,genre_аниме,genre_биография,genre_боевики,genre_военные,genre_детективы,...,country_хорватия,country_чехия,country_чили,country_швейцария,country_швеция,country_эквадор,country_эстония,country_юар,country_югославия,country_япония
0,10711,0,0,3,9,0,0,0,0,1,...,0,0,0,0,0,0,0,0,0,0
1,2508,0,0,3,10,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,10716,0,0,3,10,0,0,1,0,0,...,0,0,0,0,0,0,0,0,0,0
3,7868,0,0,3,10,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,16268,0,0,2,6,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [13]:
items_df.info(verbose=True)

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 15963 entries, 0 to 15962
Data columns (total 125 columns):
 #    Column                    Dtype
---   ------                    -----
 0    item_id                   int32
 1    content_type              int32
 2    for_kids                  int32
 3    age_rating                int32
 4    release_year_cat          int32
 5    genre_аниме               int32
 6    genre_биография           int32
 7    genre_боевики             int32
 8    genre_военные             int32
 9    genre_детективы           int32
 10   genre_детские             int32
 11   genre_для взрослых        int32
 12   genre_документальное      int32
 13   genre_драмы               int32
 14   genre_исторические        int32
 15   genre_комедии             int32
 16   genre_короткометражные    int32
 17   genre_криминал            int32
 18   genre_мелодрамы           int32
 19   genre_музыка              int32
 20   genre_мультфильмы         int32
 21   genre_мюзи

## Interactions preprocessing

In [14]:
interactions_df.head()

Unnamed: 0,user_id,item_id,last_watch_dt,total_dur,watched_pct
0,176549,9506,2021-05-11,4250,72
1,699317,1659,2021-05-29,8317,100
2,656683,7107,2021-05-09,10,0
3,864613,7638,2021-07-05,14483,100
4,964868,9506,2021-04-30,6725,100


Фильтруем малоактивных юзеров и непопулярные фильмы.

In [15]:
interactions_df['item_id'].value_counts(), interactions_df['user_id'].value_counts()

(item_id
 10440    202457
 15297    193123
 9728     132865
 13865    122119
 4151      91167
           ...  
 2435          1
 7978          1
 10642         1
 13008         1
 9286          1
 Name: count, Length: 15706, dtype: int64,
 user_id
 416206     1341
 1010539     764
 555233      685
 11526       676
 409259      625
            ... 
 690921        1
 255412        1
 264195        1
 150067        1
 337469        1
 Name: count, Length: 962179, dtype: int64)

Оставим юзеров, посмотревших хотя бы 5 фильмов, и фильмы, которые посмотрело хотя бы 5 юзеров

In [16]:
before_filtering_users = interactions_df['user_id'].nunique()
before_filtering_items = interactions_df['item_id'].nunique()

interactions_df = interactions_df[interactions_df.watched_pct > 10]

valid_users = []

neg_films_features = Counter(interactions_df['user_id'])
for user_id, entries in neg_films_features.most_common():
  if entries >= 5:
    valid_users.append(user_id)

valid_items = []

neg_films_features = Counter(interactions_df['item_id'])
for item_id, entries in neg_films_features.most_common():
  if entries >= 5:
    valid_items.append(item_id)

interactions_df = interactions_df[interactions_df['user_id'].isin(valid_users)]
interactions_df = interactions_df[interactions_df['item_id'].isin(valid_items)]

print(f"Users before filtering: {before_filtering_users:>7}")
print(f"Users after filtering:  {interactions_df['user_id'].nunique():>7}")
print(f"Items before filtering:  {before_filtering_items:>6}")
print(f"Items after filtering:   {interactions_df['item_id'].nunique():>6}")

Users before filtering:  962179
Users after filtering:   207255
Items before filtering:   15706
Items after filtering:     8823


Переведём фичи в 32-битный int

In [17]:
interactions_df['total_dur'].min(), interactions_df['total_dur'].mean(), interactions_df['total_dur'].max()

(np.int64(6), np.float64(11384.85642516867), np.int64(80411672))

In [18]:
interactions_df[['user_id', 'item_id', 'total_dur', 'watched_pct']] = interactions_df[['user_id', 'item_id', 'total_dur', 'watched_pct']].astype(np.int32)
interactions_df.info(verbose=True)

<class 'pandas.core.frame.DataFrame'>
Index: 2646592 entries, 0 to 5476249
Data columns (total 5 columns):
 #   Column         Dtype 
---  ------         ----- 
 0   user_id        int32 
 1   item_id        int32 
 2   last_watch_dt  object
 3   total_dur      int32 
 4   watched_pct    int32 
dtypes: int32(4), object(1)
memory usage: 80.8+ MB


Поскольку `nn.Embedding` принимает количество треков `num_tracks`, и считает, что их индексы будут [0; `num_tracks`) (что сейчас не так), нужно перемаппить индексы треков (и юзеров заодно) в этот диапазон  

In [19]:
unique_user_ids = np.unique(np.concat([interactions_df["user_id"], users_df['user_id']]))
unique_item_ids = np.unique(np.concat([interactions_df["item_id"], items_df['item_id']]))

# Create mappings (old ID → new consecutive ID starting from 0)
user_id_map = {old_id: np.int32(new_id) for new_id, old_id in enumerate(unique_user_ids)}
user_id_map_to_orig = {np.int32(new_id): old_id for new_id, old_id in enumerate(unique_user_ids)}
item_id_map = {old_id: np.int32(new_id) for new_id, old_id in enumerate(unique_item_ids)}
item_id_map_to_orig = {np.int32(new_id): old_id for new_id, old_id in enumerate(unique_item_ids)}

In [20]:
len(user_id_map), len(item_id_map)

(879489, 15963)

In [21]:
users_df['user_id'] = users_df['user_id'].map(user_id_map)
items_df['item_id'] = items_df['item_id'].map(item_id_map)
interactions_df['user_id'] = interactions_df['user_id'].map(user_id_map)
interactions_df['item_id'] = interactions_df['item_id'].map(item_id_map)

users_df.set_index('user_id', inplace=True)
items_df.set_index('item_id', inplace=True)
interactions_df.head()

Unnamed: 0,user_id,item_id,last_watch_dt,total_dur,watched_pct
0,141413,9178,2021-05-11,4250,72
1,560236,1604,2021-05-29,8317,100
3,692633,7376,2021-07-05,14483,100
5,826986,6454,2021-05-13,11286,100
6,814480,343,2021-08-14,1672,25


### TODO: а нам вот это надо?

In [22]:
positives = interactions_df[interactions_df["watched_pct"] > 80].copy()
film_counts = positives.groupby("item_id").size()
films = set(film_counts[film_counts >= 5].index.values)

interactions_filt = positives[positives["item_id"].isin(films)]
del positives
len(interactions_filt), len(films)

(1491848, 6801)

In [23]:
interactions_filt

Unnamed: 0,user_id,item_id,last_watch_dt,total_dur,watched_pct
1,560236,1604,2021-05-29,8317,100
3,692633,7376,2021-07-05,14483,100
5,826986,6454,2021-05-13,11286,100
11,792263,7309,2021-07-07,6558,100
14,4271,8140,2021-04-18,6598,92
...,...,...,...,...,...
5476240,642768,496,2021-08-08,6990,100
5476241,860383,9585,2021-08-07,6425,97
5476242,215082,2958,2021-04-21,5752,98
5476244,351522,7560,2021-08-02,6804,100


In [24]:
triplets = interactions_filt[["user_id", "item_id"]]
del interactions_filt

In [25]:
NUM_NEGATIVE_SAMPLES = 10
triplets = pd.concat([triplets] * NUM_NEGATIVE_SAMPLES).sort_index().reset_index(drop=True)
triplets["film_neg"] = np.random.choice(items_df.index.unique(), len(triplets))
triplets = triplets.rename(columns={ "item_id": "film_pos"}).astype(np.int32)
triplets

Unnamed: 0,user_id,film_pos,film_neg
0,560236,1604,3177
1,560236,1604,10720
2,560236,1604,14144
3,560236,1604,12754
4,560236,1604,4022
...,...,...,...
14918475,308069,15647,8101
14918476,308069,15647,12111
14918477,308069,15647,384
14918478,308069,15647,6854


In [26]:
rdm = np.random.random(len(triplets))
rdm2 = np.random.random(len(triplets))           # use only 20% of data
train_data = triplets[(rdm < 0.8) &              (rdm2 < 1)]
val_data = triplets[(rdm >= 0.8) & (rdm < 0.9) & (rdm2 < 1)]
test_data = triplets[(rdm >= 0.9) &              (rdm2 < 1)]

len(train_data), len(val_data), len(test_data)

(11935155, 1491378, 1491947)

In [27]:
items_df.shape, interactions_df.shape

((15963, 124), (2646592, 5))

In [28]:
NO_MOVIE = len(item_id_map) # id bigger than any item id
NO_MOVIE

15963

Для каждого юзера выпишем по 30 его позитивных айтемов (если у него нет столько, добавим в конец специальные значения, обозначающие отсутствие фильма).

In [29]:
def pad_with_specific_value(lst, size, val):
    lst = list(set(lst))
    random.shuffle(lst)
    lst = lst[:size]
    return np.pad(lst, (0, size - len(lst)), 'constant', constant_values=val)

padded_users = triplets.groupby("user_id").apply(lambda x: (
    pad_with_specific_value(x['film_pos'].tolist(), 30, NO_MOVIE).tolist()
))
padded_users = pd.DataFrame({'interactions': padded_users.values}, index=padded_users.index)
padded_users

  padded_users = triplets.groupby("user_id").apply(lambda x: (


Unnamed: 0_level_0,interactions
user_id,Unnamed: 1_level_1
2,"[9016, 12006, 3413, 4317, 2844, 9764, 239, 568..."
3,"[937, 15175, 5986, 4000, 1961, 9879, 7968, 911..."
9,"[13387, 9393, 15963, 15963, 15963, 15963, 1596..."
11,"[14141, 4571, 15495, 2699, 14769, 9770, 6576, ..."
12,"[10839, 13933, 15040, 10096, 15963, 15963, 159..."
...,...
879454,"[2411, 9861, 7309, 11340, 15963, 15963, 15963,..."
879467,"[9651, 13933, 9393, 139, 15963, 15963, 15963, ..."
879476,"[7560, 14927, 6084, 15963, 15963, 15963, 15963..."
879478,"[139, 11909, 12539, 5535, 15963, 15963, 15963,..."


# Dataset

In [30]:
def collate_fn(data: list[tuple]):
    return data    

class DSSMDataset(td.Dataset):
    def __init__(self, triplets: pd.DataFrame):
        super().__init__()
        self.triplets = triplets
    
    def __getitem__(self, index: int) -> tuple[Tensor, Tensor, Tensor]:
        triplet = self.triplets.iloc[index]
        user_interractions = torch.tensor(padded_users.loc[triplet['user_id']]['interactions'], dtype=torch.int32)
        pos_films_features = torch.tensor(items_df.loc[triplet['film_pos']].values, dtype=torch.float32)
        neg_films_features = torch.tensor(items_df.loc[triplet['film_neg']].values, dtype=torch.float32)
        return user_interractions, pos_films_features, neg_films_features

    def __getitems__(self, index: tp.Sequence[int]) -> tuple[Tensor, Tensor, Tensor]:
        triplet = self.triplets.iloc[index]
        user_interractions = torch.tensor(padded_users.loc[triplet['user_id']]['interactions'].tolist(), dtype=torch.int32)
        pos_films_features = torch.tensor(items_df.loc[triplet['film_pos']].values, dtype=torch.float32)
        neg_films_features = torch.tensor(items_df.loc[triplet['film_neg']].values, dtype=torch.float32)
        return user_interractions, pos_films_features, neg_films_features
    
    def __len__(self):
        return len(self.triplets)

# Model

In [31]:
class ItemNet(nn.Module):
    def __init__(self,
                 dim_embedding: int,
                 dim_input: int,
                 dim_hidden: int = 64, # 32
                 activation: tp.Callable[[Tensor], Tensor] = nn.ReLU()
                 ) -> None:
        super().__init__()
        self.embedding_layer = nn.Linear(dim_input - 1, dim_hidden, bias=False)
        # self.dense_layer = nn.Linear(dim_hidden + 1, dim_embedding, bias=False)
        self.dense_block = nn.Sequential(
            nn.Linear(dim_hidden + 1, int(dim_embedding // 2)),
            nn.ReLU(),
            # nn.Linear(int(dim_embedding // 2), int(dim_embedding // 2)),
            # nn.ReLU(),
            nn.Linear(int(dim_embedding // 2), dim_embedding),
            activation
        )
        self.output_layer = nn.Linear(dim_embedding + dim_hidden, dim_embedding, bias=False)
        self.norm = nn.LayerNorm(dim_embedding)

    def forward(self, item_features: Tensor) -> Tensor:
        popularity = item_features[:, 0].view(-1, 1)
        genre_emb = self.embedding_layer(item_features[:, 1:])

        pop_genre = torch.concat([popularity, genre_emb], axis=1)
        features = self.dense_block(pop_genre)

        genre_features = torch.concat([genre_emb, features], axis=1)
        output = self.output_layer(genre_features)
        return self.norm(output)


class UserNet(nn.Module):
    def __init__(self,
                 dim_embedding: int,
                 num_items: int,
                 activation: tp.Callable[[Tensor], Tensor] = nn.ReLU()
                 ) -> None:                              # | +1 for the NO_MOVIE element
        super().__init__()                               # V
        self.track_embeddings = nn.EmbeddingBag(num_items + 1, dim_embedding, padding_idx=num_items)
        # self.dense_layer = nn.Linear(dim_embedding, dim_embedding, bias=False)
        self.dense_layer = nn.Sequential(
            nn.Linear(dim_embedding, int(dim_embedding // 2)),
            nn.ReLU(),
            # nn.Linear(int(dim_embedding // 2), int(dim_embedding // 2)),
            # nn.ReLU(),
            nn.Linear(int(dim_embedding // 2), dim_embedding),
            activation
        )
        self.output_layer = nn.Linear(dim_embedding + dim_embedding, dim_embedding, bias=False)
        self.norm = nn.LayerNorm(dim_embedding)
        self.num_items = num_items

    def forward(self, user_ids: Tensor) -> Tensor:
        # print(f'EMBEDDING BAG MAX INPUT: {user_ids.max()} while was ready for {self.num_items}')
        interactions_emb = self.track_embeddings(user_ids)
        features = self.dense_layer(interactions_emb)
        x = torch.concat([interactions_emb, features], axis=1)
        output = self.output_layer(x)
        return self.norm(output)

In [32]:
class DSSM(nn.Module):
    def __init__(self,
                 dim_item_features: int,
                 num_items: int,
                 embedding_dim: int = 100,
                 lr: float = 1e-3,
                 triplet_loss_margin: float = 0.4,
                 weight_decay: float = 1e-3,
                 log_to_prog_bar: bool = True,
                 ) -> None:
        super().__init__()
        self.lr = lr
        self.triplet_loss_margin = triplet_loss_margin
        self.weight_decay = weight_decay
        self.log_to_prog_bar = log_to_prog_bar
        self.item_net = ItemNet(embedding_dim, dim_item_features)
        self.user_net = UserNet(embedding_dim, num_items)

    def forward(self,
                user_ids: Tensor,
                item_features_pos: Tensor,
                item_features_neg: Tensor,
                ) -> tuple[Tensor, Tensor, Tensor]:
        anchor = self.user_net(user_ids)
        pos = self.item_net(item_features_pos)
        neg = self.item_net(item_features_neg)

        return anchor, pos, neg

    # def _step(self, batch, batch_idx, metric, prog_bar=False):
    #     user_ids, pos, neg = batch
    #     anchor, positive, negative = self(user_ids, pos, neg)
    #     loss = F.triplet_margin_loss(anchor, positive, negative, margin=self.triplet_loss_margin)
    #     self.log(metric, loss, prog_bar=prog_bar)
    #     return loss

    # def training_step(self, batch: tp.Sequence[Tensor], batch_idx: int) -> Tensor:
    #     return self._step(batch, batch_idx, "train_loss")

    # def validation_step(self, batch: tp.Sequence[Tensor], batch_idx: int) -> Tensor:
    #     return self._step(batch, batch_idx, "val_loss", self.log_to_prog_bar)

    # def test_step(self, batch, batch_idx, prog_bar=False):
    #     return self._step(batch, batch_idx, "test_loss", self.log_to_prog_bar)

    # def inference(self, dataloader: td.DataLoader[tp.Any], mode: str = "item") -> np.ndarray:
    #     if mode == "user":
    #       model = self.user_net
    #     elif mode == "item":
    #       model = self.item_net
    #     else:
    #       raise ValueError(f"Unsupported mode {mode}!")

    #     self.eval()

    #     batches = []
    #     user_ids = []
    #     for ids, features in dataloader:
    #         with torch.no_grad():
    #             batch_embeddings = model(features.to(self.device))
    #         batches.append(batch_embeddings)
    #         user_ids.append(ids)
    #     vectors = torch.cat(batches, dim=0).cpu().numpy()
    #     vectors_ids = torch.cat(user_ids, dim=0).cpu().numpy()
    #     return vectors_ids, vectors

    # def configure_optimizers(self):
    #     optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
    #     # lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, threshold=2.5e-2, threshold_mode='rel')
    #     lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, threshold=2e-3, threshold_mode='abs')
    #     scheduler = {
    #         'scheduler': lr_scheduler,
    #         'monitor': 'val_loss'
    #     }
    #     return [optimizer], [scheduler]

# Training

In [36]:
# Hyperparams
EXPERIMENT_NAME = 'reelsrec_dssm_1'
EPOCHS = 10
BATCH_SIZE = 16384
NUM_WORKERS=8
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'

EMBEDDING_DIM = 96
LR = 5e-3
WEIGHT_DECAY = 1e-5
TRIPLET_LOSS_MARGIN = 0.4


train_dataset = DSSMDataset(train_data)
train_dataloader = td.DataLoader(train_dataset, collate_fn=collate_fn, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

val_dataset = DSSMDataset(val_data)
val_dataloader = td.DataLoader(val_dataset, collate_fn=collate_fn, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

test_dataset = DSSMDataset(test_data)
test_dataloader = td.DataLoader(test_dataset, collate_fn=collate_fn, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

model = DSSM(dim_item_features=items_df.shape[1],
             num_items=len(item_id_map),
             embedding_dim=EMBEDDING_DIM)

optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

model.to(DEVICE)
model.train()
for epoch in trange(EPOCHS, position=0, desc='Training', unit='epoch'):
    losses = []
    for batch in tqdm(train_dataloader, position=1, desc=f'Epoch {epoch + 1}/{EPOCHS}', unit='batch'):
        user_inters, pos, neg = batch
        optimizer.zero_grad()
        anchor, positive, negative = model(user_inters.to(DEVICE),
                                           pos.to(DEVICE),
                                           neg.to(DEVICE))
        loss = F.triplet_margin_loss(anchor, positive, negative, margin=TRIPLET_LOSS_MARGIN)
        loss.backward()
        losses.append(loss.item())

        optimizer.step()
    
    # TODO: валидация
    # for batch in tqdm(val_dataloader, position=1, desc=f'Validation', unit='batch'):
    #     user_inters, pos, neg = batch
    #     user_inters = user_inters.to(DEVICE)
    #     pos = pos.to(DEVICE)
    #     neg = neg.to(DEVICE)
    #     losses = []
    #     optimizer.zero_grad()
    #     anchor, positive, negative = model(user_inters, pos, neg)
    #     loss = F.triplet_margin_loss(anchor, positive, negative, margin=TRIPLET_LOSS_MARGIN)
    #     loss.backward()
    #     losses.append(loss.item())

    #     optimizer.step()

    tqdm.write(f'Average train loss: {sum(losses) / len(losses):.4f}')

Training:   0%|          | 0/10 [00:00<?, ?epoch/s]

Epoch 1/10:   0%|          | 0/729 [00:00<?, ?batch/s]

Average train loss: 0.0561


Epoch 2/10:   0%|          | 0/729 [00:00<?, ?batch/s]

Average train loss: 0.0276


Epoch 3/10:   0%|          | 0/729 [00:00<?, ?batch/s]

Average train loss: 0.0239


Epoch 4/10:   0%|          | 0/729 [00:00<?, ?batch/s]

Average train loss: 0.0222


Epoch 5/10:   0%|          | 0/729 [00:00<?, ?batch/s]

Average train loss: 0.0212


Epoch 6/10:   0%|          | 0/729 [00:00<?, ?batch/s]

KeyboardInterrupt: 

In [None]:
trainer.fit(
    model,
    data_module, 
    # ckpt_path='/home/serg_fedchn/Homework/6_semester/RecSys/recsys-course-spring-2025/jupyter/lightning_logs/version_13/checkpoints/epoch=32-step=63393.ckpt'
)

NameError: name 'trainer' is not defined

In [None]:
checkpoint_callback.best_model_path

'/home/serg_fedchn/Homework/6_semester/DL/reelrecs/lightning_logs/version_11/checkpoints/epoch=11-step=14004.ckpt'

In [None]:
shutil.move(checkpoint_callback.best_model_path, f"/home/serg_fedchn/Homework/6_semester/DL/reelrecs/model_weights/{EXPERIMENT_NAME}.ckpt")

'/home/serg_fedchn/Homework/6_semester/DL/reelrecs/model_weights/dssm_layernorm_params_arch_5.ckpt'

In [None]:
best_dssm = DSSM.load_from_checkpoint(f"/home/serg_fedchn/Homework/6_semester/DL/reelrecs/model_weights/{EXPERIMENT_NAME}.ckpt",
                                      dim_item_features=items_df.shape[1],
                                      num_items=len(item_id_map),
                                      embedding_dim=EMBEDDING_DIM)

NameError: name 'items_ohe_df' is not defined