In [2]:
import json
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F

# --- 1. Cargar dataset ---
with open("dataset_mercadona.json", "r", encoding="utf-8") as f:
    data = json.load(f)

pares = data["pares"]
tickets = data["tickets"]

# --- 2. Crear vocabulario ---
productos = list({p for ticket in tickets for p in ticket})
productos.sort()
token2id = {p: i for i, p in enumerate(productos)}
id2token = {i: p for p, i in token2id.items()}
vocab_size = len(productos)

# --- 3. Dataset PyTorch ---
class ProductosDataset(Dataset):
    def __init__(self, pares, token2id):
        self.inputs = []
        self.targets = []
        for par in pares:
            entrada_ids = [token2id[p] for p in par["entrada"]]
            self.inputs.append(torch.tensor(entrada_ids, dtype=torch.long))
            self.targets.append(token2id[par["siguiente"]])
        
    def __len__(self):
        return len(self.inputs)
    
    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx]

# --- 4. Collate function para padding ---
def collate_fn(batch):
    inputs, targets = zip(*batch)
    lengths = [len(seq) for seq in inputs]
    max_len = max(lengths)
    padded_inputs = torch.zeros(len(inputs), max_len, dtype=torch.long)
    for i, seq in enumerate(inputs):
        padded_inputs[i, :lengths[i]] = seq
    targets = torch.tensor(targets, dtype=torch.long)
    return padded_inputs, targets

dataset = ProductosDataset(pares, token2id)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, collate_fn=collate_fn)

In [4]:
# --- 5. Definir transformer pequeño ---
class MiniTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=64, nhead=4, num_layers=1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        # x: [batch, seq_len] -> transformer espera [seq_len, batch, d_model]
        x = self.embedding(x).permute(1,0,2)
        x = self.transformer(x)
        # predecimos solo el último token de la secuencia
        x = self.fc(x[-1])
        return F.log_softmax(x, dim=-1)

model = MiniTransformer(vocab_size)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.NLLLoss()



In [5]:
# --- 6. Entrenamiento ---
epochs = 5
for epoch in range(epochs):
    total_loss = 0
    for batch_inputs, batch_targets in dataloader:
        optimizer.zero_grad()
        output = model(batch_inputs)
        loss = criterion(output, batch_targets)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}/{epochs} - Loss: {total_loss/len(dataloader):.4f}")

Epoch 1/5 - Loss: 3.4047
Epoch 2/5 - Loss: 3.2376
Epoch 3/5 - Loss: 3.1951
Epoch 4/5 - Loss: 3.1620
Epoch 5/5 - Loss: 3.1441


In [18]:
# --- 7. Ejemplo de predicción ---
def predecir_siguiente(model, secuencia, token2id, id2token):
    model.eval()
    entrada_ids = torch.tensor([[token2id[p] for p in secuencia]], dtype=torch.long)
    with torch.no_grad():
        output = model(entrada_ids)
        idx = torch.argmax(output, dim=-1).item()
    return id2token[idx]

def predecir_siguiente_filtrado(model, secuencia, token2id, id2token):
    model.eval()
    entrada_ids = torch.tensor([[token2id[p] for p in secuencia]], dtype=torch.long)
    with torch.no_grad():
        output = model(entrada_ids)  # log_softmax
        probs = output.exp().squeeze()
        # Filtrar productos ya comprados
        for p in secuencia:
            probs[token2id[p]] = 0.0
        probs = probs / probs.sum()  # re-normalizar
        idx = torch.argmax(probs).item()
    return id2token[idx]



ejemplo = ["Chocolate negro", "Arroz redondo"]
print("Entrada:", ejemplo)
print("Predicción siguiente producto:", predecir_siguiente_filtrado(model, ejemplo, token2id, id2token))

Entrada: ['Chocolate negro', 'Arroz redondo']
Predicción siguiente producto: Helado de vainilla
