In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [2]:
#import os
import shutil
import sys

import numpy as np
from scipy import sparse

import matplotlib.pyplot as plt
%matplotlib inline

import seaborn as sn
sn.set()
from collections import Counter
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.optim as optim

from torch.autograd import Variable
import pickle
import warnings
warnings.filterwarnings('ignore')
import time
import spacy
import math, copy
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
import scipy
from scipy.sparse import csr_matrix, hstack, save_npz, load_npz

In [3]:
PATH = 'data/user_based'
PATH

'data/user_based'

In [5]:
va_tr_data = load_npz(os.path.join(PATH,'ub_va_tr_data.npz'))
va_te_data = load_npz(os.path.join(PATH,'ub_va_te_data.npz'))

va_tr_data.shape, va_te_data.shape

((10000, 20497), (10000, 20497))

In [4]:
data = load_npz(os.path.join(PATH,'ub_tr_data.npz'))

data.shape

(115806, 20497)

In [6]:
te_tr_data = load_npz(os.path.join(PATH,'ub_te_tr_data.npz'))
te_te_data = load_npz(os.path.join(PATH,'ub_te_te_data.npz'))

te_tr_data.shape, te_te_data.shape

((10000, 20497), (10000, 20497))

In [7]:
unique_mid = []
with open(os.path.join(PATH, 'ub_unique_mid.txt'), 'r') as f:
    for line in f:
        unique_mid.append(line.strip())

n_items = len(unique_mid)
n_items

20497

In [8]:
mid2id = dict((mid, i) for (i, mid) in enumerate(unique_mid))
id2mid = dict((i, mid) for (i, mid) in enumerate(unique_mid))

In [9]:
movie_descriptions = pd.read_csv('data/item_embeddings/sum_embs.csv')
movie_descriptions['sum_embs'] = movie_descriptions['sum_embs'].map(lambda x : np.fromstring(x[1:-1], sep=','))
movie_descriptions.head(2)

Unnamed: 0,movieId,id,sum_embs
0,1,0,"[0.2797650992870331, -0.12578724324703217, -0...."
1,2,1,"[-0.045451000332832336, -0.2998088002204895, -..."


In [10]:
movie_embs = movie_descriptions['sum_embs'].tolist()
movie_embs = np.array(movie_embs)
movie_embs.shape

(20497, 300)

In [1]:
import numpy as np

In [2]:
a = np.array([1,2,3])

In [4]:
np.random.shuffle(a)
a

array([2, 3, 1])

In [11]:
device = torch.device('cuda')
device

device(type='cuda')

In [12]:
def recall_at_k(pred, true, k):
    tmp = np.intersect1d(pred[:k], true)
    return len(tmp)/np.minimum(len(true), k)

In [13]:
class User_Embeddings(nn.Module):
    
    def __init__(self, movies_embeddings, att_dims, fc_dims, drop=0.2):
        super(User_Embeddings, self).__init__()
        
        self.movies_embeddings = movies_embeddings
        
        "Initialization Layers", 
        self.att_layers = nn.ModuleList([nn.Linear(d_in, d_out, bias=False) for
                                        (d_in, d_out) in zip(att_dims[:-1], att_dims[1:])])
        
        self.fc_layers = nn.ModuleList([nn.Linear(d_in, d_out) for
                                        (d_in, d_out) in zip(fc_dims[:-1], fc_dims[1:])])
        self.drop = nn.Dropout(drop)
        self.init_weights()

        
    def attention(self, data): 
        h = F.normalize(data)
        #h = data
        for i, layer in enumerate(self.att_layers):
            h = self.drop(h)
            h = layer(h)
            if not i == len(self.att_layers) - 1:
                h = torch.tanh(h)
            else:
                h = torch.softmax(h, 0)           
        return h
    
    
    def prediction(self, data):
        h = torch.flatten(data)
        #h = self.drop(h)
        for i, layer in enumerate(self.fc_layers):
            h = self.drop(h)
            h = layer(h)
            if not i == len(self.fc_layers) - 1:
                h = torch.relu(h)
            else:
                h = torch.sigmoid(h)

        return h
        
        
    def forward(self, ind):
        h = self.movies_embeddings[ind]
        z = self.attention(h)
        h = torch.mm(h.t(), z)
        h = self.prediction(h)

        return h
    
    
    def init_weights(self):
        for layer in self.att_layers:
            fan_out, fan_in = layer.weight.size()
            std = np.sqrt(2.0/(fan_in + fan_out))
            layer.weight.data.normal_(0.0, std)
            #layer.bias.data.normal_(0.0, 0.001)
            
        for layer in self.fc_layers:
            fan_out, fan_in = layer.weight.size()
            std = np.sqrt(2.0/(fan_in + fan_out))
            layer.weight.data.normal_(0.0, std)
            layer.bias.data.normal_(0.0, 0.001)

In [14]:
n_users, _ = data.shape

idxlist = list(range(n_users))

In [15]:
def att_train(data):
    model.train()
    np.random.shuffle(idxlist)
    for batch_ind, start_idx in enumerate(range(0, n_users, batch_size)):
        end_idx = min(start_idx + batch_size, n_users)
        batch_data = data[idxlist[start_idx:end_idx]]
        
        optimizer.zero_grad()
        batch_loss =  None
        for d in batch_data:
            item_like_tr = torch.LongTensor(d.nonzero()[1]).to(device)
            input_ = torch.FloatTensor(d.toarray()).view(-1,1).to(device)
            att_pred = model(item_like_tr)
            if batch_loss :
                batch_loss += criterion(att_pred, input_) 
            else:
                batch_loss = criterion(att_pred, input_)
        batch_loss.backward()
        optimizer.step()


def att_test(val_tr, val_te):
    model.eval()
    r_20 = []
    r_50 = []
    val_loss = 0.0
    
    with torch.no_grad():
        for i in range(val_tr.shape[0]):
            input_ = val_tr[i]
            output = val_te[i]

            item_like_tr = input_.nonzero()[1]
            item_like_te = output.nonzero()[1]

            input_ = torch.FloatTensor(input_.toarray()).to(device) 

            att_pred = model(torch.LongTensor(item_like_tr).to(device)) 
            
            att_pred_np = att_pred.data.cpu().numpy()
            att_pred_np[item_like_tr] = -np.inf

            "Compute loss"
            loss = criterion(att_pred, input_)
            val_loss += loss.item()
            
            "Compute recall"
            pred_item = (-att_pred_np).argsort(axis=0)
                
            r_20.append(recall_at_k(pred_item, item_like_te, 20))
            r_50.append(recall_at_k(pred_item, item_like_te, 50))

    val_loss /= val_tr.shape[0]
    return np.mean(r_20), np.mean(r_50), val_loss

In [16]:
batch_size = 256
best_recall = -np.inf
n_epochs = 300

In [17]:
emb_size = movie_embs.shape[1]
emb_size

300

In [18]:
tensor_movie_embs = torch.FloatTensor(movie_embs).to(device)
att_dims = [emb_size, 100, 50]
fc_dims = [att_dims[0]*att_dims[-1], 600, n_items]
att_dims, fc_dims

([300, 100, 50], [15000, 600, 20497])

In [19]:
"Create Attention model, define optimization and loss"

model = User_Embeddings(tensor_movie_embs, att_dims, fc_dims).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = torch.nn.BCELoss()

In [20]:
log_dir = 'models/user_embeddings/'
log_dir

'models/user_embeddings/'

In [21]:
"Batch 1"
try:
    r_20 = []
    r_50 = []
    print('Train with {} users and validate with {} users'.format(data.shape[0], va_tr_data.shape[0]))
    for epoch in range(1, n_epochs + 1):
        epoch_start_time = time.time()
        att_train(data)
        r2, r5, loss = att_test(va_tr_data, va_te_data)
        r_20.append(r2)
        r_50.append(r5)
        print('-' * 94)
        print('| end of epoch {:3d} | time: {:4.2f}s | val loss: {:4.4f} | recall@20 {:4.4f} | recall@50 {:4.4f} |'.format(epoch, time.time() - epoch_start_time, loss, r2, r5))
        print('-' * 94)

        if r2 > best_recall:
            with open(log_dir + 'user_embs_sum.pt', 'wb') as f:
                torch.save(model, f)
            best_recall = r2

    print(np.mean(r_20))
    print(np.mean(r_50))

except KeyboardInterrupt:
    print('-' * 94)
    print('Exiting from training early')

Train with 115806 users and validate with 10000 users
----------------------------------------------------------------------------------------------
| end of epoch   1 | time: 1160.34s | val loss: 0.0130 | recall@20 0.1823 | recall@50 0.2673 |
----------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------
| end of epoch   2 | time: 865.71s | val loss: 0.0125 | recall@20 0.2085 | recall@50 0.3033 |
----------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------
| end of epoch   3 | time: 879.35s | val loss: 0.0121 | recall@20 0.2140 | recall@50 0.3118 |
----------------------------------------------------------------------------------------------
---------------------------------------------------------------------------------------------

----------------------------------------------------------------------------------------------
| end of epoch  30 | time: 858.74s | val loss: 0.0080 | recall@20 0.2599 | recall@50 0.3652 |
----------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------
| end of epoch  31 | time: 919.46s | val loss: 0.0079 | recall@20 0.2583 | recall@50 0.3644 |
----------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------
| end of epoch  32 | time: 916.76s | val loss: 0.0079 | recall@20 0.2591 | recall@50 0.3654 |
----------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------
| end of epoch  33 | time: 862.84s | val loss: 0.0078

----------------------------------------------------------------------------------------------
| end of epoch  59 | time: 863.05s | val loss: 0.0068 | recall@20 0.2340 | recall@50 0.3336 |
----------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------
| end of epoch  60 | time: 860.49s | val loss: 0.0068 | recall@20 0.2357 | recall@50 0.3348 |
----------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------
| end of epoch  61 | time: 859.80s | val loss: 0.0067 | recall@20 0.2340 | recall@50 0.3344 |
----------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------
| end of epoch  62 | time: 863.33s | val loss: 0.0067

----------------------------------------------------------------------------------------------
| end of epoch  88 | time: 859.51s | val loss: 0.0062 | recall@20 0.2209 | recall@50 0.3149 |
----------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------
| end of epoch  89 | time: 858.69s | val loss: 0.0062 | recall@20 0.2224 | recall@50 0.3174 |
----------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------
| end of epoch  90 | time: 860.76s | val loss: 0.0061 | recall@20 0.2213 | recall@50 0.3155 |
----------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------
| end of epoch  91 | time: 858.46s | val loss: 0.0061

----------------------------------------------------------------------------------------------
| end of epoch 117 | time: 859.24s | val loss: 0.0058 | recall@20 0.2115 | recall@50 0.3026 |
----------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------
| end of epoch 118 | time: 860.61s | val loss: 0.0058 | recall@20 0.2121 | recall@50 0.3020 |
----------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------
| end of epoch 119 | time: 855.91s | val loss: 0.0058 | recall@20 0.2104 | recall@50 0.3010 |
----------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------
| end of epoch 120 | time: 858.87s | val loss: 0.0058

----------------------------------------------------------------------------------------------
| end of epoch 146 | time: 861.56s | val loss: 0.0055 | recall@20 0.2039 | recall@50 0.2909 |
----------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------
| end of epoch 147 | time: 860.18s | val loss: 0.0055 | recall@20 0.2020 | recall@50 0.2897 |
----------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------
| end of epoch 148 | time: 858.92s | val loss: 0.0055 | recall@20 0.2035 | recall@50 0.2917 |
----------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------
| end of epoch 149 | time: 860.72s | val loss: 0.0055

----------------------------------------------------------------------------------------------
| end of epoch 175 | time: 861.21s | val loss: 0.0053 | recall@20 0.2012 | recall@50 0.2882 |
----------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------
| end of epoch 176 | time: 859.45s | val loss: 0.0053 | recall@20 0.2015 | recall@50 0.2894 |
----------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------
| end of epoch 177 | time: 856.87s | val loss: 0.0053 | recall@20 0.2025 | recall@50 0.2919 |
----------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------
| end of epoch 178 | time: 864.78s | val loss: 0.0053

----------------------------------------------------------------------------------------------
| end of epoch 204 | time: 858.84s | val loss: 0.0052 | recall@20 0.1995 | recall@50 0.2880 |
----------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------
| end of epoch 205 | time: 860.33s | val loss: 0.0052 | recall@20 0.2012 | recall@50 0.2895 |
----------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------
| end of epoch 206 | time: 860.02s | val loss: 0.0052 | recall@20 0.1990 | recall@50 0.2873 |
----------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------
| end of epoch 207 | time: 857.93s | val loss: 0.0052

----------------------------------------------------------------------------------------------
| end of epoch 233 | time: 863.81s | val loss: 0.0051 | recall@20 0.1967 | recall@50 0.2833 |
----------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------
| end of epoch 234 | time: 858.98s | val loss: 0.0051 | recall@20 0.1994 | recall@50 0.2891 |
----------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------
| end of epoch 235 | time: 857.37s | val loss: 0.0051 | recall@20 0.1950 | recall@50 0.2823 |
----------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------
| end of epoch 236 | time: 861.29s | val loss: 0.0051

----------------------------------------------------------------------------------------------
| end of epoch 262 | time: 858.83s | val loss: 0.0050 | recall@20 0.1952 | recall@50 0.2833 |
----------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------
| end of epoch 263 | time: 857.19s | val loss: 0.0050 | recall@20 0.1951 | recall@50 0.2828 |
----------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------
| end of epoch 264 | time: 857.11s | val loss: 0.0050 | recall@20 0.1993 | recall@50 0.2886 |
----------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------
| end of epoch 265 | time: 860.73s | val loss: 0.0050

----------------------------------------------------------------------------------------------
| end of epoch 291 | time: 858.11s | val loss: 0.0050 | recall@20 0.1979 | recall@50 0.2881 |
----------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------
| end of epoch 292 | time: 859.28s | val loss: 0.0049 | recall@20 0.1978 | recall@50 0.2859 |
----------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------
| end of epoch 293 | time: 858.29s | val loss: 0.0050 | recall@20 0.1968 | recall@50 0.2842 |
----------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------
| end of epoch 294 | time: 857.46s | val loss: 0.0050

In [22]:
"Batch 1"

# Load the best saved model.
with open(log_dir + 'user_embs_sum.pt', 'rb') as f:
    model = torch.load(f)

recall_20, recall_50, _ = att_test(te_tr_data, te_te_data)

print(recall_20, recall_50)

0.2660992136355217 0.3736796119788061
