In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import pytorch_lightning as pl
import json
from pathlib import Path

# --- 1. Dataset Class ---
class RecSysDataset(Dataset):
    def __init__(self, parquet_file):
        self.data = pd.read_parquet(parquet_file)
        
        # Features NumÃ©ricas normalizadas (Gambi de engenheiro: dividir pelo max)
        self.users = torch.LongTensor(self.data['user_index'].values)
        self.items = torch.LongTensor(self.data['item_index'].values)
        
        # Features extras (Contexto)
        self.user_features = torch.FloatTensor(self.data[['avg_spend', 'purchase_count']].values)
        self.item_features = torch.FloatTensor(self.data[['popularity_score', 'avg_price']].values)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return {
            'user_id': self.users[idx],
            'item_id': self.items[idx],
            'user_feats': self.user_features[idx],
            'item_feats': self.item_features[idx]
        }

# --- 2. A Arquitetura Two-Tower ---
class TwoTowerModel(pl.LightningModule):
    def __init__(self, num_users, num_items, embedding_dim=32):
        super().__init__()
        self.save_hyperparameters()
        
        # --- Torre do UsuÃ¡rio ---
        self.user_embedding = nn.Embedding(num_users, embedding_dim)
        # Rede Neural que combina ID + Features NumÃ©ricas
        self.user_mlp = nn.Sequential(
            nn.Linear(embedding_dim + 2, 64), # +2 pois temos 2 features numÃ©ricas de user
            nn.ReLU(),
            nn.Linear(64, 32) # SaÃ­da final: vetor de tamanho 32
        )
        
        # --- Torre do Item ---
        self.item_embedding = nn.Embedding(num_items, embedding_dim)
        self.item_mlp = nn.Sequential(
            nn.Linear(embedding_dim + 2, 64),
            nn.ReLU(),
            nn.Linear(64, 32)
        )
        
    
    def forward(self, batch):
        # Embeddings -> Concatenate -> MLP -> Normalize
        # 1. Gerar Embeddings de ID
        u_emb = self.user_embedding(batch['user_id'])
        i_emb = self.item_embedding(batch['item_id'])
        
        # 2. Concatenar com features numÃ©ricas
        u_input = torch.cat([u_emb, batch['user_feats']], dim=1)
        i_input = torch.cat([i_emb, batch['item_feats']], dim=1)
        
        # 3. Passar pelos MLPs
        user_vector = self.user_mlp(u_input)
        item_vector = self.item_mlp(i_input)
        
        # 4. Normalizar vetores (para usar Cosine Similarity)
        user_vector = F.normalize(user_vector, p=2, dim=1)
        item_vector = F.normalize(item_vector, p=2, dim=1)
        
        return user_vector, item_vector

    def training_step(self, batch, batch_idx):
        user_vector, item_vector = self(batch)
        
        # --- In-Batch Negatives Loss (O Segredo do Retrieval) ---
        # Em vez de criar negativos manualmente, usamos os outros itens do batch como negativos.
        # Se o batch tem tamanho 128, para cada usuÃ¡rio temos 1 positivo e 127 negativos.
        
        # Matriz de similaridade (Batch x Batch)
        # U x I
        scores = torch.matmul(user_vector, item_vector.T)
        
        # O objetivo Ã© que a diagonal principal (user i com item i) tenha score alto
        labels = torch.arange(scores.size(0), device=self.device)
        
        loss = F.cross_entropy(scores * 10, labels) # *10 Ã© a "temperatura"
        
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

In [2]:
with open("./data/model_metadata.json", "r") as f:
    meta = json.load(f)
    
print(f"ðŸš€ Iniciando Treino Two-Tower. Users: {meta['num_users']}, Items: {meta['num_items']}")

dataset = RecSysDataset("./data/training_dataset.parquet")
dataloader = DataLoader(dataset, batch_size=1024, shuffle=True, num_workers=4)

model = TwoTowerModel(num_users=meta['num_users'], num_items=meta['num_items'])

print(model)

ðŸš€ Iniciando Treino Two-Tower. Users: 14761, Items: 8451
TwoTowerModel(
  (user_embedding): Embedding(14761, 32)
  (user_mlp): Sequential(
    (0): Linear(in_features=34, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=32, bias=True)
  )
  (item_embedding): Embedding(8451, 32)
  (item_mlp): Sequential(
    (0): Linear(in_features=34, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=32, bias=True)
  )
)


In [3]:
trainer = pl.Trainer(max_epochs=5, accelerator="gpu", devices=1)
trainer.fit(model, dataloader)

print("âœ… Modelo Treinado! Salvando artefatos...")
trainer.save_checkpoint("./data/two_tower_model.ckpt")

ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type       | Params | Mode  | FLOPs
--------------------------------------------------------------
0 | user_embedding | Embedding  | 472 K  | train | 0    
1 | user_mlp       | Sequential | 4.3 K  | train | 0    
2 | item_embedding | Embedding  | 270 K  | train | 0    
3 | item_mlp       | Sequential | 4.3 K  | train | 0    
--------------------------------------------------------------
751 K     Trainable params
0         Non-trainable params
751 K     Total params
3.006     Total estimated model params size (MB)
10        Modules in train mode
0         Modules in eval mode
0         Total Flops
/usr/local/lib/python3.10/dist-packages/pyt

Epoch 4: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 15/15 [00:01<00:00, 13.21it/s, v_num=0, train_loss=6.050]

`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 15/15 [00:01<00:00,  9.12it/s, v_num=0, train_loss=6.050]

`weights_only` was not set, defaulting to `False`.



âœ… Modelo Treinado! Salvando artefatos...
