### Bug Hunt: Index Error

There's a  deliberate bug in the code below. Find  it.

Debugging PyTorch is niche skill -- this will help you hone it.

*Hint*: when you hit the bug, type in `debug` in  the next cell, and it will drop you in a debugger. If you're outside of IPython, you can also do `import pdb; pdb.set_trace()` or `python -m pdb script.py`.

In [None]:
# !python ../src/download.py

In [None]:
import numpy as np
from sklearn.model_selection import train_test_split
from torch import from_numpy
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
from torch.utils.data import BatchSampler
from torch.utils.data import RandomSampler

fh = np.load('data/dataset.npz')

# We have a bunch of feature columns and last column is the y-target
# Note pytorch is finicky about need int64 types
train_x = fh['train_x'].astype(np.int64) - 1000
train_y = fh['train_y']

# We've already split into train & test
X_test = fh['test_x'].astype(np.int64)
Y_test = fh['test_y']

X_train, X_val, Y_train, Y_val = train_test_split(train_x, train_y)


n_user = int(fh['n_user'])
n_item = int(fh['n_item'])

# columns are user_id, item_id and other features 
# we won't use the 3rd and 4th columns
print(X_train)
print(' ')
print(Y_train)



def dataloader(*arrs, batch_size=32):
    dataset = TensorDataset(*arrs)
    bs = BatchSampler(RandomSampler(dataset), 
                      batch_size=batch_size, drop_last=False)
    return DataLoader(dataset, batch_sampler=bs, num_workers=8)
 
train = dataloader(from_numpy(X_train), from_numpy(Y_train))
test = dataloader(from_numpy(X_test), from_numpy(Y_test))
val = dataloader(from_numpy(X_val), from_numpy(Y_val))

[[4258 1515 -995 -996]
 [4112  950 -881 -987]
 [2628  911 -980 -993]
 ...
 [3364 1026 -808 -982]
 [3217 -961 -966 -996]
 [3328 1001 -801 -997]]
 
[[3.]
 [4.]
 [3.]
 ...
 [2.]
 [4.]
 [3.]]


In [None]:
from abstract_model import AbstractModel

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl

from pytorch_lightning.loggers import TensorBoardLogger


def l2_regularize(array):
    return torch.sum(array ** 2.0)


class MF(AbstractModel):
    def __init__(self, n_user, n_item, k=18, c_vector=1.0, c_bias=1.0, batch_size=128):
        super().__init__()
        # These are simple hyperparameters
        self.k = k
        self.n_user = n_user
        self.n_item = n_item
        self.c_vector = c_vector
        self.c_bias = c_bias
        self.batch_size = batch_size
        self.save_hyperparameters()
        
        # These are learned and fit by PyTorch
        self.user = nn.Embedding(n_user, k)
        self.item = nn.Embedding(n_item, k)
        
        # We've added new terms here:
        self.bias_user = nn.Embedding(n_user, 1)
        self.bias_item = nn.Embedding(n_item, 1)
        self.bias = nn.Parameter(torch.ones(1))
    
    def forward(self, inputs):
        # This is the most import function in this script
        # These are the user indices, and correspond to "u" variable
        user_id = inputs[:, 0]
        # Item indices, correspond to the "i" variable
        item_id = inputs[:, 1]
        # vector user = p_u
        vector_user = self.user(user_id)
        # vector item = q_i
        vector_item = self.item(item_id)
        # this is a dot product & a user-item interaction: p_u * q_i
        ui_interaction = torch.sum(vector_user * vector_item, dim=1)
        
        # Pull out biases
        bias_user = self.bias_user(user_id).squeeze()
        bias_item = self.bias_item(item_id).squeeze()
        biases = (self.bias + bias_user + bias_item)

        # Add bias prediction to the interaction prediction
        prediction = ui_interaction + biases
        return prediction
    
    def loss(self, prediction, target):
        # MSE error between target = R_ui and prediction = p_u * q_i
        loss_mse = F.mse_loss(prediction, target.squeeze())
        return loss_mse, {"mse": loss_mse}
    
    def reg(self):
        # Add new regularization to the biases
        reg_bias_user =  l2_regularize(self.bias_user.weight) * self.c_bias
        reg_bias_item = l2_regularize(self.bias_item.weight) * self.c_bias
        # Compute L2 reularization over user (P) and item (Q) matrices
        reg_user =  l2_regularize(self.user.weight) * self.c_vector
        reg_item = l2_regularize(self.item.weight) * self.c_vector
        # Add up the MSE loss + user & item regularization
        log = {"reg_user": reg_user, "reg_item": reg_item,
               "reg_bias_user": reg_bias_user, "reg_bias_item": reg_bias_item}
        total = reg_user + reg_item + reg_bias_user + reg_bias_item
        return total, log

In [None]:
from pytorch_lightning.loggers.wandb import WandbLogger

k = 5
c_vector = 1e-5
c_bias = 5e-8
model = MF(n_user, n_item, k=k, c_bias=c_bias, c_vector=c_vector,
          batch_size=1024)

# add a logger
logger = WandbLogger(name="02_mf", project="simple_mf")

trainer = pl.Trainer(max_epochs=100, logger=logger,
                     early_stop_callback=True,
                     progress_bar_refresh_rate=1) 

GPU available: True, used: False
TPU available: False, using: 0 TPU cores


In [None]:
trainer.fit(model, train, val)

Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable
[34m[1mwandb[0m: Currently logged in as: [33msf-moody[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: Tracking run with wandb version 0.10.2
[34m[1mwandb[0m: Run data is saved locally in wandb/run-20200925_154404-82vxtmjr
[34m[1mwandb[0m: Syncing run [33m02_mf[0m



  | Name      | Type      | Params
----------------------------------------
0 | user      | Embedding | 30 K  
1 | item      | Embedding | 19 K  
2 | bias_user | Embedding | 6 K   
3 | bias_item | Embedding | 3 K   





HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

RuntimeError: index out of range: Tried to access index -758 out of table with 6040 rows. at /pytorch/aten/src/TH/generic/THTensorEvenMoreMath.cpp:418

In [None]:
debug

> [0;32m/opt/conda/lib/python3.6/site-packages/torch/nn/functional.py[0m(1484)[0;36membedding[0;34m()[0m
[0;32m   1482 [0;31m        [0;31m# remove once script supports set_grad_enabled[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1483 [0;31m        [0m_no_grad_embedding_renorm_[0m[0;34m([0m[0mweight[0m[0;34m,[0m [0minput[0m[0;34m,[0m [0mmax_norm[0m[0;34m,[0m [0mnorm_type[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1484 [0;31m    [0;32mreturn[0m [0mtorch[0m[0;34m.[0m[0membedding[0m[0;34m([0m[0mweight[0m[0;34m,[0m [0minput[0m[0;34m,[0m [0mpadding_idx[0m[0;34m,[0m [0mscale_grad_by_freq[0m[0;34m,[0m [0msparse[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1485 [0;31m[0;34m[0m[0m
[0m[0;32m   1486 [0;31m[0;34m[0m[0m
[0m


ipdb>  help



Documented commands (type help <topic>):
EOF    cl         disable  interact  next    psource  rv         unt   
a      clear      display  j         p       q        s          until 
alias  commands   down     jump      pdef    quit     source     up    
args   condition  enable   l         pdoc    r        step       w     
b      cont       exit     list      pfile   restart  tbreak     whatis
break  continue   h        ll        pinfo   return   u          where 
bt     d          help     longlist  pinfo2  retval   unalias  
c      debug      ignore   n         pp      run      undisplay

Miscellaneous help topics:
exec  pdb



ipdb>  up


> [0;32m/opt/conda/lib/python3.6/site-packages/torch/nn/modules/sparse.py[0m(114)[0;36mforward[0;34m()[0m
[0;32m    112 [0;31m        return F.embedding(
[0m[0;32m    113 [0;31m            [0minput[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mweight[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mpadding_idx[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mmax_norm[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 114 [0;31m            self.norm_type, self.scale_grad_by_freq, self.sparse)
[0m[0;32m    115 [0;31m[0;34m[0m[0m
[0m[0;32m    116 [0;31m    [0;32mdef[0m [0mextra_repr[0m[0;34m([0m[0mself[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  up


> [0;32m/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py[0m(532)[0;36m__call__[0;34m()[0m
[0;32m    530 [0;31m            [0mresult[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0m_slow_forward[0m[0;34m([0m[0;34m*[0m[0minput[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    531 [0;31m        [0;32melse[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 532 [0;31m            [0mresult[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mforward[0m[0;34m([0m[0;34m*[0m[0minput[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    533 [0;31m        [0;32mfor[0m [0mhook[0m [0;32min[0m [0mself[0m[0;34m.[0m[0m_forward_hooks[0m[0;34m.[0m[0mvalues[0m[0;34m([0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    534 [0;31m            [0mhook_result[0m [0;34m=[0m [0mhook[0m[0;34m([0m[0mself[0m[0;34m,[0m [0minput[0m[0;34m,[0m

ipdb>  up


> [0;32m<ipython-input-4-be3f80c4bae4>[0m(41)[0;36mforward[0;34m()[0m
[0;32m     39 [0;31m        [0mitem_id[0m [0;34m=[0m [0minputs[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0;36m1[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     40 [0;31m        [0;31m# vector user = p_u[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 41 [0;31m        [0mvector_user[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0muser[0m[0;34m([0m[0muser_id[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     42 [0;31m        [0;31m# vector item = q_i[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     43 [0;31m        [0mvector_item[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mitem[0m[0;34m([0m[0mitem_id[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  print(user_id)


tensor([3448, 3773,  757, 3673, 1611, 4964, -758, 5040, 3894, 2475, 2524, -366,
        -178, -118,  689, 4005, 3600, 3193, -614, 3193, -707, 5010, 1206, 2121,
        3546, 4466, 3115, -142, 2650, 3344, 2414, 4647])


ipdb>  q
