In [None]:
import torch
from torch import nn, optim
from tqdm import tqdm
from torchvision import transforms
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.utils.data as data
import pandas as pd
from sklearn.model_selection import train_test_split
import numpy as np

%matplotlib inline

In [5]:
iterations=400000
learning_rate=0.0005
weight_decay=0.01
batch_size=64
vae_input_dim=768
vae_n_cat_feats=0
vae_hidden_dims=[512, 256, 128]
vae_embed_dim=32
vae_codebook_size=256
vae_codebook_normalize=False
vae_sim_vq=False
save_model_every=5000
eval_every=5000
data_folder="data/lfm"
data="data/"
save_dir_root="out/rqvae/lfm/"
wandb_logging=True
commitment_weight=0.25
vae_n_layers=3
vae_codebook_mode="modules.quantize.QuantizeForwardMode.ROTATION_TRICK"
force_data_process=False
data_split="beauty"
do_eval=True

# Generate Semantic IDs

In [None]:
class Encoder(nn.Module):
    def __init__(self, num_items, embedding_dim=16, dropout_prob=0.5):
        super(Encoder, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(num_items, 256),
            nn.GELU(),
            nn.LayerNorm(256),
            nn.Dropout(dropout_prob),
            nn.Linear(256, 128),
            nn.GELU(),
            nn.Linear(128, embedding_dim)
        )
    def forward(self, x):
        return self.mlp(x)

class Decoder(nn.Module):
    def __init__(self, num_items, embedding_dim=16, dropout_prob=0.5):
        super(Decoder, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(embedding_dim, 128),
            nn.GELU(),
            nn.LayerNorm(128),
            nn.Dropout(dropout_prob),
            nn.Linear(128, 256),
            nn.GELU(),
            nn.Linear(256, num_items)
        )
    def forward(self, x):
        return torch.sigmoid(self.mlp(x))
    
class Quanitzation(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, train, commitment_cost=0.25):
        super(Quanitzation, self).__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.commitment_cost = commitment_cost
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.embedding.weight.data.uniform_(-1.0 / num_embeddings, 1.0 / num_embeddings)
        self.train = train
        
    def forward(self, z_e):
        codebook = self.embedding.weight


        z_e_reshaped = z_e.view(-1, self.embedding_dim)
        distances = torch.cdist(z_e_reshaped, codebook)

        encoding_indices = torch.argmin(distances, dim=1)

        z_q = self.embedding(encoding_indices).view(z_e.shape)

        commitment_loss = F.mse_loss(z_q.detach(), z_e)
        codebook_loss = F.mse_loss(z_e.detach(), z_q)
        loss = commitment_loss + self.commitment_cost * codebook_loss
        
        return z_q, encoding_indices, loss


class RQ_VAE(nn.Module):
    def __init__(self, num_items, q_layers=3, embedding_dim=16, commitment_cost=0.25):
        super(RQ_VAE, self).__init__()
        self.encoder = Encoder(num_items, embedding_dim)
        self.layers = nn.ModuleList(modules=[
            Quanitzation(num_items, embedding_dim, True, commitment_cost)
            for _ in range(q_layers)
        ])
        self.decoder = Decoder(num_items, embedding_dim)
        
    def forward(self, x):
        z_e = self.encoder(x)
        z_q, encoding_indices, loss = self.quantization(z_e)
        x_hat = self.decoder(z_q)
        return x_hat, encoding_indices, loss