### Load preprocessed data

Run the script that downloads and processes the MovieLens data.
Uncomment it to run the download & processing script.

In [1]:
#!python ../src/download.py

In [7]:
import numpy as np
from sklearn.model_selection import train_test_split
from torch import from_numpy
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
from torch.utils.data import BatchSampler
from torch.utils.data import SequentialSampler

fh = np.load('data/dataset.npz')

# We have a bunch of feature columns and last column is the y-target
# Note pytorch is finicky about need int64 types
train_x = fh['train_x'].astype(np.int64)
train_y = fh['train_y']
test_x = fh['test_x'].astype(np.int64)
test_y = fh['test_y']
X_test = fh['test_x'].astype(np.int64)
Y_test = fh['test_y']

n_user = int(fh['n_user'])
n_item = int(fh['n_item'])
n_occu = int(fh['n_occu'])
n_rank = int(fh['n_ranks'])

n_feat = n_user + n_item + n_occu + n_rank

train_x[:, 1] += n_user
train_x[:, 2] += n_user + n_item
train_x[:, 3] += n_user + n_item + n_occu
test_x[:, 1] += n_user
test_x[:, 2] += n_user + n_item
test_x[:, 3] += n_user + n_item + n_occu

# We already had (train, test), now split train => (train, val) 
# so now we have train, test, val
X_train, X_val, Y_train, Y_val = train_test_split(train_x, train_y)



# columns are user_id, item_id and other features 
# we won't use the 3rd and 4th columns
print(X_train)
print(' ')
print(Y_train)



def dataloader(*arrs, batch_size=1024):
    dataset = TensorDataset(*arrs)
    arr_size = len(arrs[0])
    bs = BatchSampler(SequentialSampler(range(arr_size)),
                      batch_size=batch_size, drop_last=False)
    return DataLoader(dataset, batch_sampler=bs, shuffle=False)
 
train = dataloader(from_numpy(X_train), from_numpy(Y_train))
test = dataloader(from_numpy(X_test), from_numpy(Y_test))
val = dataloader(from_numpy(X_val), from_numpy(Y_val))

[[  438  8074 10171 10026]
 [ 1362  9989 10012 10027]
 [ 2976  7052 10089 10035]
 ...
 [ 4386  8992 10513 10020]
 [ 1421  7466 10013 10025]
 [ 4501  6630 10019 10032]]
 
[[4.]
 [3.]
 [3.]
 ...
 [4.]
 [4.]
 [5.]]


In [8]:
from abstract_model import AbstractModel

In [9]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable


def l2_regularize(array):
    loss = torch.sum(array ** 2.0)
    return loss


def index_into(arr, idx):
    """ This is essentially 2D array indexing.
    Vanilla pytorch allows indexing like:
    idx = [0, 2, 4] 
    x = [1, 2, 3, 4, 5]
    y = x[idx] 
    # y = [1, 3, 5]
    
    This function enable indexing when idx is a 2D array.
    """
    new_shape = (idx.size()[0], idx.size()[1], arr.size()[1])
    return arr[idx.resize(torch.numel(idx.data))].view(new_shape)


def factorization_machine(v, x=None):
    # Takes an input 2D matrix v of n vectors, each d-dimensional
    # produces output that is d-dimensional
    # v is (batchsize, n_features, dim)
    # x is (batchsize, n_features)
    # x functions as a weight array, assumed to be 1 if missing
    # Uses Rendle's trick for computing pairs of features in linear time
    batchsize = v.size()[0]
    n_features = v.size()[1]
    n_dim = v.size()[2]
    if x is None:
        x = Variable(torch.ones(v.size())).to(v.device)
    else:
        x = x.expand(batchsize, n_features, n_dim)
    t0 = (v * x).sum(dim=1)**2.0
    t1 = (v**2.0 * x**2.0).sum(dim=1)
    return 0.5 * (t0 - t1)

In [10]:
import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl



class MF(AbstractModel):
    def __init__(self, n_feat, 
                 k=18, c_feat=1.0, c_bias=1.0, 
                 batch_size=128):
        super().__init__()
        self.k = k
        self.n_feat = n_feat
        self.feat = nn.Embedding(n_feat, k)
        self.bias_feat = nn.Embedding(n_feat, 1)
        self.c_feat = c_feat
        self.c_bias = c_bias
        self.batch_size = batch_size

    def forward(self, inputs):
        # inputs is 2D array of indices
        # biases is 2D array of biases for each index in 2D index array
        biases = index_into(self.bias_feat.weight, inputs).squeeze().sum(dim=1)
        # biases is 3D array of vectors for each index in 2D index array
        vectrs = index_into(self.feat.weight, inputs)
        # return interaction of all feature-vectors
        interx = factorization_machine(vectrs).squeeze().sum(dim=1)
        logodds = biases + interx
        return logodds 

    def loss(self, prediction, target):
        loss_mse = F.mse_loss(prediction, target.squeeze())
        return loss_mse, {"mse": loss_mse}
    
    def reg(self):
        reg_feat = l2_regularize(self.feat.weight) * self.c_feat
        log = dict(reg_feat=reg_feat)
        return reg_feat, log

In [11]:
from pytorch_lightning.loggers.wandb import WandbLogger


k = 8
c_bias = 1e-3
c_feat = 1e-5
model = MF(n_feat,
           k=k, c_bias=c_bias, c_feat=c_feat,
           batch_size=1024)

# add a logger
logger = WandbLogger(name="08_mf", project="simple_mf")

trainer = pl.Trainer(max_epochs=100, logger=logger,
                     early_stop_callback=True,
                     gpus=1, progress_bar_refresh_rate=1) 

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]


#### Run model

In [12]:
trainer.fit(model, train, val)

Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable
[34m[1mwandb[0m: Currently logged in as: [33msf-moody[0m (use `wandb login --relogin` to force relogin)



  | Name      | Type      | Params
----------------------------------------
0 | feat      | Embedding | 98 K  
1 | bias_feat | Embedding | 12 K  


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…





HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Saving latest checkpoint..





1