In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import tqdm
import polars as pl
import numpy as np

In [None]:
class DNN_RQVAE(nn.Module):
    def __init__(self, embedding_dim, latent_dim, codebook_size, quantization_depth=4, dnn_latent_dim=256):
        super(nn.Module, self).__init__()

        self.encoder_dnn = nn.Sequential(
            [nn.Gelu(), 
             nn.Linear(dnn_latent_dim, latent_dim),
             ]
        )
        self.decoder_dnn = nn.Sequential(
            [
             nn.Linear(latent_dim, dnn_latent_dim),
             nn.Gelu(),
             ]
        )

        self.encoder_adapter = nn.Linear(embedding_dim, dnn_latent_dim)
        self.decoder_adapter = nn.Linear(dnn_latent_dim, embedding_dim)

        self.codebooks = nn.ModuleList([nn.Embedding(codebook_size, latent_dim) for _ in range(quantization_depth)])
        self.quantization_depth = quantization_depth  # Set to 4 for this case

    def forward(self, x):
        # Input x: (B, T, E) - batch size, sequence length, embedding dimension
        z = self.encoder_adapter(x)  # Latent vectors: (B, T, L)
        z = self.encoder_dnn(z)
        hat_z = torch.zeros_like(z)  # Initialize quantized latent vectors
        indices_list = []  # Collect indices for each depth

        # Residual quantization with four different codebooks
        for d in range(self.quantization_depth):
            residual = z - hat_z
            # Compute distances to current codebook
            distances = torch.sum((residual.unsqueeze(2) - self.codebooks[d].weight.unsqueeze(0).unsqueeze(0)) ** 2, dim=3)
            # Get indices of nearest codebook entries
            c_d = torch.argmin(distances, dim=2)  # (B, T)
            # Get quantized vectors
            q_d = self.codebooks[d](c_d)  # (B, T, L)
            hat_z = hat_z + q_d  # Update approximation
            indices_list.append(c_d)

        # Stack indices to get four tokens per position: (B, T, 4)
        indices = torch.stack(indices_list, dim=2)
        # Apply straight-through estimator for gradient flow
        hat_z = hat_z + (z - z.detach())
        # Reconstruct input embeddings
        x_hat = self.decode_dnn(hat_z)  # (B, T, E)
        x_hat = self.decoder_adapter(x_hat)
        return x_hat, hat_z, z, indices

In [None]:
def rqvae_combined_loss(x, x_hat, hat_z, z, beta = 0.25):
    recon_loss = F.mse_loss(x_hat, x)
    commit_loss = F.mse_loss(z.detach(), hat_z)

    return recon_loss + beta * commit_loss

Горячий случай

In [None]:
class TextEmbeddingDataset(Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        item = self.df.row(idx, named=True)
        return item['item_id'], torch.tensor(item['emb'], dtype=torch.float32).unsqueeze(0)

In [None]:
# dataset
df = pl.read_parquet("./data/lvl2_data/items.parquet")
batch_size = 128
dataset = TextEmbeddingDataset(df)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
model = DNN_RQVAE(512, 256, 128, 4, 256)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [None]:
#train loop
optimizer = optim.Adam(model.parameters(), lr=1e-3)
model.train()
loss_hist = []
num_epochs = 10
for epoch in range(num_epochs):
    total_loss = 0.0
    for batch in tqdm.tqdm(dataloader):
        x = batch.to(device)  # Shape: (batch_size, 1, embedding_dim)
        optimizer.zero_grad()
        x_hat, hat_z, z, indices = model(x)
        loss = rqvae_combined_loss(x, x_hat, hat_z, z)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
    avg_loss = total_loss / len(dataloader.dataset)
    loss_hist.append(avg_loss)
    print(f"Epoch {epoch+1}, Average Loss: {avg_loss:.6f}")

In [None]:
torch.save(model.state_dict(), "./rqvae_trained.pth")

In [None]:
#inference