### Load preprocessed data

In [66]:
import numpy as np
fh = np.load('../data/dataset.npz')
# We have a bunch of feature columns and last column is the y-target
train_x = fh['train_x'][::1000].astype(np.int64)
train_y = fh['train_y'][::1000]
train_xy = fh['train_xy']
n_user = fh['n_user']
n_item = fh['n_item']

In [67]:
def loader(x, y, n):
    for epoch in range(100):
        for i in range(0, len(y), n):
            xs = torch.from_numpy(x[i:i + n, :-1])
            ys = torch.from_numpy(y[i:i + n, -1])
            print(epoch, i)
            yield (xs, ys)

In [68]:
train_loader = loader(train_x, train_y, 64)

### Define the MF Model

In [69]:
import torch
from torch import nn
import torch.nn.functional as F

def l2_regularize(array):
    loss = torch.sum(array ** 2.0)
    return loss


class MF(nn.Module):
    def __init__(self, n_user, n_item, k=18, c_prior=1.0):
        super(MF, self).__init__()
        self.k = k
        self.n_user = n_user
        self.n_item = n_item
        self.c_prior = c_prior
        self.user = nn.Embedding(n_user, k)
        self.item = nn.Embedding(n_item, k)
    
    def __call__(self, train_x):
        user_id = train_x[:, 0]
        item_id = train_x[:, 1]
        vector_user = self.user(user_id)
        vector_item = self.item(item_id)
        prediction = torch.sum(vector_user * vector_item, dim=1)
        return prediction
    
    def loss(self, prediction, target):
        likelihood = F.mse_loss(prediction, target)
        prior =  l2_regularize(self.user.weight) + l2_regularize(self.item.weight)
        total =  prior * self.c_prior + likelihood
        return total

### Train model

In [70]:
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Loss

In [71]:
model = MF(n_user, n_item)
optimizer = torch.optim.Adam(model.parameters())
model

MF(
  (user): Embedding(6041, 18)
  (item): Embedding(3953, 18)
)

In [73]:
@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(engine):
    iter = (engine.state.iteration - 1) % len(train_loader) + 1
    if iter % log_interval == 0:
        print("Epoch[{}] Iteration[{}/{}] Loss: {:.2f}"
              "".format(engine.state.epoch, iter, len(train_loader), engine.state.output))

In [74]:
trainer = create_supervised_trainer(model, optimizer, model.loss)

In [75]:
trainer.run(train_loader, max_epochs=100)

0 0
0 64
0 128
0 192
0 256
0 320
0 384
0 448
0 512
0 576
0 640
0 704
0 768
0 832
0 896
1 0
1 64
1 128
1 192
1 256
1 320
1 384
1 448
1 512
1 576
1 640
1 704
1 768
1 832
1 896
2 0
2 64
2 128
2 192
2 256
2 320
2 384
2 448
2 512
2 576
2 640
2 704
2 768
2 832
2 896
3 0
3 64
3 128
3 192
3 256
3 320
3 384
3 448
3 512
3 576
3 640
3 704
3 768
3 832
3 896
4 0
4 64
4 128
4 192
4 256
4 320
4 384
4 448
4 512
4 576
4 640
4 704
4 768
4 832
4 896
5 0
5 64
5 128
5 192
5 256
5 320
5 384
5 448
5 512
5 576
5 640
5 704
5 768
5 832
5 896
6 0
6 64
6 128
6 192
6 256
6 320
6 384
6 448
6 512
6 576
6 640
6 704
6 768
6 832
6 896
7 0
7 64
7 128
7 192
7 256
7 320
7 384
7 448
7 512
7 576
7 640
7 704
7 768
7 832
7 896
8 0
8 64
8 128
8 192
8 256
8 320
8 384
8 448
8 512
8 576
8 640
8 704
8 768
8 832
8 896
9 0
9 64
9 128
9 192
9 256
9 320
9 384
9 448
9 512
9 576
9 640
9 704
9 768
9 832
9 896
10 0
10 64
10 128
10 192
10 256
10 320
10 384
10 448
10 512
10 576
10 640
10 704
10 768
10 832
10 896
11 0
11 64
11 128
11 192
11 

85 640
85 704
85 768
85 832
85 896
86 0
86 64
86 128
86 192
86 256
86 320
86 384
86 448
86 512
86 576
86 640
86 704
86 768
86 832
86 896
87 0
87 64
87 128
87 192
87 256
87 320
87 384
87 448
87 512
87 576
87 640
87 704
87 768
87 832
87 896
88 0
88 64
88 128
88 192
88 256
88 320
88 384
88 448
88 512
88 576
88 640
88 704
88 768
88 832
88 896
89 0
89 64
89 128
89 192
89 256
89 320
89 384
89 448
89 512
89 576
89 640
89 704
89 768
89 832
89 896
90 0
90 64
90 128
90 192
90 256
90 320
90 384
90 448
90 512
90 576
90 640
90 704
90 768
90 832
90 896
91 0
91 64
91 128
91 192
91 256
91 320
91 384
91 448
91 512
91 576
91 640
91 704
91 768
91 832
91 896
92 0
92 64
92 128
92 192
92 256
92 320
92 384
92 448
92 512
92 576
92 640
92 704
92 768
92 832
92 896
93 0
93 64
93 128
93 192
93 256
93 320
93 384
93 448
93 512
93 576
93 640
93 704
93 768
93 832
93 896
94 0
94 64
94 128
94 192
94 256
94 320
94 384
94 448
94 512
94 576
94 640
94 704
94 768
94 832
94 896
95 0
95 64
95 128
95 192
95 256
95 320
95 384
9

<ignite.engine.engine.State at 0x11b989ba8>