### 1. Importaci√≥n de librer√≠as y configuraci√≥n

In [None]:
import os
import json
import random
import requests
from pathlib import Path
from io import BytesIO
from concurrent.futures import ThreadPoolExecutor

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm.notebook import tqdm
import open_clip
import faiss

# Configuraci√≥n global
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

ROOT_DIR = Path("bundle_recognition")
IMG_DIR = ROOT_DIR / "images"
BUNDLE_DIR = IMG_DIR / "bundles"
PRODUCT_DIR = IMG_DIR / "products"
CHECKPOINT_DIR = ROOT_DIR / "checkpoints"

for d in [BUNDLE_DIR, PRODUCT_DIR, CHECKPOINT_DIR]:
    d.mkdir(parents=True, exist_ok=True)

### 2. Descarga de datos e im√°genes

In [None]:
df = pd.read_csv("data/master_train_dataset.csv")

def download_image(args):
    url, path = args
    if path.exists(): return True
    try:
        r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, timeout=10)
        r.raise_for_status()
        Image.open(BytesIO(r.content)).convert('RGB').save(path, 'JPEG')
        return True
    except:
        return False

# Preparar lista de descargas √∫nicas
tasks = []
for _, row in df.iterrows():
    tasks.append((row['bundle_url'], BUNDLE_DIR / f"{row['bundle_asset_id']}.jpg"))
    tasks.append((row['product_url'], PRODUCT_DIR / f"{row['product_asset_id']}.jpg"))
tasks = list(set(tasks))

with ThreadPoolExecutor(max_workers=10) as executor:
    list(tqdm(executor.map(download_image, tasks), total=len(tasks), desc="Descargando im√°genes"))

### Dataset y modelado

In [None]:
class BundleProductDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.df = dataframe
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        b_img = Image.open(BUNDLE_DIR / f"{row['bundle_asset_id']}.jpg").convert('RGB')
        p_img = Image.open(PRODUCT_DIR / f"{row['product_asset_id']}.jpg").convert('RGB')
        
        if self.transform:
            b_img, p_img = self.transform(b_img), self.transform(p_img)
        return b_img, p_img

# Carga de CLIP
model, _, preprocess = open_clip.create_model_and_transforms("ViT-B-32", pretrained="openai", device=DEVICE)

# Desbloquear las √∫ltimas capas para entrenamiento
for param in model.parameters(): param.requires_grad = False
for block in list(model.visual.transformer.resblocks)[-4:]:
    for param in block.parameters(): param.requires_grad = True

### Funci√≥n de loss y optimizaci√≥n

In [None]:
class ContrastiveLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

    def forward(self, b_emb, p_emb):
        b_emb = F.normalize(b_emb, dim=-1)
        p_emb = F.normalize(p_emb, dim=-1)
        
        logits = torch.exp(self.logit_scale) * b_emb @ p_emb.T
        labels = torch.arange(len(logits), device=DEVICE)
        
        loss_b = F.cross_entropy(logits, labels)
        loss_p = F.cross_entropy(logits.T, labels)
        return (loss_b + loss_p) / 2

criterion = ContrastiveLoss().to(DEVICE)
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-5)

### Bucle de entrenamiento por √©pocas

In [None]:
train_loader = DataLoader(BundleProductDataset(df, transform=preprocess), batch_size=32, shuffle=True)

for epoch in range(1, 6):
    model.train()
    total_loss = 0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}")
    
    for b_imgs, p_imgs in pbar:
        b_imgs, p_imgs = b_imgs.to(DEVICE), p_imgs.to(DEVICE)
        
        optimizer.zero_grad()
        b_features = model.encode_image(b_imgs)
        p_features = model.encode_image(p_imgs)
        
        loss = criterion(b_features, p_features)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        pbar.set_postfix(loss=total_loss/len(train_loader))
    
    torch.save(model.state_dict(), CHECKPOINT_DIR / f"clip_bundle_epoch_{epoch}.pt")

### Generaci√≥n de embedding e indexaci√≥n

In [None]:
def build_product_index(model, df, batch_size=64):
    model.eval()
    # 1. Obtener lista de productos √∫nicos para no repetir trabajo
    unique_products = df.drop_duplicates('product_asset_id').copy()
    
    product_embeddings = []
    
    # 2. Extraer embeddings en batches
    print(f"Generando embeddings para {len(unique_products)} productos...")
    with torch.no_grad():
        for i in tqdm(range(0, len(unique_products), batch_size)):
            batch_df = unique_products.iloc[i : i + batch_size]
            batch_imgs = []
            
            for _, row in batch_df.iterrows():
                img_path = PRODUCT_DIR / f"{row['product_asset_id']}.jpg"
                img = preprocess(Image.open(img_path).convert('RGB')).unsqueeze(0)
                batch_imgs.append(img)
            
            # Pasar al dispositivo y codificar
            batch_tensor = torch.cat(batch_imgs).to(DEVICE)
            features = model.encode_image(batch_tensor)
            features /= features.norm(dim=-1, keepdim=True) # Normalizaci√≥n L2
            product_embeddings.append(features.cpu().numpy())

    # 3. Configurar el √≠ndice FAISS
    embeddings_np = np.vstack(product_embeddings).astype('float32')
    dimension = embeddings_np.shape[1]
    
    # Usamos IndexFlatIP para similitud de coseno (Inner Product con vectores normalizados)
    index = faiss.IndexFlatIP(dimension)
    index.add(embeddings_np)
    
    # 4. Guardar metadatos para recuperar el ID del producto despu√©s
    index_metadata = unique_products[['product_asset_id', 'product_description']].to_dict('records')
    
    # Guardar a disco
    faiss.write_index(index, str(INDEX_DIR / "products.index"))
    with open(INDEX_DIR / "metadata.json", "w") as f:
        json.dump(index_metadata, f)
        
    print(f"√çndice creado y guardado con {index.ntotal} vectores.")
    return index, index_metadata

# Ejecutar la creaci√≥n del √≠ndice
product_index, metadata = build_product_index(model, df)

### Funci√≥n de B√∫squeda