In [1]:
!wget http://files.grouplens.org/datasets/movielens/ml-1m.zip

--2020-12-08 20:08:45--  http://files.grouplens.org/datasets/movielens/ml-1m.zip
Resolving files.grouplens.org (files.grouplens.org)... 128.101.65.152
Connecting to files.grouplens.org (files.grouplens.org)|128.101.65.152|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5917549 (5.6M) [application/zip]
Saving to: ‘ml-1m.zip’


2020-12-08 20:08:45 (27.0 MB/s) - ‘ml-1m.zip’ saved [5917549/5917549]



In [2]:
!unzip -q ml-1m.zip -d .

In [3]:
!pip install implicit

Collecting implicit
[?25l  Downloading https://files.pythonhosted.org/packages/bc/07/c0121884722d16e2c5beeb815f6b84b41cbf22e738e4075f1475be2791bc/implicit-0.4.4.tar.gz (1.1MB)
[K     |▎                               | 10kB 23.3MB/s eta 0:00:01[K     |▋                               | 20kB 30.3MB/s eta 0:00:01[K     |▉                               | 30kB 35.5MB/s eta 0:00:01[K     |█▏                              | 40kB 28.7MB/s eta 0:00:01[K     |█▌                              | 51kB 30.6MB/s eta 0:00:01[K     |█▊                              | 61kB 33.2MB/s eta 0:00:01[K     |██                              | 71kB 26.0MB/s eta 0:00:01[K     |██▍                             | 81kB 26.9MB/s eta 0:00:01[K     |██▋                             | 92kB 28.4MB/s eta 0:00:01[K     |███                             | 102kB 24.9MB/s eta 0:00:01[K     |███▎                            | 112kB 24.9MB/s eta 0:00:01[K     |███▌                            | 122kB 24.9MB/s eta

#Imports

In [4]:
import pandas as pd
import numpy as np
import scipy.sparse as sp
import implicit
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm, trange
from sklearn.metrics import ndcg_score
from collections import defaultdict
from copy import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data

DEVICE = torch.device('cuda')

# Datasets

Выбрал MovieLens, потому что его часто используют в статьях и будет с чем сравнить результаты.

In [5]:
datapath = 'ml-1m/'
ratings = pd.read_csv(datapath + 'ratings.dat', delimiter='::', header=None, 
        names=['user_id', 'movie_id', 'rating', 'timestamp'], 
        usecols=['user_id', 'movie_id', 'rating', 'timestamp'], engine='python')
movie_info = pd.read_csv(datapath + 'movies.dat', delimiter='::', header=None, 
        names=['movie_id', 'name', 'category'], engine='python')
ratings['user_id'] -= 1
ratings['movie_id'] -= 1
movie_info['movie_id'] -= 1

In [6]:
ratings

Unnamed: 0,user_id,movie_id,rating,timestamp
0,0,1192,5,978300760
1,0,660,3,978302109
2,0,913,3,978301968
3,0,3407,4,978300275
4,0,2354,5,978824291
...,...,...,...,...
1000204,6039,1090,1,956716541
1000205,6039,1093,5,956704887
1000206,6039,561,5,956704746
1000207,6039,1095,4,956715648


In [7]:
s_rs = ratings.sort_values(['user_id', 'timestamp'])
train, test = [], []
for user_id in s_rs['user_id'].unique():
    urs = s_rs.loc[s_rs['user_id'] == user_id]
    urs = urs.loc[urs.rating > 0]
    if len(urs) > 1:
        tr, tst = train_test_split(urs, shuffle=False, test_size=1)
        train.append(tr)
        test.append(tst)
    else:
        print(user_id)
train = pd.concat(train)
test = pd.concat(test)

In [8]:
train.rating.describe()

count    994169.000000
mean          3.581411
std           1.116737
min           1.000000
25%           3.000000
50%           4.000000
75%           4.000000
max           5.000000
Name: rating, dtype: float64

In [9]:
test.rating.describe()

count    6040.000000
mean        3.606788
std         1.175389
min         1.000000
25%         3.000000
50%         4.000000
75%         5.000000
max         5.000000
Name: rating, dtype: float64

In [10]:
users = train["user_id"]
movies = train["movie_id"]
user_item = sp.coo_matrix((np.ones_like(users), (users, movies)))
item_user_csr = user_item.T.tocsr()
user_item_csr = user_item.tocsr()

In [11]:
user_item_csr

<6040x3952 sparse matrix of type '<class 'numpy.longlong'>'
	with 994169 stored elements in Compressed Sparse Row format>

In [12]:
negatives = {}
all_items = set(range(user_item_csr.shape[1]))
for user_id in trange(user_item_csr.shape[0]):
    pos = user_item_csr[user_id].nonzero()[1]
    negatives[user_id] = list(all_items - set(pos))

HBox(children=(FloatProgress(value=0.0, max=6040.0), HTML(value='')))




In [13]:
def get_movies(idxs):
    return movie_info.set_index('movie_id').loc[[i for i in idxs if i in set(movie_info.movie_id)]]

#BPR

In [13]:
class BaseModelWithMetrics:
    def rmse(self, user_item, test=None):
        if test is None:
            i, j = user_item.nonzero()
        else:
            i, j = test.user_id, test.movie_id
        error = (self.score(i, j) - 1) ** 2
        return np.sqrt(error.mean())

    def hr_ndcg(self, test, negatives, k=10):
        hr = []
        ndcg = []
        for user_id in tqdm(test.user_id.unique()):
            items = [test.set_index('user_id').loc[user_id].movie_id]
            neg = list(np.random.choice(negatives[user_id], size=99, replace=False))
            items.extend(neg)
            pred = self.score(user_id, items)
            target = np.zeros(100)
            target[0] = 1
            hr.append((pred.argsort()[::-1] == 0)[:k].sum())
            ndcg.append(ndcg_score([target], [pred], k=k))
        print(f"HR@{k} = {np.mean(hr):.4f}")
        print(f"NDCG@{k} = {np.mean(ndcg):.4f}")        

    def recommend(self, user_id, user_item, top_n=10):
        raise NotImplementedError()

    def similar_items(self, item_id, top_n=10):
        raise NotImplementedError()

    @staticmethod
    def user_history(user_id, user_item):
        return [i for i in user_item[user_id].nonzero()[1]]

    def score(self, i, j):
        raise NotImplementedError()


class BPR(BaseModelWithMetrics):
    def __init__(self, factors=63, iters=1000, reg=0.01, lr=1e-3):
        self.model = implicit.bpr.BayesianPersonalizedRanking(
            factors=factors,
            use_gpu=True,
            learning_rate=lr,
            regularization=reg,
            verify_negative_samples=True,
            random_state=42,
            iterations=iters
        )
    
    def fit(self, item_user):
        self.model.fit(item_user)
            
    def recommend(self, user_id, negatives, top_n=10):
        not_recommended = np.array(negatives[user_id])
        score = self.score(user_id, not_recommended)
        return not_recommended[sorted(np.arange(len(not_recommended)), key=lambda x: -score[x])[:top_n]]

    def similar_items(self, item_id, top_n=10):
        return np.argsort(np.linalg.norm(self.item_factors - self.item_factors[item_id], axis=1))[:top_n]

    def score(self, i, j):
        return np.sum(self.user_factors[i] * self.item_factors[j], axis=1)

    @property
    def item_factors(self):
        return self.model.item_factors
    
    @property
    def user_factors(self):
        return self.model.user_factors

In [32]:
bpr = BPR(factors=63, iters=1000, reg=0.01, lr=1e-3)
bpr.fit(item_user_csr)

HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




In [33]:
print('RMSE:', bpr.rmse(user_item_csr, test))

RMSE: 1.2510865


In [34]:
bpr.hr_ndcg(test, negatives, 10)

HBox(children=(FloatProgress(value=0.0, max=6040.0), HTML(value='')))


HR@10 = 0.6308
NDCG@10 = 0.3767


In [None]:
user_id = 0

In [None]:
get_movies(bpr.user_history(user_id, user_item_csr))

Unnamed: 0_level_0,name,category
movie_id,Unnamed: 1_level_1,Unnamed: 2_level_1
0,Toy Story (1995),Animation|Children's|Comedy
149,Apollo 13 (1995),Drama
259,Star Wars: Episode IV - A New Hope (1977),Action|Adventure|Fantasy|Sci-Fi
526,Schindler's List (1993),Drama|War
530,"Secret Garden, The (1993)",Children's|Drama
587,Aladdin (1992),Animation|Children's|Comedy|Musical
593,Snow White and the Seven Dwarfs (1937),Animation|Children's|Musical
594,Beauty and the Beast (1991),Animation|Children's|Musical
607,Fargo (1996),Crime|Drama|Thriller
660,James and the Giant Peach (1996),Animation|Children's|Musical


In [None]:
get_movies(test.loc[test.user_id == user_id].movie_id)

Unnamed: 0_level_0,name,category
movie_id,Unnamed: 1_level_1,Unnamed: 2_level_1
47,Pocahontas (1995),Animation|Children's|Musical|Romance


In [None]:
get_movies(bpr.recommend(user_id, negatives))

Unnamed: 0_level_0,name,category
movie_id,Unnamed: 1_level_1,Unnamed: 2_level_1
363,"Lion King, The (1994)",Animation|Children's|Musical
1281,Fantasia (1940),Animation|Children's|Musical
2080,"Little Mermaid, The (1989)",Animation|Children's|Comedy|Musical|Romance
261,"Little Princess, A (1995)",Children's|Drama
2084,101 Dalmatians (1961),Animation|Children's
33,Babe (1995),Children's|Comedy|Drama
317,"Shawshank Redemption, The (1994)",Drama
595,Pinocchio (1940),Animation|Children's
2095,Sleeping Beauty (1959),Animation|Children's|Musical
2086,Peter Pan (1953),Animation|Children's|Fantasy|Musical


In [None]:
get_movies(bpr.similar_items(0))

Unnamed: 0_level_0,name,category
movie_id,Unnamed: 1_level_1,Unnamed: 2_level_1
0,Toy Story (1995),Animation|Children's|Comedy
3113,Toy Story 2 (1999),Animation|Children's|Comedy
33,Babe (1995),Children's|Comedy|Drama
2354,"Bug's Life, A (1998)",Animation|Children's|Comedy
2383,Babe: Pig in the City (1998),Children's|Comedy
587,Aladdin (1992),Animation|Children's|Comedy|Musical
2320,Pleasantville (1998),Comedy
836,Matilda (1996),Children's|Comedy
1918,Madeline (1998),Children's|Comedy
1264,Groundhog Day (1993),Comedy|Romance


Получились вполне адекватные метрики, рекомендации и симилары.

#NCF

In [18]:
class NCF(nn.Module):
    def __init__(self, num_users, num_items, 
                 GMF_factors=64, 
                 MLP_factors=64, 
                 layers_dim=[128, 128], 
                 dropout=0.2):
        super(NCF, self).__init__()
        self.dropout = dropout

        self.u_emb_GMF = nn.Embedding(num_users, GMF_factors).to(DEVICE)
        self.i_emb_GMF = nn.Embedding(num_items, GMF_factors).to(DEVICE)
        self.u_emb_MLP = nn.Embedding(num_users, MLP_factors).to(DEVICE)
        self.i_emb_MLP = nn.Embedding(num_items, MLP_factors).to(DEVICE)

        MLP_modules = [nn.Dropout(p=self.dropout), nn.Linear(MLP_factors * 2, layers_dim[0])]
        for i in range(len(layers_dim) - 1):
            MLP_modules.extend([nn.ReLU(), 
                                nn.Dropout(p=self.dropout),
                                nn.Linear(layers_dim[i], layers_dim[i + 1])])
        self.MLP_layers = nn.Sequential(*MLP_modules).to(DEVICE)
        self.NMF_layer = nn.Sequential(nn.Linear(GMF_factors + layers_dim[-1], 1),
                                       nn.Sigmoid()).to(DEVICE)

        self._init_weight_()

    def _init_weight_(self):
        nn.init.normal_(self.u_emb_GMF.weight, std=0.01)
        nn.init.normal_(self.u_emb_MLP.weight, std=0.01)
        nn.init.normal_(self.i_emb_GMF.weight, std=0.01)
        nn.init.normal_(self.i_emb_MLP.weight, std=0.01)

        for m in self.MLP_layers:
            if isinstance(m, nn.Linear):
                nn.init.kaiming_uniform_(m.weight)
        nn.init.kaiming_uniform_(self.NMF_layer[0].weight)

        for m in self.modules():
            if isinstance(m, nn.Linear) and m.bias is not None:
                m.bias.data.zero_()

    def forward(self, user, item):
        u_emb_GMF = self.u_emb_GMF(user)
        i_emb_GMF = self.i_emb_GMF(item)
        GMF_out = u_emb_GMF * i_emb_GMF

        u_emb_MLP = self.u_emb_MLP(user)
        i_emb_MLP = self.i_emb_MLP(item)
        interaction = torch.cat((u_emb_MLP, i_emb_MLP), -1)
        MLP_out = self.MLP_layers(interaction)

        NMF_in = torch.cat((GMF_out, MLP_out), -1)

        prediction = self.NMF_layer(NMF_in)
        return prediction.view(-1)


class NCFModel(BaseModelWithMetrics):
    def __init__(self, *args, neg_size=5, batch_size=64, lr=1e-4, **kargs):
        self.batch_size = batch_size
        self.num_users = kargs.get("num_users")
        self.num_items = kargs.get("num_items")
        self.neg_size = neg_size
        self.model = NCF(*args, **kargs).to(DEVICE)
        self.positives = defaultdict(list)
        self.negatives = defaultdict(list)
        self.loss = nn.BCELoss()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)

    def fit(self, user_item, test, iters=10):
        self.model.train()
        t = trange(iters)
        t.set_description("Total")
        for iteration in t:
            self.train_epoch(user_item, iteration)

    def train_epoch(self, user_item, iteration):
        dataloader = self.build_dataloader(user_item)
        t = tqdm(dataloader)
        t.set_description(f"Epoch {iteration}")
        losses = []
        for batch in t:
            self.optimizer.zero_grad()
            users, items, targets = batch.T.to(DEVICE)
            preds = self.model(users, items)
            loss = self.loss(preds, targets.float())
            loss.backward()
            self.optimizer.step()
            losses.append(loss.detach().cpu().item())
        t.set_postfix(loss=np.mean(losses))

    def score(self, i, j):
        with torch.no_grad():
            if isinstance(i, (np.int64, np.int, int)):
                i = [i] * np.array(j).shape[0]
            elif isinstance(i, pd.Series):
                i = i.to_numpy()
            if isinstance(j, pd.Series):
                j = j.to_numpy()
            i, j = torch.tensor(i).to(DEVICE), torch.tensor(j).to(DEVICE)
            scores = []
            for i_batch, j_batch in zip(torch.split(i, self.batch_size), torch.split(j, self.batch_size)):
                scores.append(self.model(i_batch, j_batch).cpu().numpy())
            return np.hstack(scores)

    def build_dataloader(self, user_item):
        if not self.positives:
            nonzeros = user_item.nonzero()
            for user_id, pos in zip(*nonzeros):
                self.positives[user_id].append(pos)
            _all_items = set(range(self.num_items))
            for user_id, pos in self.positives.items():
                self.negatives[user_id] = list(_all_items - set(pos))
        d = []
        for uid, ps in self.positives.items():
            nonrecommended = self.negatives[uid]
            n_size = min(len(ps) * self.neg_size, len(nonrecommended))
            neg_items = np.random.choice(nonrecommended,
                                        size=n_size,
                                        replace=True)
            pos, negs = np.ones((len(ps), 3), dtype=int), np.zeros((n_size, 3), dtype=int)
            pos[:, 0], negs[:, 0] = uid, uid
            pos[:, 1], negs[:, 1] = ps, neg_items
            d.append(np.vstack((pos, negs)))
        dataset = np.vstack(d)
        return data.DataLoader(dataset, shuffle=True, batch_size=self.batch_size)

    def recommend(self, user_id, negatives, top_n=10):
        not_recommended = np.array(negatives[user_id])
        score = self.score(user_id, not_recommended)
        return not_recommended[sorted(np.arange(len(not_recommended)), key=lambda x: -score[x])[:top_n]]

    def similar_items(self, item_id, top_n=10, emb_type=None):
        if emb_type == 'MLP':
            emb = self.model.i_emb_MLP.weight.cpu()
        elif emb_type == 'GMF':
            emb = self.model.i_emb_GMF.weight.cpu()
        else:
            emb = self.model.i_emb_MLP.weight.cpu(), self.model.i_emb_GMF.weight.cpu()
            emb = torch.cat(emb, dim=-1)
        return torch.linalg.norm(emb - emb[item_id], dim=-1).argsort()[:top_n].numpy()

In [22]:
ncf = NCFModel(num_users=user_item.shape[0], num_items=user_item.shape[1],
               batch_size=1024, lr=1e-4)
ncf.fit(user_item_csr, test, iters=10)

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=5559.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5559.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5559.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5559.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5559.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5559.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5559.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5559.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5559.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5559.0), HTML(value='')))





In [23]:
print('RMSE:', ncf.rmse(user_item_csr, test))

RMSE: 0.6198951


In [24]:
ncf.hr_ndcg(test, negatives, 10)

HBox(children=(FloatProgress(value=0.0, max=6040.0), HTML(value='')))


HR@10 = 0.6666
NDCG@10 = 0.3867


In [25]:
user_id = 0

In [26]:
get_movies(ncf.user_history(user_id, user_item_csr))

Unnamed: 0_level_0,name,category
movie_id,Unnamed: 1_level_1,Unnamed: 2_level_1
0,Toy Story (1995),Animation|Children's|Comedy
149,Apollo 13 (1995),Drama
259,Star Wars: Episode IV - A New Hope (1977),Action|Adventure|Fantasy|Sci-Fi
526,Schindler's List (1993),Drama|War
530,"Secret Garden, The (1993)",Children's|Drama
587,Aladdin (1992),Animation|Children's|Comedy|Musical
593,Snow White and the Seven Dwarfs (1937),Animation|Children's|Musical
594,Beauty and the Beast (1991),Animation|Children's|Musical
607,Fargo (1996),Crime|Drama|Thriller
660,James and the Giant Peach (1996),Animation|Children's|Musical


In [27]:
get_movies(test.loc[test.user_id == user_id].movie_id)

Unnamed: 0_level_0,name,category
movie_id,Unnamed: 1_level_1,Unnamed: 2_level_1
47,Pocahontas (1995),Animation|Children's|Musical|Romance


In [28]:
get_movies(ncf.recommend(user_id, negatives))

Unnamed: 0_level_0,name,category
movie_id,Unnamed: 1_level_1,Unnamed: 2_level_1
33,Babe (1995),Children's|Comedy|Drama
1281,Fantasia (1940),Animation|Children's|Musical
1264,Groundhog Day (1993),Comedy|Romance
2080,"Little Mermaid, The (1989)",Animation|Children's|Comedy|Musical|Romance
2395,Shakespeare in Love (1998),Comedy|Romance
317,"Shawshank Redemption, The (1994)",Drama
355,Forrest Gump (1994),Comedy|Romance|War
363,"Lion King, The (1994)",Animation|Children's|Musical
2857,American Beauty (1999),Comedy|Drama
2323,Life Is Beautiful (La Vita � bella) (1997),Comedy|Drama


In [29]:
get_movies(ncf.similar_items(0))

Unnamed: 0_level_0,name,category
movie_id,Unnamed: 1_level_1,Unnamed: 2_level_1
0,Toy Story (1995),Animation|Children's|Comedy
550,"Nightmare Before Christmas, The (1993)",Children's|Comedy|Musical
3113,Toy Story 2 (1999),Animation|Children's|Comedy
2383,Babe: Pig in the City (1998),Children's|Comedy
587,Aladdin (1992),Animation|Children's|Comedy|Musical
33,Babe (1995),Children's|Comedy|Drama
2293,Antz (1998),Animation|Children's
836,Matilda (1996),Children's|Comedy
660,James and the Giant Peach (1996),Animation|Children's|Musical
3395,"Muppet Movie, The (1979)",Children's|Comedy


In [30]:
get_movies(ncf.similar_items(0, emb_type='GMF'))

Unnamed: 0_level_0,name,category
movie_id,Unnamed: 1_level_1,Unnamed: 2_level_1
0,Toy Story (1995),Animation|Children's|Comedy
2383,Babe: Pig in the City (1998),Children's|Comedy
836,Matilda (1996),Children's|Comedy
550,"Nightmare Before Christmas, The (1993)",Children's|Comedy|Musical
3113,Toy Story 2 (1999),Animation|Children's|Comedy
587,Aladdin (1992),Animation|Children's|Comedy|Musical
800,Harriet the Spy (1996),Children's|Comedy
3398,Sesame Street Presents Follow That Bird (1985),Children's|Comedy
33,Babe (1995),Children's|Comedy|Drama
2293,Antz (1998),Animation|Children's


In [31]:
get_movies(ncf.similar_items(0, emb_type='MLP'))

Unnamed: 0_level_0,name,category
movie_id,Unnamed: 1_level_1,Unnamed: 2_level_1
0,Toy Story (1995),Animation|Children's|Comedy
1386,Jaws (1975),Action|Horror
1196,"Princess Bride, The (1987)",Action|Adventure|Comedy|Romance
1616,L.A. Confidential (1997),Crime|Film-Noir|Mystery|Thriller
1264,Groundhog Day (1993),Comedy|Romance
2915,Total Recall (1990),Action|Adventure|Sci-Fi|Thriller
2627,Star Wars: Episode I - The Phantom Menace (1999),Action|Adventure|Fantasy|Sci-Fi
317,"Shawshank Redemption, The (1994)",Drama
355,Forrest Gump (1994),Comedy|Romance|War
2570,"Matrix, The (1999)",Action|Sci-Fi|Thriller


HR@10 и NDCG@10 получились выше, чем у WARP, что не может не радовать. В симиларах Toy Story 2 правда не попал на 2е место, как это было в WARP.

# Simple Attention

In [14]:
hist_length = 10
s_rs = ratings.sort_values(['user_id', 'timestamp'])
att_train, att_test = [], []
for user_id in tqdm(s_rs['user_id'].unique()):
    u_hist = s_rs.loc[s_rs['user_id'] == user_id]
    u_hist = u_hist.loc[u_hist.rating > 0].movie_id.to_numpy()
    if len(u_hist) > hist_length:
        for i in range(len(u_hist) - hist_length):
            att_train.append((user_id, tuple(u_hist[i:i+hist_length]), (u_hist[i+hist_length],)))
        att_test.append(att_train.pop())
    else:
        print(user_id)

HBox(children=(FloatProgress(value=0.0, max=6040.0), HTML(value='')))




In [62]:
class Attention(nn.Module):
    def __init__(self, num_items, emb_size=64, h_size=64, layers_dim=(64,), temp=0.9):
        super().__init__()
        self.num_items = num_items
        self.emb_size = emb_size
        self.h_size = h_size
        self.temp = temp
        self.h = nn.Parameter(torch.zeros(h_size))
        self.p = nn.Embedding(num_items, emb_size)
        self.q = nn.Embedding(num_items, emb_size)

        if layers_dim is None:
            self.mlp = nn.Sequential(nn.Linear(emb_size, h_size), nn.ReLU())
        else:
            mlp_modules = [nn.Linear(emb_size, layers_dim[0]), nn.ReLU()]
            for i in range(len(layers_dim) - 1):
                mlp_modules.extend([nn.Linear(layers_dim[i], layers_dim[i + 1]),
                                    nn.ReLU()])
            mlp_modules.extend([nn.Linear(layers_dim[-1], h_size),
                                nn.ReLU()])
            self.mlp = nn.Sequential(*mlp_modules)

    def forward(self, history, item):  # history: batch_size x hist_length, item: batch_size x num_items
        p_i = self.p(item)  # batch_size x num_items x emb_size
        q_h = self.q(history)  # batch_size x hist_length x emb_size
        prod = torch.einsum('bne,bhe->bnhe', p_i, q_h)  # batch_size x num_items x hist_length x emb_size
        f = self.mlp(prod)  # batch_size x num_items x hist_length x h_size
        f = f.matmul(self.h)  # batch_size x num_items x hist_length
        soft = F.gumbel_softmax(f, tau=self.temp, hard=False, dim=-1)  # batch_size x num_items x hist_length
        a = soft.bmm(q_h)  # batch_size x num_items x emb_size
        res = torch.einsum('bne,bne->bn', p_i, a)  # batch_size x num_items
        return torch.sigmoid(res)


class NAIS():
    def __init__(self, num_items, negatives, emb_size=64, h_size=64, layers_dim=(64,), temp=0.9, batch_size=64, lr=1e-4):
        self.batch_size = batch_size
        self.num_items = num_items
        self.model = Attention(num_items, emb_size, h_size, layers_dim, temp).to(DEVICE)
        self.negatives = negatives
        self.loss = nn.BCELoss()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)

    def fit(self, train, iters=10):
        self.model.train()
        t = trange(iters)
        t.set_description("Total")
        for iteration in t:
            self.train_epoch(train, iteration)

    def train_epoch(self, train, iteration):
        self.model.train()
        dataloader = self.build_dataloader(train)
        t = tqdm(dataloader)
        t.set_description(f"Epoch {iteration}")
        losses = []
        i = 0
        for batch in t:
            i += 1
            self.optimizer.zero_grad()
            history, items = batch
            targets = torch.zeros(items.shape).float().to(DEVICE)
            targets[:, 0] = 1.
            preds = self.model(history, items)
            loss = self.loss(preds, targets)
            loss.backward()
            self.optimizer.step()
            losses.append(loss.detach().cpu().item())
            if i % 500 == 0:
                t.set_postfix(loss=losses[-1])
        t.set_postfix(loss=np.mean(losses))

    def score(self, dataset):
        self.model.eval()
        with torch.no_grad():
            dataloader = self.build_dataloader(dataset)
            scores = []
            for hist, it in dataloader:
                scores.append(self.model(hist, it).cpu().numpy())
            return np.vstack(scores)

    def build_dataloader(self, dataset):
        return data.DataLoader(dataset, shuffle=True, batch_size=self.batch_size, collate_fn=self.collate)

    def collate(self, batch):
        _, history, items = zip(*batch)
        history = torch.tensor(history).to(DEVICE)
        items = torch.tensor(items)
        assert history.shape[0] == items.shape[0], 'collate_fn mistake'
        if self.model.training:
            negs = []
            for i in range(len(batch)):
                uid = batch[i][0]
                negs.append([np.random.choice(self.negatives[uid])])
            items = torch.hstack((items, torch.tensor(negs)))
        return history, items.to(DEVICE)

    def recommend(self, user_id, history, negatives, top_n=10):
        not_recommended = np.array(negatives[user_id])
        dataset = [(user_id, history, not_recommended)]
        score = self.score(dataset).reshape(-1)
        return not_recommended[sorted(np.arange(len(not_recommended)), key=lambda x: -score[x])[:top_n]]

    def similar_items(self, item_id, top_n=10, emb_type='p'):
        assert emb_type in ['p', 'q']
        emb = None
        if emb_type == 'p':
            emb = self.model.p.weight.cpu()
        elif emb_type == 'q':
            emb = self.model.q.weight.cpu()
        return torch.linalg.norm(emb - emb[item_id], dim=-1).argsort()[:top_n].numpy()

    def hr_ndcg(self, test, k=10):
        hr = []
        ndcg = []
        for user_id, hist, item in tqdm(test):
            neg = tuple(np.random.choice(self.negatives[user_id], size=99, replace=False))
            items = item + neg
            dataset = [[user_id, hist, items]]
            pred = self.score(dataset).reshape(-1)
            target = np.zeros(100)
            target[0] = 1
            hr.append((pred.argsort()[::-1] == 0)[:k].sum())
            ndcg.append(ndcg_score([target], [pred], k=k))
        print(f"HR@{k} = {np.mean(hr):.4f}")
        print(f"NDCG@{k} = {np.mean(ndcg):.4f}")


In [None]:
nais = NAIS(num_items=user_item.shape[1], negatives=negatives,
               batch_size=512, lr=1e-4)
state = torch.load('/content/drive/MyDrive/AU/recsys_attention/attention.model')
nais.model.load_state_dict(state())
nais.fit(att_train, iters=5)

HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1824.0), HTML(value='')))

In [55]:
torch.save(nais.model.state_dict, '/content/drive/MyDrive/AU/recsys_attention/attention.model')

In [56]:
nais.hr_ndcg(att_test, 10)

HBox(children=(FloatProgress(value=0.0, max=6040.0), HTML(value='')))


HR@10 = 0.1359
NDCG@10 = 0.0581


In [50]:
user_id = 0
history = att_test[user_id][1]

In [51]:
get_movies(test.loc[test.user_id == user_id].movie_id)

Unnamed: 0_level_0,name,category
movie_id,Unnamed: 1_level_1,Unnamed: 2_level_1
47,Pocahontas (1995),Animation|Children's|Musical|Romance


In [57]:
get_movies(nais.recommend(user_id, history, negatives))

Unnamed: 0_level_0,name,category
movie_id,Unnamed: 1_level_1,Unnamed: 2_level_1
514,"Remains of the Day, The (1993)",Drama
2201,Lifeboat (1944),Drama|Thriller|War
3134,"Great Santini, The (1979)",Drama
3689,Porky's Revenge (1985),Comedy
2772,Alice and Martin (Alice et Martin) (1998),Drama
2204,Mr. & Mrs. Smith (1941),Comedy
1916,Armageddon (1998),Action|Adventure|Sci-Fi|Thriller
3441,Band of the Hand (1986),Action
695,Butterfly Kiss (1995),Thriller


In [59]:
get_movies(nais.similar_items(0))

Unnamed: 0_level_0,name,category
movie_id,Unnamed: 1_level_1,Unnamed: 2_level_1
0,Toy Story (1995),Animation|Children's|Comedy
1838,My Giant (1998),Comedy
2,Grumpier Old Men (1995),Comedy|Romance
669,"World of Apu, The (Apur Sansar) (1959)",Drama
2270,Permanent Midnight (1998),Drama
1932,"Life of �mile Zola, The (1937)",Drama
865,Bound (1996),Crime|Drama|Romance|Thriller
1672,Boogie Nights (1997),Drama
1242,Rosencrantz and Guildenstern Are Dead (1990),Comedy|Drama
1869,"Dancer, Texas Pop. 81 (1998)",Comedy|Drama
