In [90]:
import torch
import numpy as np
from tqdm import tqdm_notebook as tqdm
from scipy.sparse import rand as sprand

In [29]:
# make up some random sparse data
n_users = 1000
n_items = 1000
ratings = sprand(n_users, n_items, density=0.01, format='csr')
ratings.data = (np.random.randint(1, 5, size=ratings.nnz).astype(np.float64))
ratings = ratings.toarray()

In [153]:
class MatrixFactorization(torch.nn.Module):
    def __init__(self, n_users, n_items, n_factors=20):
        super().__init__()
        self.user_factors = torch.nn.Embedding(n_users, n_factors, sparse=True)
        self.item_factors = torch.nn.Embedding(n_items, n_factors, sparse=True)
#         self.user_biases = torch.nn.Embedding(n_users, 1, sparse=True)
#         self.item_biases = torch.nn.Embedding(n_items, 1, sparse=True)
    
    def forward(self, user, item):
#         user_features = self.user_factors(user) + self.user_biases(user)
#         item_features = self.item_factors(item) + self.item_biases(item)
#         rating = torch.mm(user_features, item_features.transpose(0,1)).squeeze(-1)
#         return torch.sigmoid(rating) * 5.0
        return torch.mm(self.user_factors(user), self.item_factors(item).transpose(0,1)).squeeze(-1)

In [154]:
mf = MatrixFactorization(n_users, n_items, n_factors=20)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adagrad(mf.parameters(), lr=1e-3)

In [157]:
# random shuffle
rows, cols = ratings.nonzero()
p = np.random.permutation(len(rows))
rows, cols = rows[p], cols[p]

In [158]:
t_loss = 0.0

trange = tqdm(enumerate(zip(rows, cols)), total=rows.shape[0], desc='Training')

for idx, (row, col) in trange:
    rating = torch.FloatTensor([ratings[row, col]])
    row = torch.LongTensor([row])
    col = torch.LongTensor([col])
    
    pred = mf(row, col)
    loss = criterion(pred, rating)
    
    t_loss += loss.item()

    trange.set_postfix(loss="{:.5f}".format(t_loss/(idx+1)))
    
    loss.backward()
    optimizer.step()

HBox(children=(IntProgress(value=0, description='Training', max=10000, style=ProgressStyle(description_width='…


