In [None]:
!pip install pytorch-lightning==1.1.0rc1

In [None]:
!wget http://files.grouplens.org/datasets/movielens/ml-1m.zip
!unzip ml-1m.zip

In [3]:
import pandas as pd
import numpy as np
import scipy.sparse as sp
import multiprocessing
from collections import namedtuple
# from lightfm import LightFM
from IPython.display import display

from sklearn.metrics.pairwise import cosine_similarity
from sklearn.model_selection import train_test_split
from sklearn.metrics import ndcg_score

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

import pytorch_lightning as pl
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint

Данные из 1ого дз: movie lens 1m. Из explicit делаем implicit данные, считая что 4-5 explicit = 1 implicit. 
<br>
Из https://arxiv.org/abs/1708.05031 :
* В качестве метрик используем метрики: Hit Ratio (HR) и Normalized Discounted Cumulative Gain (NDCG) все @k. HR@k дает скор в зависимости от того, есть ли positive айтем в top-k. Чем ближе positive айтем к 1-ому месту в полученном ранжировании, тем больший скор дает NDCG@k. Один семпл: 1 pos + 99 neg айтемов, метрики считаем по отранжированному списку длины 10.
* Train, val датасеты составляем, так, чтобы на каждый positive айтем приходилось 4 negative. В test выделяем 1 последний positive айтем для каждого пользователя.

####  Reading data, explicit -> implicit

In [8]:
ratings = pd.read_csv('ml-1m/ratings.dat', delimiter='::', header=None, 
        names=['user_id', 'movie_id', 'rating', 'timestamp'], 
        usecols=['user_id', 'movie_id', 'rating'], engine='python')

movie_info = pd.read_csv('ml-1m/movies.dat', delimiter='::', header=None, 
        names=['movie_id', 'name', 'category'], engine='python')

ratings = ratings.loc[(ratings['rating'] >= 4)]

#### Sparse matrix

In [4]:
def to_csr(data, rows, cols):
    user_item = sp.coo_matrix((data, (rows, cols)))
    user_item_csr = user_item.tocsr()
    return user_item_csr

In [9]:
users = ratings['user_id']
movies = ratings['movie_id']
ratings_sparse = to_csr(np.ones_like(users), users, movies)

#### Dict {user_id: (pos_ids, neg_ids)}

In [10]:
def pos_neg_dataset(ratings):
    user_ids = np.unique(users.values)
    movie_ids = np.unique(movies.values)
    User = namedtuple('User', 'pos_train, pos_test, neg')
    Ds = dict()
    
    for u in user_ids:
        user_rating = ratings[u]
        rated = list(user_rating.indices)
        if len(rated) > 1:
            not_rated = list(np.setdiff1d(movie_ids, rated))
            user = User(rated[:-1], rated[-1], not_rated)
            Ds[u] = user
    return Ds

In [11]:
Ds = pos_neg_dataset(ratings_sparse)

#### Sparce train matrix, test set 

In [12]:
user_train = np.hstack([np.full(len(Ds[user_id].pos_train), user_id) for user_id in Ds.keys()])
user_test = np.hstack([np.full(1, user_id) for user_id in Ds.keys()])

item_train = np.hstack([np.array(Ds[user_id].pos_train) for user_id in Ds.keys()])
item_test = np.array([Ds[user_id].pos_test for user_id in Ds.keys()])

ratings_train_sparse = to_csr(np.ones_like(user_train), user_train, item_train)

#### Recommender + metrics

In [27]:
def title_by_id(id):
    return movie_info[movie_info['movie_id'] == id]['name'].item()

def category_by_id(id):
    return movie_info[movie_info['movie_id'] == id]['category'].item()

def rec_print(title_list, category_list):
    display(pd.DataFrame(zip(title_list, category_list), columns=['Title', 'Category']))
    
    
class Recommender:
    def __init__(self, user_embs, item_embs, user_test_ds, item_test_ds):
        #global vars
        self.ratings = ratings_sparse
        self.Ds = Ds
        
        self.user_test = user_test_ds
        self.item_test = item_test_ds
        self.user_embs = user_embs
        self.item_embs = item_embs
        self.pred_ratings = self.user_embs @ self.item_embs.T
               
    def similar_movies(self, movie_id=1, k=10):
        movie_id = 1
        title = title_by_id(movie_id)
        category = category_by_id(movie_id)
        print('For movie')
        rec_print([title], [category])
        print('similars are')

        score = cosine_similarity(np.expand_dims(self.item_embs[movie_id], axis=0), self.item_embs)[0]
        ranked_similars = np.argsort(score)[::-1]
        top_k_ids = [s for s in ranked_similars if s != movie_id and s in movie_info['movie_id'].values][:k]
        top_k_titles = [title_by_id(r) for r in top_k_ids]
        top_k_categories = [category_by_id(r) for r in top_k_ids]
        rec_print(top_k_titles, top_k_categories)
        
    def prediction_for_user(self, user_id=4):
        real_movie_id = self.ratings[user_id].indices   
        print('--- User\'s choice ---')
        real_titles = [title_by_id(r) for r in real_movie_id[:10]]
        real_category = [category_by_id(r) for r in real_movie_id[:10]]
        rec_print(real_titles, real_category)
        print()

        print('--- Our recommendations ---')
        user_pred = self.pred_ratings[user_id]
        ranked_movie_id = np.argsort(user_pred)[::-1]
        not_rated_movie_id = [i for i in ranked_movie_id if i not in real_movie_id and i in movie_info['movie_id'].values]
        pred_titles = [title_by_id(r) for r in not_rated_movie_id[:10]]
        pred_category = [category_by_id(r) for r in not_rated_movie_id[:10]]
        rec_print(pred_titles, pred_category)
    
    def one_sample_metric(self, user_id, pos_item, score_fn, preprocess_user, preprocess_item, k=10):
        np.random.shuffle(self.Ds[user_id].neg)
        neg_items = self.Ds[user_id].neg[:99]
        items = np.array([pos_item, *neg_items])
        y_target = np.array([1, *[0] * 99])
        perm = np.random.permutation(len(items))
        items = items[perm]
        y_target = y_target[perm]
        if preprocess_user:
            user_id = preprocess_user(user_id)
        if preprocess_item:
            items = preprocess_item(items)
        y_pred = score_fn(user_id, items)

        if preprocess_item:
            items = items.detach().cpu().numpy()
            y_pred = y_pred.detach().cpu().numpy().reshape(-1, )
        top_k_items = items[np.argsort(y_pred)[-k:]]
        hr = int(pos_item in top_k_items)
        ndcg = ndcg_score([y_target], [y_pred], k=k)
        return hr, ndcg

    def compute_metrics(self, score_fn, preprocess_user=None, preprocess_item=None):
        hr_list = []
        ndcg_list = []
        for u, i in zip(self.user_test, self.item_test):
            hr, ndcg = self.one_sample_metric(u, i, score_fn, preprocess_user, preprocess_item)
            hr_list.append(hr)
            ndcg_list.append(ndcg)
        print('Avg hr: {0:.3f}'.format(np.mean(hr_list)))
        print('Avg ndcg: {0:.3f}'.format(np.mean(ndcg_list)))

## 1. WARP

In [None]:
warp = LightFM(no_components=64, loss='warp', max_sampled=200)
warp.fit_partial(ratings_sparse, epochs=100)
U, V = warp.user_embeddings, warp.item_embeddings
warp_recommender = Recommender(U, V, user_test, item_test)
warp_recommender.compute_metrics(warp.predict)

Avg hr: 0.915
Avg ndcg: 0.729


#### похожие фильмы

In [None]:
warp_recommender.similar_movies()

For movie


Unnamed: 0,Title,Category
0,Toy Story (1995),Animation|Children's|Comedy


similars are


Unnamed: 0,Title,Category
0,Toy Story 2 (1999),Animation|Children's|Comedy
1,"Bug's Life, A (1998)",Animation|Children's|Comedy
2,"Lion King, The (1994)",Animation|Children's|Musical
3,Hercules (1997),Adventure|Animation|Children's|Comedy|Musical
4,Aladdin (1992),Animation|Children's|Comedy|Musical
5,Babe (1995),Children's|Comedy|Drama
6,Tarzan (1999),Animation|Children's
7,"Iron Giant, The (1999)",Animation|Children's
8,Mulan (1998),Animation|Children's
9,Antz (1998),Animation|Children's


#### предсказания для пользователя

In [None]:
warp_recommender.prediction_for_user()

--- User's choice ---


Unnamed: 0,Title,Category
0,Star Wars: Episode IV - A New Hope (1977),Action|Adventure|Fantasy|Sci-Fi
1,Jurassic Park (1993),Action|Adventure|Sci-Fi
2,Die Hard (1988),Action|Thriller
3,E.T. the Extra-Terrestrial (1982),Children's|Drama|Fantasy|Sci-Fi
4,Raiders of the Lost Ark (1981),Action|Adventure
5,"Good, The Bad and The Ugly, The (1966)",Action|Western
6,Alien (1979),Action|Horror|Sci-Fi|Thriller
7,"Terminator, The (1984)",Action|Sci-Fi|Thriller
8,Jaws (1975),Action|Horror
9,Rocky (1976),Action|Drama



--- Our recommendations ---


Unnamed: 0,Title,Category
0,Ben-Hur (1959),Action|Adventure|Drama
1,For a Few Dollars More (1965),Western
2,Dr. No (1962),Action
3,Close Encounters of the Third Kind (1977),Drama|Sci-Fi
4,"Graduate, The (1967)",Drama|Romance
5,Deliverance (1972),Adventure|Thriller
6,Butch Cassidy and the Sundance Kid (1969),Action|Comedy|Western
7,Top Gun (1986),Action|Romance
8,"Searchers, The (1956)",Western
9,Predator (1987),Action|Sci-Fi|Thriller


Рекомендации хорошие, симилары нормальные, метрики выглядят большими.

## 2.  NCF

NCG = Generalized Matrix Factorization(GMF) + MLP. В GMF elementwise перемножаются gmf эмбеддинги юзера и айтема.
По полученному эмбеддингу делают предсказание. В MLP конкатенируются mlp эмбеддинги юзера и айтема и пропускаются
через нейросеть. GMF и MLP предобучаются отдельно. В NCF эмбеддинги пропускаются через соответствующие модели 
и общее предсказание делается по сконкатенированному выходу из этих моделей.

In [None]:
class GMF(nn.Module):
    def __init__(self, n_users, n_items, embed_dim):
        super().__init__()
        self.user_embedder = nn.Embedding(n_users, embed_dim)
        self.item_embedder = nn.Embedding(n_items, embed_dim)
        self.out = nn.Sequential(nn.Linear(embed_dim, 1),
                                 nn.Sigmoid())
        
    def forward(self, user_idx, item_idx, return_embs=False):
        user_embs = self.user_embedder(user_idx)
        item_embs = self.item_embedder(item_idx)
        el_prod = torch.mul(user_embs, item_embs)
        if return_embs:
            return el_prod
        y = self.out(el_prod)
        return y
    
        
class MLP(nn.Module):
    def __init__(self, n_users, n_items, embed_dim, hidden_dims):
        super().__init__()
        assert 2 * embed_dim == hidden_dims[0]
        self.user_embedder = nn.Embedding(n_users, embed_dim)
        self.item_embedder = nn.Embedding(n_items, embed_dim)
        
        self.net = []
        all_dims = [2 * embed_dim, *hidden_dims] 
        for i, dim in enumerate(all_dims[:-1]):
            self.net.extend([nn.Linear(all_dims[i], all_dims[i + 1]), nn.ReLU()])
        self.net = nn.Sequential(*self.net)
        
        self.out = nn.Sequential(nn.Linear(all_dims[-1], 1),
                                 nn.Sigmoid())
        
    def forward(self, user_idx, item_idx, return_embs=False):
        user_embs = self.user_embedder(user_idx)
        item_embs = self.item_embedder(item_idx)
        embs = torch.cat([user_embs, item_embs], dim=-1) 
        embs = self.net(embs)
        if return_embs:
            return embs
        y = self.out(embs)
        return y
            
        
class NCF(nn.Module):
    def __init__(self, gmf_params, mlp_params, alpha, gmf_ckpt=None, mlp_ckpt=None):
        super().__init__()
        self.gmf = GMF(**gmf_params)
        self.mlp = MLP(**mlp_params)
        hidden_dim = gmf_params['embed_dim'] + mlp_params['hidden_dims'][-1]
        self.out = nn.Sequential(nn.Linear(hidden_dim, 1),
                                 nn.Sigmoid())
        self.alpha = alpha
        
        if gmf_ckpt and mlp_ckpt:
            self.load_weights(gmf_ckpt, mlp_ckpt) 
            
    def forward(self, user_idx, item_idx):
        emb_gmf = self.gmf(user_idx, item_idx, return_embs=True).squeeze(1)
        emb_mlp = self.mlp(user_idx, item_idx, return_embs=True)
        emb = torch.cat([emb_gmf, emb_mlp], dim=-1)   
        y = self.out(emb)
        return y  

    def load_weights(self, gmf_ckpt, mlp_ckpt):
        gmf_state_dict = torch.load(gmf_ckpt)['state_dict']
        mlp_state_dict = torch.load(mlp_ckpt)['state_dict']

        for key in list(gmf_state_dict.keys()):
            new_key = key[key.index('.') + 1:]
            self.gmf.state_dict()[new_key] = gmf_state_dict.pop(key) 

        for key in list(mlp_state_dict.keys()):
            new_key = key[key.index('.') + 1:]
            self.mlp.state_dict()[new_key] = mlp_state_dict.pop(key) 

        last_layer_weight = torch.cat([self.alpha * self.gmf.out[0].weight.data, 
                                    (1 - self.alpha) * self.mlp.out[0].weight.data], dim=-1)
        last_layer_bias = self.alpha * self.gmf.out[0].bias.data + (1 - self.alpha) * self.mlp.out[0].bias.data
        with torch.no_grad():
            self.out[0].weight.copy_(last_layer_weight)
            self.out[0].bias.copy_(last_layer_bias) 

#### Data preprocessing

In [None]:
user_pos_train = torch.LongTensor(user_train)
item_pos_train = torch.LongTensor(item_train)

# Train & val: 1 pos + 4 neg
neg_len = [min(len(Ds[user_id].neg), 4 * len(Ds[user_id].pos_train)) for user_id in Ds.keys()]
user_neg_train = [torch.full((neg_len[i], ), user_id) for i, user_id in enumerate(Ds.keys())]
item_neg_train = []
for i, user_id in enumerate(Ds.keys()):
    np.random.shuffle(Ds[user_id].neg)
    item_neg_train.append(torch.tensor(Ds[user_id].neg)[:neg_len[i]])
    
# Dataloaders for train & val: 85% of left data - train, 15% of left data - val
user_neg_train = torch.cat(user_neg_train, dim=0)
item_neg_train = torch.cat(item_neg_train, dim=0)
y_pos = torch.ones_like(user_pos_train)
y_neg = torch.zeros_like(user_neg_train)

user_train = torch.cat([user_pos_train, user_neg_train], dim=0)
item_train = torch.cat([item_pos_train, item_neg_train], dim=0)
y = torch.cat([y_pos, y_neg], dim=0)

random_sample = torch.randperm(len(user_train))
user_train = user_train[random_sample]
item_train = item_train[random_sample]
y = y[random_sample]

user_train, user_val, item_train, item_val, y_train, y_val = train_test_split(user_train, 
                                                                              item_train, y,
                                                                              test_size=0.15)
train_dataset = TensorDataset(user_train, item_train, y_train)
val_dataset = TensorDataset(user_val, item_val, y_val)

train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True, 
                          num_workers=multiprocessing.cpu_count())
val_loader = DataLoader(val_dataset, batch_size=1024, num_workers=multiprocessing.cpu_count())

In [None]:
class Learner(pl.LightningModule):
    def __init__(self, model, params, lr=1e-3, weight_decay=0):
        super().__init__()
        self.model = model(**params)
        self.lr = lr 
        self.weight_decay = weight_decay
        self.loss_fn = nn.BCELoss()  
        
    def forward(self, user_idx, item_idx):
        return self.model(user_idx, item_idx)
    
    def training_step(self, batch, *args):
        user_idx, item_idx, y_target = batch
        y_pred = self(user_idx, item_idx)
        loss = self.loss_fn(y_pred, y_target.unsqueeze(1).to(torch.float32))
        return {'loss': loss}      
    
    def validation_step(self, batch, *args):
        user_idx, item_idx, y_target = batch
        y_pred = self(user_idx, item_idx)
        loss = self.loss_fn(y_pred, y_target.unsqueeze(1).to(torch.float32))
        return {'val_loss': loss}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        return {'val_loss': avg_loss}
        
    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)

In [None]:
n_users = np.max(users) + 1
n_items = np.max(movies) + 1

gmf_params = {'n_users': n_users,
              'n_items': n_items,
              'embed_dim': 32}

mlp_params = {'n_users': n_users,
              'n_items': n_items,
              'embed_dim': 32,
              'hidden_dims': [64, 32, 16, 32]}

In [None]:
# pre-train GMF
gmf_checkpoint = ModelCheckpoint(filepath='gmf_checkpoints/gmf_{epoch}',
                                 save_top_k=3, monitor='val_loss', save_weights_only=True, verbose=True)
num_gpus = torch.cuda.device_count()

gmf = Learner(GMF, gmf_params)
trainer = pl.Trainer(gpus=num_gpus, max_epochs=50, checkpoint_callback=gmf_checkpoint, progress_bar_refresh_rate=30)
trainer.fit(gmf, train_loader, val_loader)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type    | Params
------------------------------------
0 | model   | GMF     | 319 K 
1 | loss_fn | BCELoss | 0     
------------------------------------
319 K     Trainable params
0         Non-trainable params
319 K     Total params


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…





HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 0, global step 2351: val_loss reached 0.50224 (best 0.50224), saving model to "/content/gmf_checkpoints/gmf_epoch=0.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 1, global step 4703: val_loss reached 0.50160 (best 0.50160), saving model to "/content/gmf_checkpoints/gmf_epoch=1.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 2, global step 7055: val_loss reached 0.50146 (best 0.50146), saving model to "/content/gmf_checkpoints/gmf_epoch=2.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 3, global step 9407: val_loss reached 0.49368 (best 0.49368), saving model to "/content/gmf_checkpoints/gmf_epoch=3.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 4, global step 11759: val_loss reached 0.43985 (best 0.43985), saving model to "/content/gmf_checkpoints/gmf_epoch=4.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 5, global step 14111: val_loss reached 0.38929 (best 0.38929), saving model to "/content/gmf_checkpoints/gmf_epoch=5.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 6, global step 16463: val_loss reached 0.36618 (best 0.36618), saving model to "/content/gmf_checkpoints/gmf_epoch=6.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 7, global step 18815: val_loss reached 0.35626 (best 0.35626), saving model to "/content/gmf_checkpoints/gmf_epoch=7.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 8, global step 21167: val_loss reached 0.35077 (best 0.35077), saving model to "/content/gmf_checkpoints/gmf_epoch=8.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 9, global step 23519: val_loss reached 0.34707 (best 0.34707), saving model to "/content/gmf_checkpoints/gmf_epoch=9.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 10, global step 25871: val_loss reached 0.34331 (best 0.34331), saving model to "/content/gmf_checkpoints/gmf_epoch=10.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 11, global step 28223: val_loss reached 0.33946 (best 0.33946), saving model to "/content/gmf_checkpoints/gmf_epoch=11.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 12, global step 30575: val_loss reached 0.33537 (best 0.33537), saving model to "/content/gmf_checkpoints/gmf_epoch=12.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 13, global step 32927: val_loss reached 0.33218 (best 0.33218), saving model to "/content/gmf_checkpoints/gmf_epoch=13.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 14, global step 35279: val_loss reached 0.32915 (best 0.32915), saving model to "/content/gmf_checkpoints/gmf_epoch=14.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 15, global step 37631: val_loss reached 0.32725 (best 0.32725), saving model to "/content/gmf_checkpoints/gmf_epoch=15.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 16, global step 39983: val_loss reached 0.32454 (best 0.32454), saving model to "/content/gmf_checkpoints/gmf_epoch=16.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 17, global step 42335: val_loss reached 0.32279 (best 0.32279), saving model to "/content/gmf_checkpoints/gmf_epoch=17.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 18, global step 44687: val_loss reached 0.32074 (best 0.32074), saving model to "/content/gmf_checkpoints/gmf_epoch=18.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 19, global step 47039: val_loss reached 0.31925 (best 0.31925), saving model to "/content/gmf_checkpoints/gmf_epoch=19.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 20, global step 49391: val_loss reached 0.31883 (best 0.31883), saving model to "/content/gmf_checkpoints/gmf_epoch=20.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 21, global step 51743: val_loss reached 0.31728 (best 0.31728), saving model to "/content/gmf_checkpoints/gmf_epoch=21.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 22, global step 54095: val_loss reached 0.31661 (best 0.31661), saving model to "/content/gmf_checkpoints/gmf_epoch=22.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 23, global step 56447: val_loss reached 0.31654 (best 0.31654), saving model to "/content/gmf_checkpoints/gmf_epoch=23.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 24, global step 58799: val_loss reached 0.31592 (best 0.31592), saving model to "/content/gmf_checkpoints/gmf_epoch=24.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 25, global step 61151: val_loss reached 0.31604 (best 0.31592), saving model to "/content/gmf_checkpoints/gmf_epoch=25.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 26, global step 63503: val_loss reached 0.31609 (best 0.31592), saving model to "/content/gmf_checkpoints/gmf_epoch=26.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 27, global step 65855: val_loss reached 0.31606 (best 0.31592), saving model to "/content/gmf_checkpoints/gmf_epoch=27.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 28, step 68207: val_loss was not in top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 29, step 70559: val_loss was not in top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 30, step 72911: val_loss was not in top 3
Epoch 31, step 73128: val_loss was not in top 3





1

In [None]:
# pre-train MLP
mlp_checkpoint = ModelCheckpoint(filepath='mlp_checkpoints/mlp_{epoch}',
                                 save_top_k=3, monitor='val_loss', save_weights_only=True, verbose=True)
num_gpus = torch.cuda.device_count()

mlp = Learner(MLP, mlp_params)
trainer = pl.Trainer(gpus=num_gpus, max_epochs=30, checkpoint_callback=mlp_checkpoint, progress_bar_refresh_rate=30)
trainer.fit(mlp, train_loader, val_loader)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type    | Params
------------------------------------
0 | model   | MLP     | 327 K 
1 | loss_fn | BCELoss | 0     
------------------------------------
327 K     Trainable params
0         Non-trainable params
327 K     Total params


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…





HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 0, global step 2351: val_loss reached 0.34526 (best 0.34526), saving model to "/content/mlp_checkpoints/mlp_epoch=0.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 1, global step 4703: val_loss reached 0.34214 (best 0.34214), saving model to "/content/mlp_checkpoints/mlp_epoch=1.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 2, global step 7055: val_loss reached 0.33991 (best 0.33991), saving model to "/content/mlp_checkpoints/mlp_epoch=2.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 3, global step 9407: val_loss reached 0.33927 (best 0.33927), saving model to "/content/mlp_checkpoints/mlp_epoch=3.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 4, global step 11759: val_loss reached 0.33881 (best 0.33881), saving model to "/content/mlp_checkpoints/mlp_epoch=4.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 5, global step 14111: val_loss reached 0.33847 (best 0.33847), saving model to "/content/mlp_checkpoints/mlp_epoch=5.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 6, global step 16463: val_loss reached 0.33806 (best 0.33806), saving model to "/content/mlp_checkpoints/mlp_epoch=6.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 7, global step 18815: val_loss reached 0.33802 (best 0.33802), saving model to "/content/mlp_checkpoints/mlp_epoch=7.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 8, global step 21167: val_loss reached 0.33823 (best 0.33802), saving model to "/content/mlp_checkpoints/mlp_epoch=8.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 9, global step 23519: val_loss reached 0.33552 (best 0.33552), saving model to "/content/mlp_checkpoints/mlp_epoch=9.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 10, global step 25871: val_loss reached 0.33172 (best 0.33172), saving model to "/content/mlp_checkpoints/mlp_epoch=10.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 11, global step 28223: val_loss reached 0.32783 (best 0.32783), saving model to "/content/mlp_checkpoints/mlp_epoch=11.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 12, global step 30575: val_loss reached 0.32370 (best 0.32370), saving model to "/content/mlp_checkpoints/mlp_epoch=12.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 13, global step 32927: val_loss reached 0.32062 (best 0.32062), saving model to "/content/mlp_checkpoints/mlp_epoch=13.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 14, global step 35279: val_loss reached 0.31945 (best 0.31945), saving model to "/content/mlp_checkpoints/mlp_epoch=14.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 15, global step 37631: val_loss reached 0.31792 (best 0.31792), saving model to "/content/mlp_checkpoints/mlp_epoch=15.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 16, global step 39983: val_loss reached 0.31594 (best 0.31594), saving model to "/content/mlp_checkpoints/mlp_epoch=16.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 17, global step 42335: val_loss reached 0.31585 (best 0.31585), saving model to "/content/mlp_checkpoints/mlp_epoch=17.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 18, global step 44687: val_loss reached 0.31567 (best 0.31567), saving model to "/content/mlp_checkpoints/mlp_epoch=18.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 19, global step 47039: val_loss reached 0.31373 (best 0.31373), saving model to "/content/mlp_checkpoints/mlp_epoch=19.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 20, global step 49391: val_loss reached 0.31372 (best 0.31372), saving model to "/content/mlp_checkpoints/mlp_epoch=20.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 21, global step 51743: val_loss reached 0.31400 (best 0.31372), saving model to "/content/mlp_checkpoints/mlp_epoch=21.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 22, global step 54095: val_loss reached 0.31314 (best 0.31314), saving model to "/content/mlp_checkpoints/mlp_epoch=22.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 23, global step 56447: val_loss reached 0.31357 (best 0.31314), saving model to "/content/mlp_checkpoints/mlp_epoch=23.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 24, global step 58799: val_loss reached 0.31336 (best 0.31314), saving model to "/content/mlp_checkpoints/mlp_epoch=24.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 25, global step 61151: val_loss reached 0.31297 (best 0.31297), saving model to "/content/mlp_checkpoints/mlp_epoch=25.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 26, step 63503: val_loss was not in top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 27, step 65855: val_loss was not in top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 28, step 68207: val_loss was not in top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 29, step 70559: val_loss was not in top 3





1

In [None]:
# train NCF
ncf_params = {'gmf_params': gmf_params,
              'mlp_params': mlp_params,
              'alpha': 0.5, 
              'gmf_ckpt': '/content/gmf_checkpoints/gmf_epoch=24.ckpt',
              'mlp_ckpt': '/content/mlp_checkpoints/mlp_epoch=25.ckpt'}

ncf_checkpoint = ModelCheckpoint(filepath='ncf_checkpoints/ncf_{epoch}',
                                 save_top_k=3, monitor='val_loss', verbose=True)

num_gpus = torch.cuda.device_count()

ncf = Learner(NCF, ncf_params)
trainer = pl.Trainer(gpus=num_gpus, max_epochs=10, checkpoint_callback=ncf_checkpoint, progress_bar_refresh_rate=30)
trainer.fit(ncf, train_loader, val_loader)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type    | Params
------------------------------------
0 | model   | NCF     | 647 K 
1 | loss_fn | BCELoss | 0     
------------------------------------
647 K     Trainable params
0         Non-trainable params
647 K     Total params


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…





HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 0, global step 2351: val_loss reached 0.34559 (best 0.34559), saving model to "/content/ncf_checkpoints/ncf_epoch=0.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 1, global step 4703: val_loss reached 0.34266 (best 0.34266), saving model to "/content/ncf_checkpoints/ncf_epoch=1.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 2, global step 7055: val_loss reached 0.34069 (best 0.34069), saving model to "/content/ncf_checkpoints/ncf_epoch=2-v0.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 3, global step 9407: val_loss reached 0.34010 (best 0.34010), saving model to "/content/ncf_checkpoints/ncf_epoch=3-v0.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 4, global step 11759: val_loss reached 0.34053 (best 0.34010), saving model to "/content/ncf_checkpoints/ncf_epoch=4-v0.ckpt" as top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 5, step 14111: val_loss was not in top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 6, step 16463: val_loss was not in top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 7, step 18815: val_loss was not in top 3


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 8, step 21167: val_loss was not in top 3
Epoch 9, step 22754: val_loss was not in top 3





1

In [None]:
pretrained_ncf = ncf.model
item_embs_gmf = pretrained_ncf.gmf.item_embedder.weight.detach()
item_embs_mlp = pretrained_ncf.mlp.item_embedder.weight.detach()
user_embs_gmf = pretrained_ncf.gmf.user_embedder.weight.detach()
user_embs_mlp = pretrained_ncf.mlp.user_embedder.weight.detach()

item_embs = torch.cat([item_embs_gmf, item_embs_mlp], dim=-1).numpy()
user_embs = torch.cat([user_embs_gmf, user_embs_mlp], dim=-1).numpy()

In [None]:
ncf_recommender = Recommender(user_embs, item_embs, user_test, item_test)
preprocess_user = lambda x: torch.full((100,), x)
preprocess_item = lambda x: torch.LongTensor(x).view(len(x))
ncf_recommender.compute_metrics(score_fn=ncf.forward, preprocess_user=preprocess_user, preprocess_item=preprocess_item)

Avg hr: 0.291
Avg ndcg: 0.135


In [None]:
ncf_recommender.similar_movies()

For movie


Unnamed: 0,Title,Category
0,Toy Story (1995),Animation|Children's|Comedy


similars are


Unnamed: 0,Title,Category
0,Rude (1995),Drama
1,Star Wars: Episode VI - Return of the Jedi (1983),Action|Adventure|Romance|Sci-Fi|War
2,Ferris Bueller's Day Off (1986),Comedy
3,Set It Off (1996),Action|Crime
4,"Great Muppet Caper, The (1981)",Children's|Comedy
5,Dear God (1996),Comedy
6,Disclosure (1994),Drama|Thriller
7,Killing Zoe (1994),Thriller
8,Vermin (1998),Comedy
9,Shiloh (1997),Children's|Drama


In [None]:
ncf_recommender.prediction_for_user()

--- User's choice ---


Unnamed: 0,Title,Category
0,Star Wars: Episode IV - A New Hope (1977),Action|Adventure|Fantasy|Sci-Fi
1,Jurassic Park (1993),Action|Adventure|Sci-Fi
2,Die Hard (1988),Action|Thriller
3,E.T. the Extra-Terrestrial (1982),Children's|Drama|Fantasy|Sci-Fi
4,Raiders of the Lost Ark (1981),Action|Adventure
5,"Good, The Bad and The Ugly, The (1966)",Action|Western
6,Alien (1979),Action|Horror|Sci-Fi|Thriller
7,"Terminator, The (1984)",Action|Sci-Fi|Thriller
8,Jaws (1975),Action|Horror
9,Rocky (1976),Action|Drama



--- Our recommendations ---


Unnamed: 0,Title,Category
0,"Contender, The (2000)",Drama|Thriller
1,Jefferson in Paris (1995),Drama
2,One Little Indian (1973),Comedy|Drama|Western
3,Exit to Eden (1994),Comedy
4,"$1,000,000 Duck (1971)",Children's|Comedy
5,Deceiver (1997),Crime
6,Eternity and a Day (Mia eoniotita ke mia mera ...,Drama
7,I Don't Want to Talk About It (De eso no se ha...,Drama
8,See the Sea (Regarde la mer) (1997),Drama
9,"Steal Big, Steal Little (1995)",Comedy


Все стало хуже. Может быть, стоило брать hidden dims или negative sampling поменьше.


## 3. Attention

Для каждого юзера надо составить датасет из (последовательность предыдущих просмотренных фильм, следующий фильм). Будем поочередно в качестве следующего фильма брать positive айтем(т.е. фильм который был реально просмотрен после последовательности фильмов) и negative айтем(т.е фильм, кторый не был просмотрен вообще). Получится negative sampling как в NCF. Но здесь возьмем меньшее количество negative pairs. В тест отправим одну пару (последовательность, positive фильм) для каждого юзера. 

#### New dataset

In [20]:
def attention_dataset(ratings, train_seq_len, test_seq_len=10):
    user_ids = np.unique(users.values)
    movie_ids = np.unique(movies.values)
    Sample = namedtuple('Sample', 'seq, next_item, target')
    Ds_train = []
    Ds_test = []
    
    for u in user_ids:
        user_rating = ratings[u]
        rated = list(user_rating.indices)
        if len(rated) > train_seq_len + 1:
            not_rated = list(np.setdiff1d(movie_ids, rated))
            
            user_sequence = [rated[i:i + train_seq_len] for i in range(len(rated) - train_seq_len)]
            next_pos = [rated[i] for i in range(train_seq_len, len(rated))]
            next_neg = list(np.random.choice(not_rated, len(next_pos), replace=False))
            next_items = [next_neg, next_pos]
            
            for i, seq in enumerate(user_sequence[:-1]):
                # if 0 - then for sequence take negative next item & target = 0
                # if 1 - then for sequence take positive next item & target = 1
                target = np.random.choice([0, 1])
                Ds_train.append(Sample(seq, next_items[target][i], target))
            Ds_test.append((rated[-test_seq_len - 1:-1], rated[-1]))
    return Ds_train, Ds_test

In [21]:
Ds_attn_train, Ds_attn_test = attention_dataset(ratings_sparse, train_seq_len=15)

In [22]:
sequences_train, next_items_train, targets_train = map(torch.LongTensor, zip(*Ds_attn_train))
train_dataset = TensorDataset(sequences_train, next_items_train, targets_train)
train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True, 
                          num_workers=multiprocessing.cpu_count())

sequences_test, next_items_test = map(torch.LongTensor, zip(*Ds_attn_test))
test_dataset = TensorDataset(sequences_test, next_items_test)
test_loader = DataLoader(test_dataset, shuffle=True, num_workers=multiprocessing.cpu_count())

#### Attention model

In [13]:
class SelfAttention(nn.Module):
    def __init__(self, n_items, embed_dim, num_heads, hidden_dims):
        super().__init__()
        self.num_heads = num_heads
        self.item_embedder = nn.Embedding(n_items, embed_dim)
        self.attention_layer = nn.MultiheadAttention(embed_dim * num_heads, num_heads, kdim=embed_dim,
                                                    vdim=embed_dim) 
        
        self.out = []
        all_dims = [embed_dim, *hidden_dims]
        for i, dim in enumerate(all_dims[:-1]):
            self.out.extend([nn.Linear(all_dims[i], all_dims[i + 1]), nn.ReLU()])
        self.out.extend([nn.Linear(all_dims[-1], 1), nn.Sigmoid()])
        self.out = nn.Sequential(*self.out)
        
    def forward(self, sequence, next_item):
        seq_embs = self.item_embedder(sequence)  # key, value 
        next_emb = self.item_embedder(next_item) # query 
        
        # sizes: seq_len, bs, emd_dim
        seq_embs_inp = seq_embs.permute(1, 0, 2)
        next_emb_inp = next_emb.view(1, *next_emb.shape).repeat(1, 1, self.num_heads)
        
        _, attn_output_weights = self.attention_layer(next_emb_inp, seq_embs_inp, seq_embs_inp)
        
        # weights * values
        attn_next = (attn_output_weights * seq_embs_inp.permute(1, 2, 0)).sum(dim=2)
        return self.out(attn_next)    

#### Train

In [14]:
class AttentionLearner(pl.LightningModule):
    def __init__(self, model, params, lr=1e-3):
        super().__init__()
        self.model = model(**params)
        self.lr = lr 
        self.loss_fn = nn.BCELoss()  
        
    def forward(self, seq, next_item):
        return self.model(seq, next_item)
    
    def training_step(self, batch, *args):
        seq, next_item, y_target = batch
        y_pred = self(seq, next_item)
        loss = self.loss_fn(y_pred, y_target.unsqueeze(1).to(torch.float32))
        return {'loss': loss}      
        
    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=self.lr)

In [23]:
attn_params = {'n_items': np.max(movies.values) + 1,
                'embed_dim': 32 * 5,
                'num_heads': 5, 
                'hidden_dims': []}

num_gpus = torch.cuda.device_count()

attn_checkpoint = ModelCheckpoint(save_top_k=3, monitor='val_loss', verbose=True)

attn = AttentionLearner(SelfAttention, attn_params)
trainer = pl.Trainer(gpus=num_gpus, max_epochs=30, progress_bar_refresh_rate=30)
trainer.fit(attn, train_loader)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type          | Params
------------------------------------------
0 | model   | SelfAttention | 2.2 M 
1 | loss_fn | BCELoss       | 0     
------------------------------------------
2.2 M     Trainable params
0         Non-trainable params
2.2 M     Total params


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…




1

#### New recommender

In [24]:
class AttentionRecommender:
    def __init__(self, ratings_sparse, Ds, model):
        self.ratings = ratings_sparse
        self.pos_neg_Ds = Ds

        self.item_embs = model.item_embedder.weight.detach()
        self.predict_score = model.forward
               
    def similar_movies(self, movie_id=1, k=10):
        title = title_by_id(movie_id)
        category = category_by_id(movie_id)
        print('For movie')
        rec_print([title], [category])
        print('similars are')
        
        movie_emb = self.item_embs[movie_id].numpy()
        all_movies_emb = self.item_embs
            
        score = cosine_similarity(np.expand_dims(movie_emb, axis=0), all_movies_emb)[0]
        ranked_similars = np.argsort(score)[::-1]
        top_k_ids = [s for s in ranked_similars if s != movie_id and s in movie_info['movie_id'].values][:k]
        top_k_titles = [title_by_id(r) for r in top_k_ids]
        top_k_categories = [category_by_id(r) for r in top_k_ids]
        rec_print(top_k_titles, top_k_categories)
        
    def prediction_for_user(self, user_id=4):
        real_movie_id = self.ratings[user_id].indices   
        print('--- User\'s choice ---')
        real_titles = [title_by_id(r) for r in real_movie_id[:10]]
        real_category = [category_by_id(r) for r in real_movie_id[:10]]
        rec_print(real_titles, real_category)
        print()

        print('--- Our recommendations ---')
        # user emb = sum of embs of his items
        user_pos_items = self.pos_neg_Ds[user_id].pos_train
        user_embs = self.item_embs[user_pos_items]
        user_emb = torch.sum(user_embs, dim=0).unsqueeze(0)
        
        user_pred = (user_emb @ self.item_embs.T).numpy().squeeze()
        ranked_movie_id = np.argsort(user_pred)[::-1]
        not_rated_movie_id = [i for i in ranked_movie_id if i not in real_movie_id and i in movie_info['movie_id'].values]
        pred_titles = [title_by_id(r) for r in not_rated_movie_id[:10]]
        pred_category = [category_by_id(r) for r in not_rated_movie_id[:10]]
        rec_print(pred_titles, pred_category)
    
    def one_sample_metric(self, batch, k=10):
        seq, next_item_pos = batch
        all_movies = np.unique(movies.values)
        next_item_neg = torch.LongTensor(np.random.choice(all_movies, 99, replace=False))
        y_target_pos = torch.ones_like(next_item_pos)
        y_target_neg = torch.zeros_like(next_item_neg)
        seq = seq.repeat(100, 1)
        next_items = torch.cat([next_item_pos, next_item_neg], dim=-1)
        y_target = torch.cat([y_target_pos, y_target_neg], dim=-1).numpy()
        y_pred = self.predict_score(seq, next_items).detach().numpy().squeeze()

        items = next_items.numpy()
        top_k_items = items[np.argsort(y_pred)[-k:]]
        hr = int(next_item_pos.item() in top_k_items)
        ndcg = ndcg_score([y_target], [y_pred], k=k)
        return hr, ndcg

    def compute_metrics(self, test_loader):
        hr_list = []
        ndcg_list = []
        for batch in test_loader:
            hr, ndcg = self.one_sample_metric(batch)
            hr_list.append(hr)
            ndcg_list.append(ndcg)
        print('Avg hr: {0:.3f}'.format(np.mean(hr_list)))
        print('Avg ndcg: {0:.3f}'.format(np.mean(ndcg_list)))

In [25]:
attn_rec = AttentionRecommender(ratings_sparse, Ds, attn.model)
attn_rec.compute_metrics(test_loader)

Avg hr: 0.711
Avg ndcg: 0.435


In [28]:
attn_rec.similar_movies()

For movie


Unnamed: 0,Title,Category
0,Toy Story (1995),Animation|Children's|Comedy


similars are


Unnamed: 0,Title,Category
0,Hotel de Love (1996),Comedy|Romance
1,"Pawnbroker, The (1965)",Drama
2,Dracula (1931),Horror
3,Cemetery Man (Dellamorte Dellamore) (1994),Comedy|Horror
4,Two Family House (2000),Drama
5,"Gambler, The (A J�t�kos) (1997)",Drama
6,Death Wish 4: The Crackdown (1987),Action|Drama
7,"City, The (1998)",Drama
8,Down to You (2000),Comedy|Romance
9,Conceiving Ada (1997),Drama|Sci-Fi


In [29]:
attn_rec.prediction_for_user()

--- User's choice ---


Unnamed: 0,Title,Category
0,Star Wars: Episode IV - A New Hope (1977),Action|Adventure|Fantasy|Sci-Fi
1,Jurassic Park (1993),Action|Adventure|Sci-Fi
2,Die Hard (1988),Action|Thriller
3,E.T. the Extra-Terrestrial (1982),Children's|Drama|Fantasy|Sci-Fi
4,Raiders of the Lost Ark (1981),Action|Adventure
5,"Good, The Bad and The Ugly, The (1966)",Action|Western
6,Alien (1979),Action|Horror|Sci-Fi|Thriller
7,"Terminator, The (1984)",Action|Sci-Fi|Thriller
8,Jaws (1975),Action|Horror
9,Rocky (1976),Action|Drama



--- Our recommendations ---


Unnamed: 0,Title,Category
0,Raising Arizona (1987),Comedy
1,Star Wars: Episode V - The Empire Strikes Back...,Action|Adventure|Drama|Sci-Fi|War
2,Waltzes from Vienna (1933),Comedy|Musical
3,Pretty Woman (1990),Comedy|Romance
4,Key Largo (1948),Crime|Drama|Film-Noir|Thriller
5,Theodore Rex (1995),Comedy
6,"Tickle in the Heart, A (1996)",Documentary
7,Children of a Lesser God (1986),Drama
8,Harlem (1993),Drama
9,Mille bolle blu (1993),Comedy


Метрики лучше, чем в NCF, но хуже чем в WARP. Симилары и рекомендации не очень. 
Вероятно, стоит подобрать гипермапарметры и по-другому предсказывать скор с помощью эмбеддингов из атеншн.