In [1]:
import torch
import torch.nn as nn

from tqdm.notebook import tqdm

from typing import List

import pandas as pd

In [2]:
users = 3974
movies = 3564

train_csv = pd.read_csv("../../../data/train_data.csv")
test_csv = pd.read_csv("../../../data/test_data.csv")

train_csv["user_id"] = train_csv["user_id"].apply(lambda x: x - 1)
train_csv["movie_id"] = train_csv["movie_id"].apply(lambda x: x - 1)

test_csv["user_id"] = test_csv["user_id"].apply(lambda x: x - 1)
test_csv["movie_id"] = test_csv["movie_id"].apply(lambda x: x - 1)

# Split into train and validation
train_data = train_csv.drop(["timestamp"], axis=1).sample(frac=0.8)
validation_data = train_csv.drop(train_data.index).drop(["timestamp"], axis=1)

assert train_data.shape[0] + validation_data.shape[0] == train_csv.shape[0]

In [131]:
train_csv.head()

Unnamed: 0,user_id,movie_id,rating,timestamp
0,0,1159,5,974769817
1,0,1128,3,974769817
2,0,3327,4,974769817
3,0,2658,2,974769817
4,0,979,3,974769817


In [3]:
train_uir = train_data.values
val_uir = validation_data.values

total_uir = train_csv.values[:, :-1]
test_ui = test_csv.values[:, 1:-1]

## Model

In [4]:
def reconstruct(
    P: torch.FloatTensor,
    Q: torch.FloatTensor,
    bu: torch.FloatTensor,
    bi: torch.FloatTensor,
    mu: float,
) -> torch.FloatTensor:
    P = P.cuda()
    Q = Q.cuda()
    bu = bu.cuda()
    bi = bi.cuda()
    
    Bu = torch.concat((bu, torch.ones(len(bu), 1, device="cuda")), dim=1)
    Bi = torch.concat((bi, torch.ones(len(bi), 1, device="cuda")), dim=1)
    
    mat = mu + Bu@Bi.T + P@Q.T
    
    return torch.clip(mat, 1, 5).cpu()

In [132]:
def fit(
    uir_mat: torch.IntTensor, # User Item rating mat
    k: int,
    lr: float,
    λ: float,
    iters: int,
    n_users: int,
    n_movies: int,
    mu: float = None,
) -> List[torch.FloatTensor]:
    # Initialize params
    uir_mat = uir_mat.cuda()
    expected = uir_mat[:, 2].float()
    n_interactions = expected.shape[0]
    
    if mu is None:
        mu = expected.mean()
    
    P = torch.randn(n_users, k, requires_grad=True, device="cuda")
    Q = torch.randn(n_movies, k, requires_grad=True, device="cuda")
    with torch.no_grad():
        P *= .1
        Q *= .1
    bu = torch.zeros(n_users, requires_grad=True, device="cuda")
    bi = torch.zeros(n_movies, requires_grad=True, device="cuda")
    
#     criterion = nn.MSELoss()
    
    # Fit
    ones_user = torch.ones(n_users, 1, requires_grad=False, device="cuda")
    ones_item = torch.ones(n_movies, 1, requires_grad=False, device="cuda")

    min_loss = torch.inf
    params = []
    for it in tqdm(range(iters)):
        for idx, (u, i, r) in tqdm(enumerate(uir_mat)):
            pred = mu + bu[u] + bi[i] + torch.dot(P[u], Q[i])
#             loss = criterion(pred, expected[idx])
            loss = (pred - expected[idx])**2 + λ/2*(torch.sum(P**2) + torch.sum(Q**2))
            
            loss.backward()

            with torch.no_grad():
                P -= lr*P.grad
                Q -= lr*Q.grad
                bu -= lr*bu.grad
                bi -= lr*bi.grad

            P.grad.zero_()
            Q.grad.zero_()
            bu.grad.zero_()
            bi.grad.zero_()
#             break
        print(float(loss))
            
    print(min_loss)
    return P, Q, bu, bi

In [133]:
fitted_params = fit(
    uir_mat=torch.from_numpy(train_uir),
    k=5,
    lr=0.005, 
    λ=0.02, 
    iters=1, 
    n_users=users, 
    n_movies=movies, 
)

  0%|          | 0/1 [00:00<?, ?it/s]

0it [00:00, ?it/s]

KeyboardInterrupt: 

In [113]:
print(fitted_params[2])
print(fitted_params[3])

tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0', requires_grad=True)
tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0', requires_grad=True)


## Fit

In [48]:
arr_1 = torch.zeros(4, 1, requires_grad=True)
arr_2 = torch.zeros(4, 1, requires_grad=True)

ones = torch.ones(4, 1, requires_grad=False)
arr_3 = torch.cat([arr_1, ones], dim=1)
arr_4 = torch.cat([ones, arr_2], dim=1)

In [49]:
out = arr_3@arr_4.T
((out[0, 0]-2)**2).backward()

print(arr_3.grad)

None


In [50]:
out

tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]], grad_fn=<MmBackward0>)