# Collaborative filtering

In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2
import fastai
import fastai.collab
import fastai.datasets
import fastai.tabular.transform
import math
import numpy
import os
import pandas
import time
import torch
from torch import nn
import typing
import matplotlib.pyplot as plt

In [2]:
GPU = torch.device("cuda:0")
CPU = torch.device("cpu")
dev = CPU # Seems to be much faster for this application

In [3]:
import zipfile
zip_path = fastai.datasets.download_data("http://files.grouplens.org/datasets/movielens/ml-100k.zip", ext="")
dest_dir = zip_path.parent
data_dir = os.path.splitext(zip_path)[0]
if not os.path.exists(data_dir):
    with zipfile.ZipFile(zip_path) as zf:
        zf.extractall(dest_dir)

In [4]:
col_names = ("user", "item", "rating", "timestamp")
train_df = pandas.read_csv(os.path.join(data_dir, "ua.base"), sep="\t", names=col_names)
test_df = pandas.read_csv(os.path.join(data_dir, "ua.test"), sep="\t", names=col_names)
concat_df = pandas.concat((train_df, test_df))
n_item = max(train_df["item"].max(), test_df["item"].max())
n_user = max(train_df["user"].max(), test_df["user"].max())
print(f"n_item: {n_item}, n_user: {n_user}")

class MovieLensDataset(torch.utils.data.Dataset):
    def __init__(self, df: pandas.DataFrame, device: torch.device):
        # Indices into embeddings need to have dtype "long".
        self.ids_tensor = torch.tensor(df[["user", "item"]].to_numpy(), dtype=torch.long, device=device)
        self.ratings_tensor = torch.tensor(df[["rating"]].to_numpy(), dtype=torch.float, device=device)
        
    def __len__(self):
        return len(self.ids_tensor)

    def __getitem__(self, idx):
        return self.ids_tensor[idx], self.ratings_tensor[idx]
        

train_dataset = MovieLensDataset(train_df, dev)
test_dataset = MovieLensDataset(test_df, dev)

n_item: 1682, n_user: 943


In [5]:
batch_size = 64
num_epochs = 10

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
test_inputs, test_labels = test_dataset[:]

In [14]:
# Use fastai as a benchmark.
# See https://github.com/fastai/course-v3/blob/master/nbs/dl1/lesson4-collab.ipynb
# Hacked up copy of CollabDataBunch.from_df because I want to use test_df as the validation set.
user_name   = concat_df.columns[0]
item_name   = concat_df.columns[1]
rating_name = concat_df.columns[2]
cat_names = [user_name,item_name]
num_train = len(train_df)
src = (fastai.collab.CollabList.from_df(concat_df, cat_names=cat_names, procs=fastai.tabular.transform.Categorify)
        .split_by_idxs(train_idx=numpy.arange(num_train), valid_idx=numpy.arange(num_train, num_train + len(test_df)))
        .label_from_df(cols=rating_name))
data_bunch = src.databunch(path=".", bs=batch_size, val_bs=batch_size, device=dev)
assert len(data_bunch.dl(fastai.basic_data.DatasetType.Train).dl.dataset.x) == num_train
data_bunch.show_batch()

user,item,target
13,692,4.0
308,378,3.0
747,268,5.0
56,410,4.0
417,582,3.0


Help on DeviceDataLoader in module fastai.basic_data object:

class DeviceDataLoader(builtins.object)
 |  DeviceDataLoader(dl: torch.utils.data.dataloader.DataLoader, device: torch.device, tfms: List[Callable] = None, collate_fn: Callable = <function data_collate at 0x7f82d2a75820>) -> None
 |  
 |  Bind a `DataLoader` to a `torch.device`.
 |  
 |  Methods defined here:
 |  
 |  __eq__(self, other)
 |  
 |  __getattr__(self, k: str) -> Any
 |  
 |  __init__(self, dl: torch.utils.data.dataloader.DataLoader, device: torch.device, tfms: List[Callable] = None, collate_fn: Callable = <function data_collate at 0x7f82d2a75820>) -> None
 |  
 |  __iter__(self)
 |      Process and returns items from `DataLoader`.
 |  
 |  __len__(self) -> int
 |  
 |  __post_init__(self)
 |  
 |  __repr__(self)
 |  
 |  __setstate__(self, data: Any)
 |  
 |  add_tfm(self, tfm: Callable) -> None
 |      Add `tfm` to `self.tfms`.
 |  
 |  collate_fn = data_collate(batch: Collection[Union[torch.Tensor, fastai.core

In [22]:
fastai_learn = fastai.collab.collab_learner(data_bunch, n_factors=40, y_range=[0,5.5], wd=1e-1)
fastai_learn.fit_one_cycle(num_epochs, 1e-2)

epoch,train_loss,valid_loss,time
0,0.947493,1.052847,00:07
1,0.877937,0.965214,00:07
2,0.877414,0.963883,00:08
3,0.854461,0.941553,00:07
4,0.773689,0.939386,00:08
5,0.72579,0.912324,00:07
6,0.61582,0.902836,00:07
7,0.509327,0.895727,00:07
8,0.419247,0.896176,00:08
9,0.33099,0.896819,00:07


In [27]:
fastai_pred = fastai_learn.get_preds(ds_type=fastai.data_block.DatasetType.Valid)
print("final fastai valid_loss = %.3f" % torch.nn.functional.mse_loss(*fastai_pred).item())

In [8]:
class DotProdBias(nn.Module):
    """Each user and item have embedding_dim params and a bias.

    The predicted rating for (user, item) is
      (x_user • x_item) + b_user + b_item
    """
    def __init__(self, n_user: int, n_item: int, embedding_dim: int):
        super().__init__()
        self.user_emb = nn.Embedding(num_embeddings=n_user, embedding_dim=embedding_dim)
        self.item_emb = nn.Embedding(num_embeddings=n_item, embedding_dim=embedding_dim)
        self.user_bias = nn.Embedding(num_embeddings=n_user, embedding_dim=1)
        self.item_bias = nn.Embedding(num_embeddings=n_item, embedding_dim=1)
    
    def forward(self, users: torch.LongTensor, items: torch.LongTensor) -> torch.FloatTensor:
        # Convert from 1-based to 0-based index.
        users, items = users - 1, items - 1
        dot_prods = (self.user_emb(users) * self.item_emb(items)).sum(dim=1)
        return dot_prods + self.user_bias(users) + self.item_bias(items)

In [9]:
class Fitter:
    def __init__(self, model: nn.Module, loss_func: nn.Module, optim: torch.optim.Optimizer):
        self.model = model
        self.loss_func = loss_func
        self.optim = optim

    def fit(self, num_epochs: int):
        print("epoch | train_loss | test_loss | time")
        for epoch in range(num_epochs):
            start = time.time()
            train_loss = torch.tensor([0.0], dtype=float, device=dev)
            for batch_idx, (inputs, targets) in enumerate(train_loader, 0):
                train_loss += self._one_batch(inputs, targets)

            with torch.no_grad():
                pred = self.model(test_inputs[:, 0], test_inputs[:, 1])
                test_loss = self.loss_func(pred, test_labels).item()

            num_batches = batch_idx + 1
            print("%5d |      %.3f |     %.3f |   %ds |" % (
                epoch,
                train_loss / num_batches,
                test_loss,
                int(time.time() - start)))

    def _one_batch(self, inputs: torch.tensor, targets: torch.tensor) -> torch.tensor:
        self.optim.zero_grad()
        pred = self.model(inputs[:, 0], inputs[:, 1])
        loss = self.loss_func(pred, targets)
        loss.backward()
        self.optim.step()
        return loss

In [68]:
model_dot_prod_bias = DotProdBias(n_user, n_item, 40).to(dev)
fitter_dot_prod_bias = Fitter(model_dot_prod_bias, nn.MSELoss(), torch.optim.Adam(model_dot_prod_bias.parameters(), lr=5e-3, betas=(0.9, 0.99)))
fitter_dot_prod_bias.fit(num_epochs)

epoch | train_loss | test_loss | time | num_batches
    0 |      32.385 |     20.587 |  2s | 1416
    1 |      7.093 |     9.406 |  3s | 1416
    2 |      2.267 |     5.821 |  3s | 1416
    3 |      1.325 |     4.394 |  3s | 1416
    4 |      1.089 |     3.734 |  3s | 1416
    5 |      1.017 |     3.262 |  3s | 1416
    6 |      0.985 |     2.941 |  3s | 1416
    7 |      0.965 |     2.685 |  2s | 1416
    8 |      0.954 |     2.465 |  3s | 1416
    9 |      0.940 |     2.284 |  3s | 1416


In [51]:
model_dot_prod_bias_path = "model_dot_prod_bias.pth"
torch.save(model_dot_prod_bias.state_dict(), model_dot_prod_bias_path)

In [10]:
def trunc_normal_(x: torch.tensor, mean: float=0., std: float=1.) -> torch.tensor:
    "Truncated normal initialization."
    # From https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/12
    return x.normal_().fmod_(2).mul_(std).add_(mean)

In [11]:
class ScaledDotProdBias(nn.Module):
    """Same as DotProdBias, but scale the output to be within y_range."""
    def __init__(self, n_user: int, n_item: int, embedding_dim: int, y_range: typing.Tuple[int, int], trunc_normal: bool=False):
        super().__init__()
        self.user_emb = nn.Embedding(num_embeddings=n_user, embedding_dim=embedding_dim)
        self.item_emb = nn.Embedding(num_embeddings=n_item, embedding_dim=embedding_dim)
        self.user_bias = nn.Embedding(num_embeddings=n_user, embedding_dim=1)
        self.item_bias = nn.Embedding(num_embeddings=n_item, embedding_dim=1)
        if trunc_normal:
            # https://github.com/fastai/fastai1/blob/6a5102ef7bdefa9058d0481ab311f48b21cbc6fc/fastai/layers.py#L285
            for e in (self.user_emb, self.item_emb, self.user_bias, self.item_bias):
                with torch.no_grad(): trunc_normal_(e.weight, std=0.01)
        self.y_min, self.y_max = y_range
    
    def forward(self, users: torch.LongTensor, items: torch.LongTensor) -> torch.FloatTensor:
        # Convert from 1-based to 0-based index.
        users, items = users - 1, items - 1
        dot_prods = (self.user_emb(users) * self.item_emb(items)).sum(dim=1)
        biased = dot_prods + self.user_bias(users) + self.item_bias(items)
        return self.y_min + (self.y_max - self.y_min) * nn.functional.sigmoid(biased)

In [53]:
model_scaled_dot_prod_bias = ScaledDotProdBias(n_user, n_item, 40, (-0.5, 5.5)).to(dev)
fitter_scaled_dot_prod_bias = Fitter(model_scaled_dot_prod_bias, nn.MSELoss(), torch.optim.Adam(model_scaled_dot_prod_bias.parameters(), lr=5e-3, betas=(0.9, 0.99)))
fitter_scaled_dot_prod_bias.fit(num_epochs)

epoch | train_loss | test_loss | time
    0 |      7.755 |     6.754 | 4s
    1 |      5.117 |     5.289 | 3s
    2 |      3.834 |     4.599 | 3s
    3 |      2.908 |     3.990 | 3s
    4 |      2.060 |     3.375 | 3s
    5 |      1.512 |     2.918 | 3s
    6 |      1.257 |     2.601 | 3s
    7 |      1.138 |     2.368 | 3s
    8 |      1.079 |     2.189 | 3s
    9 |      1.045 |     2.043 | 3s


In [54]:
model_scaled_dot_prod_bias_path = "model_scaled_dot_prod_bias.pth"
torch.save(model_scaled_dot_prod_bias.state_dict(), model_scaled_dot_prod_bias_path)

In [12]:
class FitterOneCycle(Fitter):
    def __init__(self, *args):
        super().__init__(*args)
        self.scheduler = None
    
    def fit(self, num_epochs: int):
        self.scheduler = torch.optim.lr_scheduler.OneCycleLR(self.optim, self.optim.defaults["lr"], epochs=num_epochs, steps_per_epoch=math.ceil(len(train_df) / batch_size))
        super().fit(num_epochs)

    def _one_batch(self, inputs: torch.tensor, targets: torch.tensor) -> torch.tensor:
        loss = super()._one_batch(inputs, targets)
        self.scheduler.step()
        return loss


In [77]:
model_scaled_dot_prod_bias_one_cycle = ScaledDotProdBias(n_user, n_item, 40, (-0.5, 5.5)).to(dev)
fitter_one_cycle = FitterOneCycle(model_scaled_dot_prod_bias_one_cycle, nn.MSELoss(), torch.optim.Adam(model_scaled_dot_prod_bias_one_cycle.parameters(), lr=5e-3, betas=(0.9, 0.99)))
fitter_one_cycle.fit(num_epochs)

epoch | train_loss | test_loss | time
    0 |      9.008 |     8.912 |   2s |
    1 |      8.019 |     7.623 |   3s |
    2 |      5.932 |     5.857 |   3s |
    3 |      4.281 |     4.870 |   3s |
    4 |      3.307 |     4.285 |   3s |
    5 |      2.513 |     3.843 |   3s |
    6 |      1.925 |     3.539 |   4s |
    7 |      1.580 |     3.375 |   3s |
    8 |      1.414 |     3.313 |   3s |
    9 |      1.355 |     3.304 |   3s |


In [78]:
# Let's try increasing lr
model_scaled_dot_prod_bias_one_cycle = ScaledDotProdBias(n_user, n_item, 40, (-0.5, 5.5)).to(dev)
fitter_one_cycle = FitterOneCycle(model_scaled_dot_prod_bias_one_cycle, nn.MSELoss(), torch.optim.Adam(model_scaled_dot_prod_bias_one_cycle.parameters(), lr=1e-2, betas=(0.9, 0.99)))
fitter_one_cycle.fit(num_epochs)

epoch | train_loss | test_loss | time
    0 |      8.962 |     8.560 |   3s |
    1 |      7.094 |     6.255 |   3s |
    2 |      4.448 |     4.438 |   3s |
    3 |      2.863 |     3.301 |   3s |
    4 |      1.745 |     2.520 |   3s |
    5 |      1.307 |     2.088 |   3s |
    6 |      1.104 |     1.874 |   3s |
    7 |      0.978 |     1.763 |   3s |
    8 |      0.904 |     1.738 |   3s |
    9 |      0.873 |     1.734 |   3s |


In [79]:
# Let's try increasing lr again
model_scaled_dot_prod_bias_one_cycle = ScaledDotProdBias(n_user, n_item, 40, (-0.5, 5.5)).to(dev)
fitter_one_cycle = FitterOneCycle(model_scaled_dot_prod_bias_one_cycle, nn.MSELoss(), torch.optim.Adam(model_scaled_dot_prod_bias_one_cycle.parameters(), lr=5e-2, betas=(0.9, 0.99)))
fitter_one_cycle.fit(num_epochs)

epoch | train_loss | test_loss | time
    0 |      8.032 |     6.355 |   3s |
    1 |      4.281 |     4.149 |   3s |
    2 |      3.784 |     4.565 |   3s |
    3 |      4.159 |     4.648 |   3s |
    4 |      4.131 |     4.498 |   3s |
    5 |      3.825 |     4.278 |   3s |
    6 |      3.194 |     3.716 |   3s |
    7 |      2.350 |     2.941 |   3s |
    8 |      1.725 |     2.435 |   3s |
    9 |      1.399 |     2.355 |   3s |


In [84]:
# That was not good, let's go back to 1e-2, and try training longer
model_scaled_dot_prod_bias_one_cycle = ScaledDotProdBias(n_user, n_item, 40, (-0.5, 5.5)).to(dev)
fitter_one_cycle = FitterOneCycle(model_scaled_dot_prod_bias_one_cycle, nn.MSELoss(), torch.optim.Adam(model_scaled_dot_prod_bias_one_cycle.parameters(), lr=1e-2, betas=(0.9, 0.99)))
fitter_one_cycle.fit(num_epochs*2)

epoch | train_loss | test_loss | time
    0 |      8.959 |     8.992 |   2s |
    1 |      8.165 |     8.119 |   3s |
    2 |      6.458 |     6.432 |   3s |
    3 |      4.583 |     4.893 |   3s |
    4 |      3.204 |     3.809 |   3s |
    5 |      2.045 |     2.845 |   3s |
    6 |      1.483 |     2.264 |   3s |
    7 |      1.273 |     1.872 |   3s |
    8 |      1.150 |     1.660 |   3s |
    9 |      1.075 |     1.477 |   3s |
   10 |      1.018 |     1.342 |   3s |
   11 |      0.974 |     1.268 |   3s |
   12 |      0.940 |     1.210 |   3s |
   13 |      0.910 |     1.165 |   3s |
   14 |      0.886 |     1.135 |   3s |
   15 |      0.867 |     1.116 |   3s |
   16 |      0.853 |     1.104 |   3s |
   17 |      0.842 |     1.100 |   3s |
   18 |      0.836 |     1.097 |   3s |
   19 |      0.832 |     1.097 |   3s |


In [85]:
# Let's tweak y-range to start at 0
model_scaled_dot_prod_bias_one_cycle = ScaledDotProdBias(n_user, n_item, 40, (0., 5.5)).to(dev)
fitter_one_cycle = FitterOneCycle(model_scaled_dot_prod_bias_one_cycle, nn.MSELoss(), torch.optim.Adam(model_scaled_dot_prod_bias_one_cycle.parameters(), lr=1e-2, betas=(0.9, 0.99)))
fitter_one_cycle.fit(num_epochs*2)

epoch | train_loss | test_loss | time
    0 |      7.453 |     7.434 |   3s |
    1 |      6.883 |     6.810 |   3s |
    2 |      5.626 |     5.615 |   3s |
    3 |      4.108 |     4.364 |   3s |
    4 |      2.801 |     3.274 |   3s |
    5 |      1.780 |     2.465 |   3s |
    6 |      1.353 |     1.999 |   4s |
    7 |      1.179 |     1.677 |   3s |
    8 |      1.072 |     1.441 |   3s |
    9 |      1.007 |     1.305 |   3s |
   10 |      0.959 |     1.196 |   3s |
   11 |      0.926 |     1.124 |   3s |
   12 |      0.903 |     1.087 |   3s |
   13 |      0.884 |     1.052 |   3s |
   14 |      0.870 |     1.035 |   3s |
   15 |      0.856 |     1.023 |   3s |
   16 |      0.846 |     1.016 |   3s |
   17 |      0.838 |     1.013 |   3s |
   18 |      0.834 |     1.012 |   3s |
   19 |      0.831 |     1.011 |   3s |


In [10]:
# Let's tweak y-range to start at 0.5
model_scaled_dot_prod_bias_one_cycle = ScaledDotProdBias(n_user, n_item, 40, (0.5, 5.5)).to(dev)
fitter_one_cycle = FitterOneCycle(model_scaled_dot_prod_bias_one_cycle, nn.MSELoss(), torch.optim.Adam(model_scaled_dot_prod_bias_one_cycle.parameters(), lr=1e-2, betas=(0.9, 0.99)))
fitter_one_cycle.fit(num_epochs*2)

epoch | train_loss | test_loss | time
    0 |      6.203 |     6.216 |   3s |
    1 |      5.811 |     5.814 |   3s |
    2 |      4.914 |     5.001 |   3s |
    3 |      3.687 |     3.921 |   3s |
    4 |      2.425 |     2.863 |   3s |
    5 |      1.561 |     2.172 |   3s |
    6 |      1.259 |     1.759 |   3s |
    7 |      1.114 |     1.512 |   3s |
    8 |      1.028 |     1.323 |   3s |
    9 |      0.974 |     1.204 |   3s |
   10 |      0.937 |     1.125 |   3s |
   11 |      0.910 |     1.077 |   3s |
   12 |      0.890 |     1.044 |   3s |
   13 |      0.874 |     1.018 |   3s |
   14 |      0.862 |     1.003 |   3s |
   15 |      0.851 |     0.994 |   3s |
   16 |      0.843 |     0.988 |   3s |
   17 |      0.837 |     0.987 |   3s |
   18 |      0.832 |     0.986 |   3s |
   19 |      0.830 |     0.986 |   3s |


In [11]:
# Let's try weight decay
model_scaled_dot_prod_bias_one_cycle = ScaledDotProdBias(n_user, n_item, 40, (0.5, 5.5)).to(dev)
fitter_one_cycle = FitterOneCycle(model_scaled_dot_prod_bias_one_cycle, nn.MSELoss(), torch.optim.Adam(model_scaled_dot_prod_bias_one_cycle.parameters(), lr=1e-2, betas=(0.9, 0.99), weight_decay=1e-2))
fitter_one_cycle.fit(num_epochs*2)

epoch | train_loss | test_loss | time
    0 |      5.271 |     3.430 |   2s |
    1 |      1.727 |     1.257 |   3s |
    2 |      1.129 |     1.227 |   3s |
    3 |      1.124 |     1.225 |   3s |
    4 |      1.126 |     1.232 |   3s |
    5 |      1.127 |     1.227 |   3s |
    6 |      1.127 |     1.231 |   3s |
    7 |      1.127 |     1.228 |   3s |
    8 |      1.127 |     1.230 |   4s |
    9 |      1.127 |     1.228 |   3s |
   10 |      1.125 |     1.233 |   3s |
   11 |      1.125 |     1.227 |   3s |
   12 |      1.124 |     1.226 |   3s |
   13 |      1.123 |     1.228 |   3s |
   14 |      1.122 |     1.226 |   3s |
   15 |      1.122 |     1.226 |   3s |
   16 |      1.121 |     1.226 |   3s |
   17 |      1.118 |     1.228 |   3s |
   18 |      1.117 |     1.228 |   3s |
   19 |      1.116 |     1.228 |   2s |


In [13]:
# Let's try a little bit less weight decay
model_scaled_dot_prod_bias_one_cycle = ScaledDotProdBias(n_user, n_item, 40, (0.5, 5.5)).to(dev)
fitter_one_cycle = FitterOneCycle(model_scaled_dot_prod_bias_one_cycle, nn.MSELoss(), torch.optim.Adam(model_scaled_dot_prod_bias_one_cycle.parameters(), lr=1e-2, betas=(0.9, 0.99), weight_decay=5e-3))
fitter_one_cycle.fit(num_epochs*2)

epoch | train_loss | test_loss | time
    0 |      5.423 |     3.773 |   2s |
    1 |      1.922 |     1.205 |   2s |
    2 |      1.051 |     1.135 |   3s |
    3 |      1.037 |     1.132 |   3s |
    4 |      1.039 |     1.136 |   3s |
    5 |      1.043 |     1.143 |   3s |
    6 |      1.042 |     1.136 |   3s |
    7 |      1.041 |     1.138 |   3s |
    8 |      1.041 |     1.141 |   3s |
    9 |      1.040 |     1.144 |   3s |
   10 |      1.039 |     1.132 |   3s |
   11 |      1.037 |     1.136 |   3s |
   12 |      1.037 |     1.130 |   3s |
   13 |      1.034 |     1.135 |   3s |
   14 |      1.033 |     1.136 |   3s |
   15 |      1.032 |     1.134 |   3s |
   16 |      1.029 |     1.135 |   3s |
   17 |      1.027 |     1.135 |   3s |
   18 |      1.025 |     1.135 |   3s |
   19 |      1.022 |     1.135 |   2s |


In [14]:
# Seems weight decay isn't really helping. Let's go back to our best performing model and see how many epochs it takes before we over-fit
model_scaled_dot_prod_bias_one_cycle = ScaledDotProdBias(n_user, n_item, 40, (0.5, 5.5)).to(dev)
fitter_one_cycle = FitterOneCycle(model_scaled_dot_prod_bias_one_cycle, nn.MSELoss(), torch.optim.Adam(model_scaled_dot_prod_bias_one_cycle.parameters(), lr=1e-2, betas=(0.9, 0.99)))
fitter_one_cycle.fit(num_epochs*4)

epoch | train_loss | test_loss | time
    0 |      6.271 |     6.275 |   3s |
    1 |      6.041 |     6.117 |   2s |
    2 |      5.651 |     5.828 |   3s |
    3 |      5.048 |     5.367 |   3s |
    4 |      4.274 |     4.760 |   3s |
    5 |      3.418 |     4.057 |   3s |
    6 |      2.542 |     3.316 |   3s |
    7 |      1.783 |     2.626 |   3s |
    8 |      1.379 |     2.169 |   3s |
    9 |      1.212 |     1.847 |   3s |
   10 |      1.131 |     1.583 |   3s |
   11 |      1.061 |     1.393 |   3s |
   12 |      1.007 |     1.247 |   3s |
   13 |      0.968 |     1.165 |   3s |
   14 |      0.944 |     1.100 |   3s |
   15 |      0.927 |     1.063 |   3s |
   16 |      0.917 |     1.032 |   3s |
   17 |      0.907 |     1.010 |   3s |
   18 |      0.902 |     0.998 |   3s |
   19 |      0.894 |     0.985 |   3s |
   20 |      0.890 |     0.979 |   3s |
   21 |      0.883 |     0.969 |   3s |
   22 |      0.879 |     0.965 |   3s |
   23 |      0.874 |     0.958 |   3s |
  

In [15]:
fitter_one_cycle.fit(num_epochs)

epoch | train_loss | test_loss | time
    0 |      0.834 |     0.938 |   2s |
    1 |      0.856 |     0.959 |   3s |
    2 |      0.888 |     0.970 |   3s |
    3 |      0.898 |     0.974 |   3s |
    4 |      0.893 |     0.972 |   3s |
    5 |      0.880 |     0.957 |   3s |
    6 |      0.865 |     0.947 |   3s |
    7 |      0.851 |     0.941 |   3s |
    8 |      0.839 |     0.938 |   3s |
    9 |      0.831 |     0.938 |   3s |


Seems that after 35 epochs we don't see any improvement in test set, and after 40 we start to over-fit.

In [13]:
# Let's try trunc_normal initialization
model_scaled_dot_prod_bias_one_cycle = ScaledDotProdBias(n_user, n_item, 40, (0.5, 5.5), trunc_normal=True).to(dev)
fitter_one_cycle = FitterOneCycle(model_scaled_dot_prod_bias_one_cycle, nn.MSELoss(), torch.optim.Adam(model_scaled_dot_prod_bias_one_cycle.parameters(), lr=1e-2, betas=(0.9, 0.99)))
fitter_one_cycle.fit(num_epochs*2)

epoch | train_loss | test_loss | time
    0 |      1.317 |     1.161 |   3s |
    1 |      1.009 |     0.997 |   2s |
    2 |      0.900 |     0.944 |   3s |
    3 |      0.878 |     0.944 |   3s |
    4 |      0.885 |     0.955 |   3s |
    5 |      0.894 |     0.970 |   3s |
    6 |      0.895 |     0.971 |   3s |
    7 |      0.896 |     0.972 |   3s |
    8 |      0.891 |     0.965 |   3s |
    9 |      0.887 |     0.962 |   4s |
   10 |      0.879 |     0.959 |   3s |
   11 |      0.873 |     0.954 |   3s |
   12 |      0.865 |     0.946 |   3s |
   13 |      0.858 |     0.942 |   3s |
   14 |      0.851 |     0.937 |   3s |
   15 |      0.845 |     0.935 |   3s |
   16 |      0.839 |     0.935 |   3s |
   17 |      0.834 |     0.934 |   3s |
   18 |      0.831 |     0.934 |   3s |
   19 |      0.829 |     0.934 |   3s |


In [14]:
# Wow, that really sped things up! Seems we're over-fitting after 5 epochs now. Let's try weight_decay again.
model_scaled_dot_prod_bias_one_cycle = ScaledDotProdBias(n_user, n_item, 40, (0.5, 5.5), trunc_normal=True).to(dev)
fitter_one_cycle = FitterOneCycle(model_scaled_dot_prod_bias_one_cycle, nn.MSELoss(), torch.optim.Adam(model_scaled_dot_prod_bias_one_cycle.parameters(), lr=1e-2, betas=(0.9, 0.99), weight_decay=1e-2))
fitter_one_cycle.fit(num_epochs*2)

epoch | train_loss | test_loss | time
    0 |      1.392 |     1.336 |   3s |
    1 |      1.164 |     1.230 |   3s |
    2 |      1.122 |     1.226 |   3s |
    3 |      1.125 |     1.227 |   3s |
    4 |      1.126 |     1.230 |   3s |
    5 |      1.127 |     1.226 |   3s |
    6 |      1.127 |     1.235 |   3s |
    7 |      1.128 |     1.229 |   3s |
    8 |      1.127 |     1.227 |   3s |
    9 |      1.126 |     1.230 |   3s |
   10 |      1.125 |     1.230 |   3s |
   11 |      1.125 |     1.233 |   3s |
   12 |      1.125 |     1.222 |   3s |
   13 |      1.124 |     1.228 |   3s |
   14 |      1.122 |     1.223 |   3s |
   15 |      1.122 |     1.227 |   3s |
   16 |      1.121 |     1.227 |   3s |
   17 |      1.119 |     1.227 |   3s |
   18 |      1.117 |     1.228 |   3s |
   19 |      1.115 |     1.228 |   3s |


In [15]:
# Hmm, under-fitting now :-(
# Let's try using fastai's "true_wd" algorithm.
# I think that's been added to torch as AdamW, so let's use that instead of Adam.
model_scaled_dot_prod_bias_one_cycle = ScaledDotProdBias(n_user, n_item, 40, (0.5, 5.5), trunc_normal=True).to(dev)
fitter_one_cycle = FitterOneCycle(model_scaled_dot_prod_bias_one_cycle, nn.MSELoss(), torch.optim.AdamW(model_scaled_dot_prod_bias_one_cycle.parameters(), lr=1e-2, betas=(0.9, 0.99)))
fitter_one_cycle.fit(num_epochs*2)

epoch | train_loss | test_loss | time
    0 |      1.314 |     1.161 |   3s |
    1 |      1.009 |     0.997 |   3s |
    2 |      0.902 |     0.947 |   3s |
    3 |      0.881 |     0.942 |   3s |
    4 |      0.883 |     0.951 |   3s |
    5 |      0.890 |     0.952 |   3s |
    6 |      0.893 |     0.959 |   3s |
    7 |      0.890 |     0.958 |   3s |
    8 |      0.888 |     0.956 |   3s |
    9 |      0.882 |     0.949 |   3s |
   10 |      0.877 |     0.942 |   3s |
   11 |      0.870 |     0.941 |   3s |
   12 |      0.863 |     0.934 |   3s |
   13 |      0.857 |     0.928 |   3s |
   14 |      0.851 |     0.925 |   3s |
   15 |      0.845 |     0.923 |   3s |
   16 |      0.839 |     0.922 |   3s |
   17 |      0.834 |     0.922 |   3s |
   18 |      0.831 |     0.923 |   3s |
   19 |      0.829 |     0.923 |   3s |


Benchmark is fastai's results from https://github.com/fastai/course-v3/blob/master/nbs/dl1/lesson4-collab.ipynb: 0.815 for validation set.

TODOs
 * look into weight decay
   * fastai has their own implementation that's here: https://github.com/fastai/fastai1/blob/bcef12e95405655481bb309761f8c552b51b2bd2/fastai/callback.py#L48
   * this seemed to help in first few epochs but lead to under-fitting. In general I'm not seeing over-fitting so I don't think this would be helpful.ab
 * look into OneCycleLR parameters vs fastai's implementation
