In [1]:
# TODO:
# 1. Preprocess user-movie ratings:
#   All ratings >=4: 1
#   All ratings < 4 or not rated: 0

# 2. Convert the csv file to sparse matrix (storing non-zero values) only
#   - Convert the sparse matrix into numpy and feed them into the VAE


# 3. Build the VAE architecture:
#   Input layer -> hidden layer -> -latent -> hidden layer -> output layer
#       I -> 600 (relu) -> 200 (relu) -> 600 (relu) -> I (Sigmoid)

# 4. Loss function:
#   P(X|z) binomial, BCE + KLD

# 5. Configuration:
#   batch_size=500, epochs=200, optimizer=Adam, lr=1e-3

In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

In [2]:
rating_df = pd.read_csv('data/MovieLens/rating.csv')

In [3]:
rating_df.head()

Unnamed: 0,userId,movieId,rating,timestamp
0,1,2,3.5,2005-04-02 23:53:47
1,1,29,3.5,2005-04-02 23:31:16
2,1,32,3.5,2005-04-02 23:33:39
3,1,47,3.5,2005-04-02 23:32:07
4,1,50,3.5,2005-04-02 23:29:40


In [4]:

rating_df['implicit'] = np.where(rating_df.rating < 4, 0, 1)
rating_df.head()

Unnamed: 0,userId,movieId,rating,timestamp,implicit
0,1,2,3.5,2005-04-02 23:53:47,0
1,1,29,3.5,2005-04-02 23:31:16,0
2,1,32,3.5,2005-04-02 23:33:39,0
3,1,47,3.5,2005-04-02 23:32:07,0
4,1,50,3.5,2005-04-02 23:29:40,0


In [5]:

rows = rating_df.userId - min(rating_df.userId)
cols = rating_df.movieId - min(rating_df.movieId)
user_movie_coo_tensor = torch.sparse_coo_tensor([rows, cols], rating_df.implicit, dtype=torch.float)
user_movie_dataset = TensorDataset(user_movie_coo_tensor.to_dense())
user_movie_dataloader = DataLoader(user_movie_dataset, batch_size=500, shuffle=True)

In [6]:
class MovieLensVAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(in_features=input_dim, out_features=hidden_dim, bias=True),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
        )
        self.mu = nn.Linear(in_features=hidden_dim, out_features=latent_dim, bias=True)
        self.logvar = nn.Linear(in_features=hidden_dim, out_features=latent_dim, bias=True)
        self.decoder = nn.Sequential(
            nn.Linear(in_features=latent_dim, out_features=hidden_dim, bias=True),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(in_features=hidden_dim, out_features=input_dim, bias=True),
            nn.Sigmoid(), # Return values in the range of 0-1
        )
    
    def forward(self, x):
        encoded = self.encoder(x)
        mu = self.mu(encoded)
        logvar = self.logvar(encoded)
        z = self.reparameterization(mu, logvar)
        output = self.decoder(z)
        return output, mu, logvar

    def reparameterization(self, mu, logvar):
        e = torch.randn_like(mu)
        std = torch.exp(logvar/2)
        return mu + std * e


In [8]:
# Sum of loss over all elements in the same batch
# For each element, the loss is summed over all dimension
def loss_func(x, x_output, mu, logvar):
    # With binomial distribution at every pixel, the likelihood function becomes negative of 
    BCE_loss = F.binary_cross_entropy(x_output, x, reduction='sum')
    # KL divergence: -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    KLD = 0.5 * torch.sum(mu.pow(2) + logvar.exp() - 1 - logvar)
    return BCE_loss + KLD

In [None]:
input_dim = max(cols) + 1
movielens_vae = MovieLensVAE(input_dim=input_dim, hidden_dim=600, latent_dim=200)
epochs = 200
optimizer = optim.Adam(movielens_vae.parameters(), lr=1e-3)
mu, logvar = None, None

for epoch in range(epochs):
    loss_error = 0
    for dl in user_movie_dataloader:
        x_output, mu_output, logvar_output = movielens_vae(dl[0])
        mu, logvar = mu_output, logvar_output
        loss = loss_func(dl[0], x_output, mu, logvar)

        loss_error += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    if epoch % 10 == 0:
        print(f"Epoch={epoch}: Error={loss_error}")

Epoch=0: Error=393560091.5625


KeyboardInterrupt: 

: 