In [6]:
# ==========================================
# 0. IMPORTS
# ==========================================

#probando si esto se commitea Segunda prueba desde otro pc
#Nota de intiti: si van a trabajar desde un entorno local (Visual), 
# aseg√∫rense de tener instaladas las librer√≠as necesarias.
#tutorial: ctrl + √± para abrir el terminal y luego pegar los siguientes comandos:
#comando para instalar torch: pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 -> En caso que quieran usar GPU.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from ortools.constraint_solver import pywrapcp, routing_enums_pb2
import concurrent.futures # LIBRER√çA MAGICA PARA PARALELISMO
import multiprocessing
import numpy as np
import os
import glob
import math
from tqdm import tqdm
import os
import requests 
import gc # Garbage Collector para gesti√≥n de memoria

In [7]:
# ==========================================
# 1. CONFIGURACI√ìN, GPU Y DESCARGA DE DATOS
# ==========================================


# --- A. CONFIGURACI√ìN DEL HARDWARE (DEVICE) ---
# Esto es vital para que el Bloque de entrenamiento sepa qu√© usar
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
    print(f"‚úÖ GPU DETECTADA: {torch.cuda.get_device_name(0)}")
    print(f"   (Memoria disponible: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB)")
else:
    DEVICE = torch.device("cpu")
    print("‚ö†Ô∏è GPU NO DETECTADA: Entrenando en CPU (ser√° lento).")

# --- B. CONFIGURACI√ìN DEL REPOSITORIO ---
REPO_USER = "felipe-astudillo-s"
REPO_NAME = "TransformerTSP"
BRANCH = "main" # ‚ö†Ô∏è IMPORTANTE: Si tus datos no est√°n en 'main', cambia esto por el nombre de tu rama o commit.

REPO_FOLDERS = {
    "EASY":   "Data/Easy",
    "MEDIUM": "Data/Medium",
    "HARD":   "Data/Hard"
}

BASE_LOCAL_DIR = os.path.join(os.getcwd(), "data_repo")

def download_folder_from_github(user, repo, repo_folder_path, local_output_dir, branch="main"):
    """Descarga todos los .npz de una carpeta de GitHub usando la API."""
    api_url = f"https://api.github.com/repos/{user}/{repo}/contents/{repo_folder_path}?ref={branch}"
    
    print(f"üîç Consultando API para: {repo_folder_path}...")
    try:
        response = requests.get(api_url)
        if response.status_code == 404:
            print(f"‚ùå Error 404: No existe la carpeta '{repo_folder_path}' en la rama '{branch}'.")
            return local_output_dir
        if response.status_code != 200:
            print(f"‚ùå Error API ({response.status_code}): {response.text}")
            return local_output_dir

        files_list = response.json()
        
        if not os.path.exists(local_output_dir):
            os.makedirs(local_output_dir)

        if isinstance(files_list, dict) and 'message' in files_list:
            print("‚ùå Error: La ruta parece no ser una carpeta v√°lida.")
            return local_output_dir

        count = 0
        for item in files_list:
            if item['type'] == 'file' and item['name'].endswith('.npz'):
                local_path = os.path.join(local_output_dir, item['name'])
                if not os.path.exists(local_path):
                    try:
                        r = requests.get(item['download_url'])
                        with open(local_path, 'wb') as f:
                            f.write(r.content)
                        count += 1
                    except Exception as e:
                        print(f"  ‚ùå Fall√≥ {item['name']}: {e}")
                else:
                    count += 1 # Ya exist√≠a
        
        print(f"‚úÖ Fase {repo_folder_path}: {count} archivos listos en {local_output_dir}")
        return local_output_dir

    except Exception as e:
        print(f"‚ùå Error de conexi√≥n: {e}")
        return local_output_dir

# --- C. EJECUCI√ìN DE DESCARGA ---
PATHS = {}
print(f"\n‚öôÔ∏è Sincronizando con GitHub ({REPO_USER}/{REPO_NAME})...")

for phase_name, repo_path in REPO_FOLDERS.items():
    local_target = os.path.join(BASE_LOCAL_DIR, phase_name)
    final_path = download_folder_from_github(REPO_USER, REPO_NAME, repo_path, local_target, BRANCH)
    PATHS[phase_name] = final_path

# --- D. CURRICULUM ---
CURRICULUM = [
    {"phase": "EASY",   "epochs": 20, "lr": 1e-3, "bs": 128},
    {"phase": "MEDIUM", "epochs": 15, "lr": 1e-4, "bs": 64},
    {"phase": "HARD",   "epochs": 30, "lr": 1e-4, "bs": 32}
]

print(f"\nüìÇ Rutas configuradas correctamente.")
print(f"üöÄ Listo para ejecutar el Bloque de Entrenamiento.")

‚ö†Ô∏è GPU NO DETECTADA: Entrenando en CPU (ser√° lento).

‚öôÔ∏è Sincronizando con GitHub (felipe-astudillo-s/TransformerTSP)...
üîç Consultando API para: Data/Easy...
‚úÖ Fase Data/Easy: 20 archivos listos en c:\Users\intix\OneDrive\Documentos\DeltaDefiitive\TransformerTSP\data_repo\EASY
üîç Consultando API para: Data/Medium...
‚úÖ Fase Data/Medium: 20 archivos listos en c:\Users\intix\OneDrive\Documentos\DeltaDefiitive\TransformerTSP\data_repo\MEDIUM
üîç Consultando API para: Data/Hard...
‚úÖ Fase Data/Hard: 10 archivos listos en c:\Users\intix\OneDrive\Documentos\DeltaDefiitive\TransformerTSP\data_repo\HARD

üìÇ Rutas configuradas correctamente.
üöÄ Listo para ejecutar el Bloque de Entrenamiento.


In [17]:
# ==========================================
# 2. Modelo
# ==========================================

class EncoderPointerModel(nn.Module):
    def __init__(self, input_dim=2, d_model=64, nhead=8, enc_layers=3, dec_layers=2, dropout=0.1):
        super().__init__()
        
        # 1. EMBEDDING & ENCODER (El "Mapa")
        self.embedding = nn.Linear(input_dim, d_model)
        
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=512, dropout=dropout, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=enc_layers)
        
        # 2. DECODER (El "Navegante")
        # Usamos standard TransformerDecoder. Nota: El decoder necesita una memoria (encoder output)
        decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=512, dropout=dropout, batch_first=True)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=dec_layers)
        
        # 3. COMPONENTES DE "POINTING"
        self.start_token = nn.Parameter(torch.randn(1, 1, d_model)) # Token de inicio
        self.W_q = nn.Linear(d_model, d_model) # Proyecci√≥n para la Query
        self.d_model = d_model

    def forward(self, x, tgt_indices=None, teacher_forcing=True):
        """
        x: [Batch, N, 2] (Coordenadas)
        tgt_indices: [Batch, N] (Indices de la ruta real, para entrenamiento)
        """
        Batch, N, _ = x.size()
        device = x.device

        # --- A. ENCODER ---
        # 1. Convertir coordenadas a vectores
        h = self.embedding(x)  # [B, N, d_model]
        # 2. Procesar contexto global
        memory = self.encoder(h) # [B, N, d_model]

        # --- B. DECODER LOOP ---
        # Preparamos el bucle
        decoder_input = self.start_token.expand(Batch, -1, -1) # [B, 1, d_model]
        
        # M√°scaras
        visited_mask = torch.zeros(Batch, N, dtype=torch.bool, device=device)
        logits_list = []
        
        # Si teacher_forcing=True, iteramos N veces (largo del tour). Si no, tambi√©n.
        steps = N 
        
        for t in range(steps):
            # 1. Pasamos por el Transformer Decoder Est√°ndar
            # dec_out: [B, 1, d_model] (El "pensamiento" actual del decoder)
            dec_out = self.decoder(tgt=decoder_input, memory=memory)
            
            # 2. MECANISMO DE POINTER (Atenci√≥n)
            # Calculamos qu√© tanto se parece el pensamiento actual (Query) a cada ciudad en memoria (Key)
            query = self.W_q(dec_out) # [B, 1, d_model]
            
            # Score = Query ‚Ä¢ Memory_Transpuesta
            # [B, 1, d] x [B, d, N] -> [B, 1, N]
            scores = torch.matmul(query, memory.transpose(1, 2)) / math.sqrt(self.d_model)
            scores = scores.squeeze(1) # [B, N]

            # 3. ENMASCARAR VISITADOS (Para no repetir ciudades)
            if not teacher_forcing:
                scores = scores.masked_fill(visited_mask, float('-inf'))

            logits_list.append(scores)

            # 4. PREPARAR ENTRADA PARA EL SIGUIENTE PASO
            if teacher_forcing and tgt_indices is not None:
                # Entrenamiento: Usamos la ciudad real que debi√≥ visitar
                next_idx = tgt_indices[:, t]
            else:
                # Inferencia: Usamos la ciudad que el modelo acaba de elegir (Greedy)
                probs = F.softmax(scores, dim=-1)
                next_idx = probs.argmax(dim=-1)
                
                # Actualizar m√°scara de visitados
                visited_mask.scatter_(1, next_idx.unsqueeze(1), True)

            # Buscamos el embedding de la ciudad elegida en la memoria
            # [B, 1, d_model]
            next_input = torch.gather(memory, 1, next_idx.view(Batch, 1, 1).expand(-1, -1, self.d_model))
            
            # Actualizamos la entrada del decoder para la siguiente vuelta
            decoder_input = next_input

        # Retornamos todos los pasos apilados: [B, N, N]
        return torch.stack(logits_list, dim=1)

In [16]:

# ==========================================
# 3. UTILIDADES DE EVALUACI√ìN
# ==========================================
def calculate_gap(model, loader, device):
    """Calcula el Optimality GAP (%) usando Greedy Decoding en un batch."""
    model.eval()
    try:
        # Tomamos solo el primer batch para no demorar el entrenamiento
        batch_x, batch_y = next(iter(loader))
    except StopIteration:
        return 0.0 # Loader vac√≠o

    batch_x, batch_y = batch_x.to(device), batch_y.to(device)
    batch_size, n_nodes, _ = batch_x.size()

    with torch.no_grad():
        # Inferencia Greedy (Teacher Forcing = False)
        # El modelo genera la secuencia de √≠ndices autom√°ticamente
        logits = model(batch_x, teacher_forcing=False)
        # logits: [Batch, N, N_nodes]

        pred_indices = logits.argmax(dim=2) # [Batch, N]

        # Stackear para formar tour
        pred_tour = pred_indices

    # --- C√°lculo de Distancias ---
    def get_dist(pts, idx):
        # pts: [B, N, 2], idx: [B, N]
        gathered = torch.gather(pts, 1, idx.unsqueeze(-1).expand(-1, -1, 2))
        next_pts = torch.roll(gathered, -1, dims=1)
        return torch.norm(gathered - next_pts, dim=2).sum(dim=1)

    cost_model = get_dist(batch_x, pred_tour)
    cost_oracle = get_dist(batch_x, batch_y)

    gap = ((cost_model - cost_oracle) / cost_oracle).mean().item() * 100
    return gap

In [21]:
# ==========================================
# 4. BUCLE DE ENTRENAMIENTO (LAZY LOADING) - SIMPLIFICADO
# ==========================================

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import glob
import os
from tqdm import tqdm
import gc

# Instanciar Modelo Simplificado
# Ajusta input_dim=2 (x,y), d_model=128, etc.
model = SimplifiedPointerNetwork(input_dim=2, d_model=128, nhead=8, enc_layers=3, dec_layers=2).to(DEVICE)
criterion = nn.CrossEntropyLoss()

print("\nüöÄ INICIANDO ENTRENAMIENTO (Simplified Pointer Network)")

for stage in CURRICULUM:
    phase = stage['phase']
    folder_path = PATHS[phase]

    print(f"\n{'='*60}")
    print(f"üéì FASE ACTUAL: {phase} | Epochs: {stage['epochs']}")
    print(f"{'='*60}")

    # Buscar archivos .npz
    all_files = glob.glob(os.path.join(folder_path, "*.npz"))

    if not all_files:
        print(f"‚ö†Ô∏è ALERTA: No encontr√© datos en {folder_path}. Saltando fase.")
        continue

    print(f"üìÇ Archivos detectados: {len(all_files)}")

    optimizer = optim.Adam(model.parameters(), lr=stage['lr'])
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5)

    for epoch in range(stage['epochs']):
        model.train()
        epoch_loss_accum = 0
        total_batches = 0
        current_gap = 0

        # --- BUCLE SOBRE ARCHIVOS (Lazy Loading) ---
        for file_idx, file_path in enumerate(all_files):
            try:
                # 1. Cargar Archivo a RAM
                data = np.load(file_path)
                points = torch.FloatTensor(data['points'])
                solutions = torch.LongTensor(data['solutions'])

                # Normalizaci√≥n defensiva (0-1)
                if points.max() > 1.0: points /= points.max()

                dataset = TensorDataset(points, solutions)
                loader = DataLoader(dataset, batch_size=stage['bs'], shuffle=True)

                # 2. Entrenar sobre este archivo
                pbar = tqdm(loader, desc=f"Ep {epoch+1} | {os.path.basename(file_path)}", leave=False)

                for batch_x, batch_y in pbar:
                    batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
                    optimizer.zero_grad()

                    # A. Forward Pass (Entrenamiento con Teacher Forcing impl√≠cito al pasar tgt_indices)
                    # batch_x: [Batch, N, 2]
                    # batch_y: [Batch, N] (Indices de la ruta √≥ptima)
                    # logits: [Batch, N, N] (Probabilidades para cada paso)
                    logits = model(batch_x, tgt_indices=batch_y, teacher_forcing=True)

                    # B. Calcular Loss
                    # Aplanamos:
                    # logits -> [Batch * N, N] (Predicciones para cada paso de cada tour)
                    # batch_y -> [Batch * N] (Target real para cada paso)
                    loss = criterion(logits.reshape(-1, logits.size(-1)), batch_y.reshape(-1))

                    # C. Backward Pass
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # Evita explosi√≥n de gradientes
                    optimizer.step()

                    epoch_loss_accum += loss.item()
                    pbar.set_postfix({'loss': loss.item()})

                total_batches += len(loader)

                # Calcular GAP solo en el √∫ltimo archivo de la √©poca para ahorrar tiempo
                # Nota: Aseg√∫rate de que calculate_gap est√© definida y use teacher_forcing=False
                if file_idx == len(all_files) - 1:
                     current_gap = calculate_gap(model, loader, DEVICE)

                # 3. LIMPIEZA DE MEMORIA
                del data, points, solutions, dataset, loader
                gc.collect()
                torch.cuda.empty_cache()

            except Exception as e:
                print(f"‚ùå Error leyendo archivo {file_path}: {e}")
                continue

        # --- REPORTE DE √âPOCA ---
        avg_loss = epoch_loss_accum / total_batches if total_batches > 0 else 0
        print(f"    üìâ Epoca {epoch+1} Terminada | Loss: {avg_loss:.4f} | üìä GAP: {current_gap:.2f}%")

        # Scheduler Step
        scheduler.step(avg_loss)

        # Guardar Checkpoint
        save_file = os.path.join(folder_path, f"checkpoint_{phase}_best.pth")
        torch.save(model.state_dict(), save_file)

print("\nüèÜ ENTRENAMIENTO COMPLETADO EXITOSAMENTE.")


üöÄ INICIANDO ENTRENAMIENTO (Simplified Pointer Network)

üéì FASE ACTUAL: EASY | Epochs: 20
üìÇ Archivos detectados: 20


                                                                                     

KeyboardInterrupt: 

In [22]:
# ==========================================
# 5. VALIDACI√ìN FINAL COMPLETA (MULTI-PART)
# ==========================================

# --- CONFIGURACI√ìN DE RUTAS ---
PATHS_CONFIG = {
    "EASY": {
        "ckpt": "data_repo/EASY/checkpoint_EASY_best.pth",
        "val_folder": "Data/Validation/Easy",
        "val_prefix": "tsp_easy"
    },
    "MEDIUM": {
        "ckpt": "data_repo/MEDIUM/checkpoint_MEDIUM_best.pth",
        "val_folder": "Data/Validation/Medium",
        "val_prefix": "tsp_medium"
    },
    "HARD": {
        "ckpt": "data_repo/HARD/checkpoint_HARD_best.pth",
        "val_folder": "Data/Validation/Hard",
        "val_prefix": "tsp_hard"
    }
}

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_tour_distance(points, tour_indices):
    """Calcula distancia total de la ruta."""
    gathered = torch.gather(points, 1, tour_indices.unsqueeze(-1).expand(-1, -1, 2))
    next_pts = torch.roll(gathered, -1, dims=1)
    dist = torch.norm(gathered - next_pts, dim=2).sum(dim=1)
    return dist

def load_all_validation_parts(folder, prefix):
    """
    Busca TODAS las partes (part_0, part_1...) y las une en un solo dataset gigante.
    """
    if not os.path.exists(folder):
        print(f"‚ùå Carpeta no existe: {folder}")
        return None, None
    
    # Buscar todos los archivos que coincidan
    search_pattern = os.path.join(folder, f"{prefix}*.npz")
    all_files = sorted(glob.glob(search_pattern))
    
    if not all_files:
        print(f"‚ùå No encontr√© archivos {prefix}*.npz en {folder}")
        return None, None
    
    print(f"üìö Uniendo {len(all_files)} archivos de validaci√≥n encontrados...")
    
    all_points = []
    all_solutions = []
    
    for f_path in all_files:
        try:
            data = np.load(f_path, allow_pickle=True)
            all_points.append(data['points'])
            
            # Conversi√≥n m√°gica de lista de objetos a matriz int64
            raw_sols = data['solutions']
            # Verificamos si ya es matriz o lista de listas
            if raw_sols.dtype == np.object_:
                sols_mat = np.vstack(raw_sols).astype(np.int64)
            else:
                sols_mat = raw_sols.astype(np.int64)
                
            all_solutions.append(sols_mat)
            
        except Exception as e:
            print(f"‚ö†Ô∏è Error leyendo {os.path.basename(f_path)}: {e}")
            continue

    if not all_points:
        return None, None

    # Pegamos todo en arrays gigantes
    # np.concatenate une los arrays uno detr√°s de otro
    final_points = np.concatenate(all_points) 
    final_solutions = np.concatenate(all_solutions)
    
    return torch.FloatTensor(final_points), torch.from_numpy(final_solutions)

def validate_phase(phase_name, config):
    print(f"\n{'='*60}")
    print(f"üìä VALIDANDO FASE: {phase_name} (MODO COMPLETO)")
    print(f"{'='*60}")

    # 1. Cargar Checkpoint
    if not os.path.exists(config["ckpt"]):
        print(f"‚ö†Ô∏è Salto Fase: No existe checkpoint en {config['ckpt']}")
        return

    model = EncoderPointerModel(input_dim=2, d_model=128, nhead=8, enc_layers=3, dec_layers=2, max_seq_len=150).to(DEVICE)
    
    try:
        model.load_state_dict(torch.load(config["ckpt"], map_location=DEVICE, weights_only=False))
        model.eval()
        print(f"üß† Modelo cargado: {os.path.basename(config['ckpt'])}")
    except Exception as e:
        print(f"‚ùå Error cargando modelo: {e}")
        return

    # 2. Cargar TODA la data
    points, solutions = load_all_validation_parts(config["val_folder"], config["val_prefix"])
    
    if points is None:
        return
    
    print(f"üìÇ Total muestras cargadas: {len(points)}")
    
    # Normalizaci√≥n
    if points.max() > 1.0: points /= points.max()

    # Dataset completo
    dataset = TensorDataset(points, solutions)
    loader = DataLoader(dataset, batch_size=64, shuffle=False) # Batch grande para ir r√°pido

    # 3. Inferencia
    gap_accum = 0
    total_samples = 0
    
    pbar = tqdm(loader, desc="Benchmarking")
    
    with torch.no_grad():
        for bx, by in pbar:
            bx, by = bx.to(DEVICE), by.to(DEVICE)
            
            logits = model(bx, teacher_forcing=False)
            pred_tour = logits.argmax(dim=2) 

            cost_model = get_tour_distance(bx, pred_tour)
            cost_ortools = get_tour_distance(bx, by)

            gap = ((cost_model - cost_ortools) / cost_ortools)
            gap_accum += gap.sum().item()
            total_samples += bx.size(0)
            
            pbar.set_postfix({'GAP Acum': f"{(gap_accum/total_samples)*100:.2f}%"})

    final_gap = (gap_accum / total_samples) * 100
    print(f"\nüèÜ RESULTADO FINAL {phase_name}: GAP GLOBAL {final_gap:.2f}%")

# --- EJECUTAR ---
# Nota: Como detuviste el entrenamiento en MEDIUM, probablemente solo EASY funcione bien.
for phase in ["EASY", "MEDIUM", "HARD"]:
    validate_phase(phase, PATHS_CONFIG[phase])


üìä VALIDANDO FASE: EASY (MODO COMPLETO)
‚ö†Ô∏è Salto Fase: No existe checkpoint en data_repo/EASY/checkpoint_EASY_best.pth

üìä VALIDANDO FASE: MEDIUM (MODO COMPLETO)
‚ö†Ô∏è Salto Fase: No existe checkpoint en data_repo/MEDIUM/checkpoint_MEDIUM_best.pth

üìä VALIDANDO FASE: HARD (MODO COMPLETO)
‚ö†Ô∏è Salto Fase: No existe checkpoint en data_repo/HARD/checkpoint_HARD_best.pth


In [12]:
# ==========================================
# üß™ PRUEBA DE GENERALIZACI√ìN (EASY -> MEDIUM)
# ==========================================

# Definimos una configuraci√≥n h√≠brida:
# üß† CEREBRO: Checkpoint de EASY (Entrenado con 20 nodos)
# üìù EXAMEN: Datos de MEDIUM (Problemas de 50 nodos)

CROSS_TEST_CONFIG = {
    "ckpt": "data_repo/EASY/checkpoint_EASY_best.pth",   # Usamos el modelo peque√±o
    "val_folder": "Data/Validation/Medium",              # Usamos la data mediana
    "val_prefix": "tsp_medium"
}

print(f"\n{'#'*60}")
print("üß™ EXPERIMENTO: ¬øPuede un modelo de 20 ciudades resolver uno de 50?")
print(f"{'#'*60}")

# Llamamos a tu funci√≥n de validaci√≥n existente
validate_phase("GENERALIZATION_TEST", CROSS_TEST_CONFIG)


############################################################
üß™ EXPERIMENTO: ¬øPuede un modelo de 20 ciudades resolver uno de 50?
############################################################

üìä VALIDANDO FASE: GENERALIZATION_TEST (MODO COMPLETO)
‚ö†Ô∏è Salto Fase: No existe checkpoint en data_repo/EASY/checkpoint_EASY_best.pth


In [13]:
# ==========================================
# 6. VISUALIZACI√ìN COMPARATIVA (VISUALIZER)
# ==========================================


# --- CONFIGURACI√ìN ---
# Usamos la misma configuraci√≥n de rutas que antes
PATHS_CONFIG = {
    "EASY":   {"ckpt": "data_repo/EASY/checkpoint_EASY_best.pth",   "val_folder": "Data/Validation/Easy",   "val_prefix": "tsp_easy"},
    "MEDIUM": {"ckpt": "data_repo/MEDIUM/checkpoint_MEDIUM_best.pth", "val_folder": "Data/Validation/Medium", "val_prefix": "tsp_medium"},
    "HARD":   {"ckpt": "data_repo/HARD/checkpoint_HARD_best.pth",   "val_folder": "Data/Validation/Hard",   "val_prefix": "tsp_hard"}
}

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def plot_route(ax, points, tour, title, color):
    """Dibuja una ruta en el subplot dado."""
    # points: array numpy [N, 2]
    # tour: array numpy [N] (indices)
    
    # Reordenamos los puntos seg√∫n el tour
    route_points = points[tour]
    # Cerramos el ciclo (a√±adimos el primer punto al final)
    route_points = np.vstack([route_points, route_points[0]])
    
    # Dibujar l√≠neas
    ax.plot(route_points[:, 0], route_points[:, 1], c=color, linewidth=1.5, linestyle='-')
    # Dibujar nodos
    ax.scatter(points[:, 0], points[:, 1], c='black', s=15, zorder=5)
    # Marcar inicio (rojo)
    ax.scatter(route_points[0, 0], route_points[0, 1], c='red', s=40, zorder=6, label='Inicio')
    
    ax.set_title(title)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_aspect('equal')

def visualize_comparison(phase_name, config):
    print(f"\nüé® GENERANDO VISUALIZACIONES PARA: {phase_name}")
    
    if not os.path.exists(config["ckpt"]):
        print(f"‚ö†Ô∏è No hay modelo para {phase_name}, saltando...")
        return

    # 1. Cargar Modelo
    model = EncoderPointerModel(input_dim=2, d_model=128, nhead=8, enc_layers=3, dec_layers=2, max_seq_len=150).to(DEVICE)
    try:
        model.load_state_dict(torch.load(config["ckpt"], map_location=DEVICE, weights_only=False))
        model.eval()
    except Exception as e:
        print(f"‚ùå Error cargando modelo: {e}")
        return

    # 2. Buscar archivos parciales
    search_pattern = os.path.join(config["val_folder"], f"{config['val_prefix']}*.npz")
    files = sorted(glob.glob(search_pattern))
    
    if not files:
        print("‚ùå No encontr√© archivos de validaci√≥n.")
        return

    # 3. Iterar sobre cada archivo encontrado
    print(f"üì∏ Se encontraron {len(files)} archivos. Generando 1 ejemplo de cada uno...")

    for i, f_path in enumerate(files):
        try:
            # Cargar archivo
            data = np.load(f_path, allow_pickle=True)
            points_all = data['points']
            sols_all = data['solutions']
            
            # --- SELECCIONAR UN EJEMPLO ALEATORIO O EL PRIMERO ---
            idx = 0 # Tomamos el primero de cada archivo (puedes cambiar a np.random.randint)
            
            sample_points = points_all[idx] # [N, 2]
            
            # Fix conversi√≥n object -> int64 para la soluci√≥n real
            raw_sol = sols_all[idx]
            if isinstance(raw_sol, list) or raw_sol.dtype == np.object_:
                 sample_sol_true = np.array(raw_sol).astype(np.int64)
            else:
                 sample_sol_true = raw_sol.astype(np.int64)

            # Normalizar puntos para el modelo (0-1)
            max_val = sample_points.max()
            input_points = torch.tensor(sample_points / max_val, dtype=torch.float32).unsqueeze(0).to(DEVICE)

            # --- INFERENCIA DEL MODELO ---
            with torch.no_grad():
                logits = model(input_points, teacher_forcing=False)
                sample_sol_pred = logits.argmax(dim=2).squeeze(0).cpu().numpy()

            # --- DIBUJAR ---
            fig, axs = plt.subplots(1, 2, figsize=(10, 5))
            
            # Gr√°fica Izquierda: Tu IA
            plot_route(axs[0], sample_points, sample_sol_pred, f"Tu Modelo (IA)\nArchivo: {os.path.basename(f_path)}", 'blue')
            
            # Gr√°fica Derecha: OR-Tools (El Maestro)
            plot_route(axs[1], sample_points, sample_sol_true, "OR-Tools (Ground Truth)", 'green')
            
            plt.tight_layout()
            plt.show()
            
            # Limite de seguridad: Si hay 50 archivos, no queremos 50 popups.
            # Comenta estas dos l√≠neas si quieres verlos TODOS.
            if i >= 2: 
                print("üõë Deteniendo visualizaci√≥n para no saturar la pantalla (3 ejemplos mostrados).")
                break

        except Exception as e:
            print(f"‚ö†Ô∏è Error visualizando {os.path.basename(f_path)}: {e}")
            continue

# --- EJECUTAR ---
visualize_comparison("EASY", PATHS_CONFIG["EASY"])
visualize_comparison("MEDIUM", PATHS_CONFIG["MEDIUM"])


üé® GENERANDO VISUALIZACIONES PARA: EASY
‚ö†Ô∏è No hay modelo para EASY, saltando...

üé® GENERANDO VISUALIZACIONES PARA: MEDIUM
‚ö†Ô∏è No hay modelo para MEDIUM, saltando...
