In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, TensorDataset, random_split
from torch.optim import Adam, AdamW
import h5py
import numpy as np
from tqdm import tqdm
import pickle
import math
from sklearn.linear_model import Ridge, LogisticRegression
from sklearn.preprocessing import KBinsDiscretizer, StandardScaler
from sklearn.metrics import (
    accuracy_score, f1_score, precision_score, recall_score, 
    confusion_matrix, roc_auc_score, precision_recall_curve, auc, roc_curve
)
import random
import copy
import warnings
from sklearn.exceptions import ConvergenceWarning  # Importar ConvergenceWarning
import matplotlib.pyplot as plt

# ============================
# Configuración de la Semilla para Reproducibilidad
# ============================

def set_seed(seed):
    """
    Establece la semilla para diferentes librerías para asegurar la reproducibilidad.

    Args:
        seed (int): Valor de la semilla a establecer.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True  # Garantiza determinismo en cuDNN
    torch.backends.cudnn.benchmark = False     # Desactiva el benchmark para evitar variaciones

SEED = 42
set_seed(SEED)

# Función para inicializar la semilla en los workers de DataLoader
def worker_init_fn(worker_id):
    """
    Inicializa la semilla para cada worker del DataLoader.

    Args:
        worker_id (int): Identificador del worker.
    """
    np.random.seed(SEED + worker_id)
    random.seed(SEED + worker_id)

# Suprimir advertencias específicas si es necesario
warnings.filterwarnings("ignore", category=UserWarning, module="sklearn")
# warnings.filterwarnings("ignore", category=ConvergenceWarning, module="sklearn")  # Eliminado para no suprimir ConvergenceWarning

############################################################
# Dataset Clases
############################################################

class ShapesDataset(Dataset):
    def __init__(self, file_path):
        """
        Clase personalizada para manejar el conjunto de datos de formas.
        Carga imágenes y etiquetas desde un archivo HDF5.

        Args:
            file_path (str): Ruta al archivo HDF5 que contiene los datos.
        """
        with h5py.File(file_path, 'r') as data:
            # Cargar y normalizar las imágenes
            self.images = torch.tensor(data['images'][:], dtype=torch.float32).div(255.0)
            # Cargar las etiquetas
            self.labels = torch.tensor(data['labels'][:], dtype=torch.float32)

    def __len__(self):
        """
        Retorna la cantidad total de muestras en el dataset.
        """
        return len(self.images)

    def __getitem__(self, idx):
        """
        Retorna una imagen y su etiqueta correspondiente.

        Args:
            idx (int): Índice de la muestra a obtener.

        Returns:
            tuple: (imagen, etiqueta)
        """
        img = self.images[idx].permute(2, 0, 1)  # Cambiar de [H, W, C] a [C, H, W]
        label = self.labels[idx]
        return img, label

class DynamicTripletDataset(Dataset):
    def __init__(self, dataset):
        """
        Clase personalizada para generar tripletes dinámicamente durante el entrenamiento.
        Cada triplete consiste en una imagen ancla, una imagen positiva (misma clase) y una imagen negativa (clase diferente).

        Args:
            dataset (Dataset): Instancia de ShapesDataset o Subset de ShapesDataset.
        """
        self.dataset = dataset
        # Verificar si el dataset es una instancia de Subset
        if isinstance(dataset, torch.utils.data.Subset):
            # Acceder a los labels del dataset original usando los índices del subconjunto
            self.labels = dataset.dataset.labels[dataset.indices]
        else:
            self.labels = dataset.labels

    def __len__(self):
        """
        Retorna la cantidad total de muestras en el dataset.
        """
        return len(self.dataset)

    def __getitem__(self, idx):
        """
        Genera un triplete (ancla, positivo, negativo).

        Args:
            idx (int): Índice de la muestra ancla.

        Returns:
            tuple: (ancla, positivo, negativo)
        """
        anchor_img, anchor_label = self.dataset[idx]
        shape_val = anchor_label[4].item()  # Suponiendo que la columna 4 es 'shape'

        # Encontrar índices con la misma forma
        same_shape_indices = (self.labels[:, 4] == shape_val).nonzero(as_tuple=True)[0].numpy()
        # Encontrar índices con forma diferente
        diff_shape_indices = (self.labels[:, 4] != shape_val).nonzero(as_tuple=True)[0].numpy()

        # Asegurarse de que haya suficientes índices para positivos y negativos
        if len(same_shape_indices) < 1:
            positive_idx = idx  # Usar el mismo índice si no hay otros con la misma forma
        else:
            positive_idx = np.random.choice(same_shape_indices)

        if len(diff_shape_indices) < 1:
            negative_idx = idx  # Usar el mismo índice si no hay otros con forma diferente
        else:
            negative_idx = np.random.choice(diff_shape_indices)

        positive_img, _ = self.dataset[positive_idx]
        negative_img, _ = self.dataset[negative_idx]
        return anchor_img, positive_img, negative_img

############################################################
# Modelos TAE y TVAE
############################################################

class TAE(nn.Module):
    def __init__(self, input_shape, latent_dim):
        """
        Triplet AutoEncoder (TAE).

        Args:
            input_shape (tuple): Forma de entrada de las imágenes (C, H, W).
            latent_dim (int): Dimensionalidad del espacio latente.
        """
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 1), nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64*16*16, latent_dim)
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 64*16*16),
            nn.Unflatten(1, (64, 16, 16)),
            nn.ConvTranspose2d(64, 32, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(32, input_shape[0], 4, 2, 1), nn.Sigmoid()
        )

    def forward(self, x):
        """
        Paso hacia adelante del modelo.

        Args:
            x (torch.Tensor): Entrada de imágenes.

        Returns:
            tuple: (imagen reconstruida, representación latente)
        """
        z = self.encoder(x)
        x_recon = self.decoder(z)
        return x_recon, z

    def loss(self, x, x_recon, z_anchor, z_positive, z_negative):
        """
        Función de pérdida para TAE que combina la pérdida de reconstrucción y la pérdida triplete.

        Args:
            x (torch.Tensor): Imagen original.
            x_recon (torch.Tensor): Imagen reconstruida.
            z_anchor (torch.Tensor): Representación latente de la imagen ancla.
            z_positive (torch.Tensor): Representación latente de la imagen positiva.
            z_negative (torch.Tensor): Representación latente de la imagen negativa.

        Returns:
            torch.Tensor: Pérdida total.
        """
        recon_loss = nn.functional.mse_loss(x_recon, x, reduction="mean")
        margin = 1.0
        pos_dist = torch.norm(z_anchor - z_positive, dim=-1)
        neg_dist = torch.norm(z_anchor - z_negative, dim=-1)
        triplet_loss = torch.mean(torch.relu(pos_dist - neg_dist + margin))
        return recon_loss + triplet_loss

class TVAE(nn.Module):
    def __init__(self, input_shape, latent_dim):
        """
        Triplet Variational AutoEncoder (TVAE).

        Args:
            input_shape (tuple): Forma de entrada de las imágenes (C, H, W).
            latent_dim (int): Dimensionalidad del espacio latente.
        """
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 1), nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64*16*16, latent_dim*2)
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 64*16*16),
            nn.Unflatten(1, (64, 16, 16)),
            nn.ConvTranspose2d(64, 32, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(32, input_shape[0], 4, 2, 1), nn.Sigmoid()
        )
        self.latent_dim = latent_dim

    def reparameterize(self, mu, log_var):
        """
        Reparametrización para el TVAE.

        Args:
            mu (torch.Tensor): Media de la distribución latente.
            log_var (torch.Tensor): Log-variancia de la distribución latente.

        Returns:
            torch.Tensor: Muestra reparametrizada de la distribución latente.
        """
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        """
        Paso hacia adelante del modelo.

        Args:
            x (torch.Tensor): Entrada de imágenes.

        Returns:
            tuple: (imagen reconstruida, mu, log_var, representación latente)
        """
        q = self.encoder(x)
        mu, log_var = torch.chunk(q, 2, dim=-1)
        z = self.reparameterize(mu, log_var)
        x_recon = self.decoder(z)
        return x_recon, mu, log_var, z

    def loss(self, x, x_recon, mu, log_var, z_anchor, z_positive, z_negative):
        """
        Función de pérdida para TVAE que combina la pérdida de reconstrucción, la pérdida KL y la pérdida triplete.

        Args:
            x (torch.Tensor): Imagen original.
            x_recon (torch.Tensor): Imagen reconstruida.
            mu (torch.Tensor): Media de la distribución latente.
            log_var (torch.Tensor): Log-variancia de la distribución latente.
            z_anchor (torch.Tensor): Representación latente de la imagen ancla.
            z_positive (torch.Tensor): Representación latente de la imagen positiva.
            z_negative (torch.Tensor): Representación latente de la imagen negativa.

        Returns:
            torch.Tensor: Pérdida total.
        """
        recon_loss = nn.functional.mse_loss(x_recon, x, reduction="mean")
        kl_loss = -0.5 * torch.mean(1 + log_var - mu.pow(2) - log_var.exp())
        margin = 1.0
        pos_dist = torch.norm(z_anchor - z_positive, dim=-1)
        neg_dist = torch.norm(z_anchor - z_negative, dim=-1)
        triplet_loss = torch.mean(torch.relu(pos_dist - neg_dist + margin))
        return recon_loss + kl_loss + triplet_loss

############################################################
# Entrenamiento y guardado
############################################################

def train_model(model, dataloader, optimizer, device, epochs=10):
    """
    Entrena el modelo utilizando los datos proporcionados.

    Args:
        model (nn.Module): Modelo a entrenar.
        dataloader (DataLoader): DataLoader que proporciona los datos.
        optimizer (torch.optim.Optimizer): Optimizador para actualizar los pesos del modelo.
        device (torch.device): Dispositivo (CPU o GPU) para entrenar el modelo.
        epochs (int, optional): Número de épocas de entrenamiento. Defaults to 10.

    Returns:
        nn.Module: Modelo entrenado.
    """
    model.train()
    for epoch in range(epochs):
        epoch_loss = 0
        pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
        for anchor, positive, negative in pbar:
            anchor, positive, negative = anchor.to(device), positive.to(device), negative.to(device)
            optimizer.zero_grad()
            if isinstance(model, TAE):
                x_recon, z_anchor = model(anchor)
                _, z_positive = model(positive)
                _, z_negative = model(negative)
                loss = model.loss(anchor, x_recon, z_anchor, z_positive, z_negative)
            elif isinstance(model, TVAE):
                x_recon, mu, log_var, z_anchor = model(anchor)
                _, _, _, z_positive = model(positive)
                _, _, _, z_negative = model(negative)
                loss = model.loss(anchor, x_recon, mu, log_var, z_anchor, z_positive, z_negative)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            pbar.set_postfix({"Loss": epoch_loss / len(dataloader)})
    return model

def save_representations(model, dataset, device, save_path):
    """
    Guarda las representaciones latentes y las etiquetas del dataset proporcionado.

    Args:
        model (nn.Module): Modelo entrenado para generar las representaciones latentes.
        dataset (Dataset): Dataset del cual generar las representaciones.
        device (torch.device): Dispositivo para procesar los datos.
        save_path (str): Ruta donde guardar las representaciones.
    """
    latents = []
    factors = []
    model.eval()
    loader = DataLoader(dataset, batch_size=1024, shuffle=False, worker_init_fn=worker_init_fn)
    with torch.no_grad():
        for img, label in tqdm(loader, desc=f"Saving {save_path}"):
            img = img.to(device)
            if isinstance(model, TAE):
                _, z = model(img)
            elif isinstance(model, TVAE):
                _, mu, log_var, z = model(img)
            latents.append(z.cpu().numpy())
            factors.append(label.numpy())
    latents = np.vstack(latents)
    factors = np.vstack(factors)
    with open(save_path, "wb") as f:
        pickle.dump({"latents": latents, "factors": factors}, f)
    print(f"Representations saved to {save_path}")

############################################################
# MIG y DCI
############################################################

def compute_mig(latents, factors, num_bins=10):  # Reducido de 20 a 10
    """
    Calcula la métrica Mutual Information Gap (MIG) para evaluar la calidad de las representaciones latentes.

    Args:
        latents (np.ndarray): Representaciones latentes del modelo.
        factors (np.ndarray): Factores de variación originales.
        num_bins (int, optional): Número de bins para discretizar las representaciones. Defaults to 10.

    Returns:
        float: Valor promedio de MIG.
    """
    num_factors = factors.shape[1]
    num_latents = latents.shape[1]

    est_lat = KBinsDiscretizer(n_bins=num_bins, encode='ordinal', strategy='quantile')
    lat_disc = est_lat.fit_transform(latents).astype(int)

    factors_disc = []
    for f_idx in range(num_factors):
        f_vals = factors[:, f_idx]
        uniq = np.unique(f_vals)
        if len(uniq) > 10:
            est_f = KBinsDiscretizer(n_bins=min(num_bins, len(uniq)), encode='ordinal', strategy='quantile')
            f_disc = est_f.fit_transform(f_vals.reshape(-1, 1)).astype(int).flatten()
        else:
            val2idx = {val: i for i, val in enumerate(np.sort(uniq))}
            f_disc = np.array([val2idx[v] for v in f_vals])
        factors_disc.append(f_disc)
    factors_disc = np.stack(factors_disc, axis=1).astype(int)

    def mutual_information(z, v):
        pz = np.bincount(z) / len(z)
        pv = np.bincount(v) / len(v)
        pzv, _, _ = np.histogram2d(z, v, bins=(np.arange(z.max()+2)-0.5, np.arange(v.max()+2)-0.5))
        pzv = pzv / pzv.sum()
        mi = 0
        nz, nv = pzv.shape
        for i in range(nz):
            for j in range(nv):
                if pzv[i, j] > 0:
                    mi += pzv[i, j] * math.log(pzv[i, j] / (pz[i] * pv[j] + 1e-12) + 1e-12)
        return mi / math.log(2 + 1e-9)

    def entropy(v):
        pv = np.bincount(v) / len(v)
        h = 0
        for p in pv:
            if p > 0:
                h -= p * math.log(p + 1e-12)
        return h / math.log(2 + 1e-9)

    migs = []
    for f_idx in range(num_factors):
        v = factors_disc[:, f_idx]
        h_v = entropy(v)
        mi_scores = []
        for z_j in range(num_latents):
            z = lat_disc[:, z_j]
            mi_scores.append(mutual_information(z, v))
        mi_scores = sorted(mi_scores, reverse=True)
        if len(mi_scores) > 1:
            migs.append((mi_scores[0] - mi_scores[1]) / (h_v + 1e-12))
        else:
            migs.append(mi_scores[0] / (h_v + 1e-12))
    return np.mean(migs)

def compute_dci(latents, factors):
    """
    Calcula la métrica Disentanglement Completeness Informativeness (DCI) para evaluar la calidad de las representaciones latentes.

    Args:
        latents (np.ndarray): Representaciones latentes del modelo.
        factors (np.ndarray): Factores de variación originales.

    Returns:
        float: Valor promedio de DCI.
    """
    num_factors = factors.shape[1]
    disent_scores = []
    scaler = StandardScaler()  # Escalador para las representaciones latentes
    latents_scaled = scaler.fit_transform(latents)  # Escalar solo una vez para todo el proceso

    for f_idx in range(num_factors):
        f_vals = factors[:, f_idx]
        uniq = np.unique(f_vals)
        if len(uniq) > 10:
            model = Ridge(alpha=1.0)
            model.fit(latents_scaled, f_vals)
            coefs = model.coef_
        else:
            val2idx = {val: i for i, val in enumerate(np.sort(uniq))}
            y_disc = np.array([val2idx[v] for v in f_vals])
            model = LogisticRegression(max_iter=1000, multi_class='auto')  # Aumentado max_iter a 1000
            model.fit(latents_scaled, y_disc)
            coefs = model.coef_
            if coefs.shape[0] > 1:
                coefs = np.mean(np.abs(coefs), axis=0)
            else:
                coefs = coefs.flatten()
        importances = np.abs(coefs)
        p = importances / (importances.sum() + 1e-12)
        entropy_val = -np.sum(p * np.log(p + 1e-12))
        max_entropy = math.log(len(p) + 1e-12)
        disent = 1 - (entropy_val / (max_entropy + 1e-12))
        disent_scores.append(disent)
    return np.mean(disent_scores)

############################################################
# create_stratified_pairs
############################################################

def create_stratified_pairs(x, y, num_pairs):
    """
    Crea pares de imágenes estratificados en dos clases:
    SAME: ambas imágenes tienen la misma forma (y[:,4] igual)
    DIFFERENT: las imágenes tienen distinta forma (y[:,4] distinto)

    Genera la mitad de pares SAME y la mitad DIFFERENT.

    Args:
        x (torch.Tensor): Tensores de imágenes [N, C, H, W].
        y (torch.Tensor): Tensores de etiquetas [N, num_factors].
        num_pairs (int): Número total de pares a generar.

    Returns:
        tuple: ([X1, X2], y_pairs)
    """
    shape_labels = y[:,4].cpu().numpy()  # El factor de la forma está en la columna 4
    unique_shapes = np.unique(shape_labels)
    shape_to_indices = {sh: np.where(shape_labels == sh)[0] for sh in unique_shapes}

    half_pairs = num_pairs // 2

    # Generar pares SAME
    same_x1 = []
    same_x2 = []
    for _ in range(half_pairs):
        sh = np.random.choice(unique_shapes)
        inds = shape_to_indices[sh]
        if len(inds) < 2:
            continue
        i1, i2 = np.random.choice(inds, 2, replace=False)
        same_x1.append(i1)
        same_x2.append(i2)

    # Generar pares DIFFERENT
    diff_x1 = []
    diff_x2 = []
    for _ in range(half_pairs):
        if len(unique_shapes) < 2:
            break
        sh1, sh2 = np.random.choice(unique_shapes, 2, replace=False)
        inds1 = shape_to_indices[sh1]
        inds2 = shape_to_indices[sh2]
        if len(inds1) == 0 or len(inds2) == 0:
            continue
        i1 = np.random.choice(inds1)
        i2 = np.random.choice(inds2)
        diff_x1.append(i1)
        diff_x2.append(i2)

    same_x1 = np.array(same_x1)
    same_x2 = np.array(same_x2)
    diff_x1 = np.array(diff_x1)
    diff_x2 = np.array(diff_x2)

    # Ajustar por si alguno se quedó corto
    min_len = min(len(same_x1), len(diff_x1))
    if min_len == 0:
        raise ValueError("No hay suficientes pares para generar.")
    same_x1 = same_x1[:min_len]
    same_x2 = same_x2[:min_len]
    diff_x1 = diff_x1[:min_len]
    diff_x2 = diff_x2[:min_len]

    X1 = np.concatenate([same_x1, diff_x1])
    X2 = np.concatenate([same_x2, diff_x2])

    # Mantener los datos en CPU
    X1_t = x[X1].cpu()
    X2_t = x[X2].cpu()

    # Crear etiquetas: 0 para SAME, 1 para DIFFERENT
    y_pairs = torch.cat([
        torch.zeros(min_len, dtype=torch.long),
        torch.ones(min_len, dtype=torch.long)
    ])

    x_pairs = [X1_t, X2_t]
    return x_pairs, y_pairs

############################################################
# JEPA Predictor
############################################################

class IWM_JEPA_Predictor(nn.Module):
    def __init__(self, latent_dim=6, action_dim=4):
        """
        Predictor IWM JEPA que predice la representación latente futura basada en la actual y en ciertas acciones.

        Args:
            latent_dim (int, optional): Dimensionalidad del espacio latente. Defaults to 6.
            action_dim (int, optional): Dimensionalidad de los parámetros de acción. Defaults to 4.
        """
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(latent_dim + action_dim, 256),
            nn.LeakyReLU(0.1),
            nn.Dropout(0.5),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.1),
            nn.Dropout(0.5),
            nn.Linear(128, latent_dim)
        )

    def forward(self, z_x, a_xy):
        """
        Paso hacia adelante del predictor.

        Args:
            z_x (torch.Tensor): Representación latente actual.
            a_xy (torch.Tensor): Parámetros de acción.

        Returns:
            torch.Tensor: Representación latente predicha.
        """
        inp = torch.cat([z_x, a_xy], dim=1)
        return self.mlp(inp)

def iwm_jepa_loss(z_pred, z_y):
    """
    Función de pérdida para el predictor IWM JEPA.

    Args:
        z_pred (torch.Tensor): Representación latente predicha.
        z_y (torch.Tensor): Representación latente real.

    Returns:
        torch.Tensor: Pérdida MSE.
    """
    return nn.functional.mse_loss(z_pred, z_y)

############################################################
# Funciones para IWM JEPA
############################################################

def generate_views_for_iwm(imgs, device):
    """
    Genera dos vistas de cada imagen: una con leves transformaciones y otra con transformaciones más fuertes.

    Args:
        imgs (torch.Tensor): Batch de imágenes [B, C, H, W].
        device (torch.device): Dispositivo para procesar las imágenes.

    Returns:
        tuple: (x_imgs, y_imgs)
    """
    B = imgs.shape[0]
    def color_jitter(img, brightness=0.2, contrast=0.2):
        """
        Aplica jitter de color a una imagen.

        Args:
            img (torch.Tensor): Imagen [1, C, H, W].
            brightness (float, optional): Factor de brillo. Defaults to 0.2.
            contrast (float, optional): Factor de contraste. Defaults to 0.2.

        Returns:
            torch.Tensor: Imagen transformada.
        """
        b_factor = 1.0 + (2 * random.random() - 1) * brightness
        c_factor = 1.0 + (2 * random.random() - 1) * contrast
        mean_img = img.mean(dim=(1,2,3), keepdim=True)
        img2 = img * b_factor
        img2 = (img2 - mean_img) * c_factor + mean_img
        img2 = torch.clamp(img2, 0, 1)
        return img2

    # y_imgs con leve jitter
    y_imgs = []
    for i in range(B):
        y_imgs.append(color_jitter(imgs[i:i+1], 0.1, 0.1))
    y_imgs = torch.cat(y_imgs, dim=0)

    # x con fuerte jitter+blur+mask
    x_imgs = []
    for i in range(B):
        x_img = imgs[i:i+1]
        x_img = color_jitter(x_img, 0.4, 0.4)
        if random.random() < 0.5:
            kernel = 3
            pad = (kernel - 1) // 2
            x_img = torch.nn.functional.avg_pool2d(x_img, kernel, stride=1, padding=pad)
        if random.random() < 0.5:
            C, H, W = x_img.shape[1], x_img.shape[2], x_img.shape[3]
            mh, mw = H // 4, W // 4
            sy = random.randint(0, H - mh)
            sx = random.randint(0, W - mw)
            x_img[:, :, sy:sy+mh, sx:sx+mw] = 0.0
        x_imgs.append(x_img)
    x_imgs = torch.cat(x_imgs, dim=0)
    return x_imgs.to(device), y_imgs.to(device)

def generate_action_params(x_imgs, y_imgs):
    """
    Genera parámetros de acción basados en las diferencias de brillo y contraste entre x_imgs y y_imgs.

    Args:
        x_imgs (torch.Tensor): Primer conjunto de imágenes [B, C, H, W].
        y_imgs (torch.Tensor): Segundo conjunto de imágenes [B, C, H, W].

    Returns:
        torch.Tensor: Parámetros de acción [B, 4].
    """
    B = x_imgs.shape[0]
    def brightness_contrast(img):
        """
        Calcula el brillo y contraste de una imagen.

        Args:
            img (torch.Tensor): Imagen [B, C, H, W].

        Returns:
            tuple: (brillo medio, desviación estándar)
        """
        mean_val = img.mean(dim=(1,2,3))
        std_val = img.std(dim=(1,2,3)) + 1e-6
        return mean_val, std_val
    mean_x, std_x = brightness_contrast(x_imgs)
    mean_y, std_y = brightness_contrast(y_imgs)
    a_xy = torch.stack([mean_y - mean_x, std_y - std_x, torch.zeros_like(mean_x), torch.zeros_like(mean_x)], dim=1)
    return a_xy

############################################################
# MAIN
############################################################

if __name__ == "__main__":
    # Configuración del dispositivo y parámetros
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    latent_dim = 6

    # Cargar los datasets de entrenamiento y prueba
    train_dataset_full = ShapesDataset('/home/gperaltag/3dshapes_data/3dshapes_abstraction_train.h5')
    test_dataset = ShapesDataset('/home/gperaltag/3dshapes_data/3dshapes_abstraction_test.h5')

    # ============================
    # Crear Conjunto de Validación
    # ============================
    val_size = int(0.1 * len(train_dataset_full))  # 10% para validación
    train_size = len(train_dataset_full) - val_size
    train_dataset, val_dataset = random_split(train_dataset_full, [train_size, val_size], generator=torch.Generator().manual_seed(SEED))
    print(f"Train size: {train_size}, Validation size: {val_size}")

    # ============================
    # Etapa 1: Entrenar TAE y TVAE
    # ============================
    # Crear el dataset triplet para entrenamiento
    triplet_dataset = DynamicTripletDataset(train_dataset)
    triplet_loader = DataLoader(
        triplet_dataset,
        batch_size=512,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
        worker_init_fn=worker_init_fn  # Asegurar reproducibilidad en los workers
    )

    # Entrenar TAE
    print("Entrenando TAE...")
    tae = TAE((3, 64, 64), latent_dim).to(device)
    tae_optimizer = Adam(tae.parameters(), lr=1e-4)
    train_model(tae, triplet_loader, tae_optimizer, device, epochs=10)
    save_representations(tae, test_dataset, device, "tae_representations.pkl")

    # Entrenar TVAE
    print("Entrenando TVAE...")
    tvae = TVAE((3, 64, 64), latent_dim).to(device)
    tvae_optimizer = Adam(tvae.parameters(), lr=1e-4)  # Corregido para usar tvae.parameters()
    train_model(tvae, triplet_loader, tvae_optimizer, device, epochs=10)
    save_representations(tvae, test_dataset, device, "tvae_representations.pkl")

    # ============================
    # Evaluar MIG y DCI
    # ============================
    print("Evaluando MIG y DCI para TAE y TVAE...")
    with open("tae_representations.pkl", "rb") as f:
        tae_data = pickle.load(f)
    tae_latents = tae_data["latents"]
    tae_factors = tae_data["factors"]

    with open("tvae_representations.pkl", "rb") as f:
        tvae_data = pickle.load(f)
    tvae_latents = tvae_data["latents"]
    tvae_factors = tvae_data["factors"]

    tae_mig = compute_mig(tae_latents, tae_factors)
    tae_dci = compute_dci(tae_latents, tae_factors)
    tvae_mig = compute_mig(tvae_latents, tvae_factors)
    tvae_dci = compute_dci(tvae_latents, tvae_factors)

    print(f"TAE MIG: {tae_mig}, TAE DCI: {tae_dci}")
    print(f"TVAE MIG: {tvae_mig}, TVAE DCI: {tvae_dci}")

    # Seleccionar el mejor modelo basado en MIG + DCI
    tae_score = tae_mig + tae_dci
    tvae_score = tvae_mig + tvae_dci
    if tvae_score > tae_score:
        best_model_name = "TVAE"
        best_encoder = tvae.encoder
    else:
        best_model_name = "TAE"
        best_encoder = tae.encoder

    print(f"Best model for IWM JEPA: {best_model_name}")

    def get_latent(z):
        """
        Obtiene la representación latente correcta según el modelo seleccionado.

        Args:
            z (torch.Tensor): Representación latente del modelo.

        Returns:
            torch.Tensor: Representación latente utilizada.
        """
        if best_model_name == "TVAE":
            return z[:, :latent_dim]
        else:
            return z

    # Congelar encoder seleccionado (f_theta)
    f_theta = best_encoder
    for param in f_theta.parameters():
        param.requires_grad = False
    f_theta.eval()

    # Crear f_EMA_theta como copia congelada
    f_EMA_theta = copy.deepcopy(f_theta)
    for param in f_EMA_theta.parameters():
        param.requires_grad = False

    # ============================
    # Etapa 2: Entrenar IWM JEPA
    # ============================
    print("Entrenando IWM JEPA Predictor...")
    iwm_predictor = IWM_JEPA_Predictor(latent_dim=latent_dim, action_dim=4).to(device)
    iwm_optimizer = AdamW(iwm_predictor.parameters(), lr=1e-4, weight_decay=1e-5)
    # === MODIFICACIÓN AQUÍ ===
    # Antes: iwm_loader = DataLoader(test_dataset, batch_size=256, shuffle=True, num_workers=4, pin_memory=True)
    # Ahora: Utilizar train_dataset en lugar de test_dataset
    iwm_loader = DataLoader(
        train_dataset,
        batch_size=256,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
        worker_init_fn=worker_init_fn  # Asegurar reproducibilidad en los workers
    )
    # ============================

    # Aumentar el número de épocas y aplicar Early Stopping
    epochs_iwm = 50
    best_loss = float('inf')
    patience = 3
    trigger_times = 0
    # Utilizar un scheduler para ajustar la tasa de aprendizaje
    scheduler_iwm = torch.optim.lr_scheduler.ReduceLROnPlateau(iwm_optimizer, mode='min', factor=0.1, patience=3, verbose=True)

    for epoch in range(epochs_iwm):
        epoch_loss = 0
        pbar = tqdm(iwm_loader, desc=f"IWM JEPA Epoch {epoch+1}/{epochs_iwm}")
        for imgs, labels in pbar:
            imgs = imgs.to(device)
            x_imgs, y_imgs = generate_views_for_iwm(imgs, device)
            with torch.no_grad():
                z_x = f_theta(x_imgs)
                z_y = f_EMA_theta(y_imgs)
                z_x = get_latent(z_x)
                z_y = get_latent(z_y)
            a_xy = generate_action_params(x_imgs, y_imgs)
            iwm_optimizer.zero_grad()
            z_pred = iwm_predictor(z_x, a_xy)
            loss = iwm_jepa_loss(z_pred, z_y)
            loss.backward()
            iwm_optimizer.step()
            epoch_loss += loss.item()
            pbar.set_postfix({"Loss": epoch_loss / len(iwm_loader)})
        avg_loss = epoch_loss / len(iwm_loader)
        scheduler_iwm.step(avg_loss)
        print(f"IWM JEPA Epoch {epoch+1}, Loss: {avg_loss}")
        
        # Early Stopping
        if avg_loss < best_loss:
            best_loss = avg_loss
            trigger_times = 0
            torch.save(iwm_predictor.state_dict(), 'best_iwm_jepa.pth')
        else:
            trigger_times += 1
            if trigger_times >= patience:
                print("Early stopping en IWM JEPA")
                break

    # Cargar el mejor modelo entrenado
    iwm_predictor.load_state_dict(torch.load('best_iwm_jepa.pth'))
    iwm_predictor.eval()

    # ============================
    # Etapa 3: Same/Different
    # ============================
    print("Preparando pares SAME/DIFFERENT usando todo el conjunto de entrenamiento...")
    # Usar todo el conjunto de entrenamiento
    train_loader_full = DataLoader(
        train_dataset,
        batch_size=1024,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
        worker_init_fn=worker_init_fn  # Asegurar reproducibilidad en los workers
    )
    all_train_imgs = []
    all_train_factors = []
    with torch.no_grad():
        for img, label in train_loader_full:
            all_train_imgs.append(img)
            all_train_factors.append(label)
    all_train_imgs = torch.cat(all_train_imgs, dim=0)  # [N, C, H, W]
    all_train_factors = torch.cat(all_train_factors, dim=0)  # [N, num_factors]

    # Generar pares usando todas las imágenes de entrenamiento
    train_num_pairs = len(all_train_imgs)
    print(f"Generando {train_num_pairs} pares SAME/DIFFERENT para entrenamiento...")
    x_pairs_train, y_pairs_train = create_stratified_pairs(all_train_imgs, all_train_factors, num_pairs=train_num_pairs)

    # La clase same/diff ya está definida como 0 (SAME) y 1 (DIFFERENT)
    same_diff_labels_train = y_pairs_train  # 0 o 1

    # Crear dataset de Same/Different
    train_sd_dataset = TensorDataset(x_pairs_train[0], x_pairs_train[1], same_diff_labels_train)

    # Dividir en train y validación para el clasificador *}*
    train_sd_size = int(0.8 * len(train_sd_dataset))
    val_sd_size = len(train_sd_dataset) - train_sd_size
    train_sd, val_sd = random_split(train_sd_dataset, [train_sd_size, val_sd_size], generator=torch.Generator().manual_seed(SEED))

    train_sd_dataloader = DataLoader(
        train_sd,
        batch_size=256,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
        worker_init_fn=worker_init_fn  # Asegurar reproducibilidad en los workers
    )

    val_sd_dataloader = DataLoader(
        val_sd,
        batch_size=256,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
        worker_init_fn=worker_init_fn  # Asegurar reproducibilidad en los workers
    )

    # Preparar pares SAME/DIFFERENT para test
    print("Preparando pares SAME/DIFFERENT usando todo el conjunto de test...")
    test_loader_full = DataLoader(
        test_dataset,
        batch_size=1024,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
        worker_init_fn=worker_init_fn  # Asegurar reproducibilidad en los workers
    )
    all_test_imgs = []
    all_test_factors = []
    with torch.no_grad():
        for img, label in test_loader_full:
            all_test_imgs.append(img)
            all_test_factors.append(label)
    all_test_imgs = torch.cat(all_test_imgs, dim=0)  # [N, C, H, W]
    all_test_factors = torch.cat(all_test_factors, dim=0)  # [N, num_factors]

    test_num_pairs = len(all_test_imgs)
    print(f"Generando {test_num_pairs} pares SAME/DIFFERENT para test...")
    x_pairs_test, y_pairs_test = create_stratified_pairs(all_test_imgs, all_test_factors, num_pairs=test_num_pairs)
    same_diff_labels_test = y_pairs_test  # 0 o 1

    test_sd_dataset = TensorDataset(x_pairs_test[0], x_pairs_test[1], same_diff_labels_test)
    test_sd_dataloader = DataLoader(
        test_sd_dataset,
        batch_size=256,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
        worker_init_fn=worker_init_fn  # Asegurar reproducibilidad en los workers
    )

    # ============================
    # Definir y Entrenar el Clasificador Same/Different
    # ============================
    print("Definiendo y entrenando el clasificador SAME/DIFFERENT con características mejoradas...")
    concatenated_feature_dim = 3 * latent_dim  # z1, z2, z_pred - z2

    # Definir una arquitectura más compleja para el clasificador
    classifier = nn.Sequential(
        nn.Linear(concatenated_feature_dim, 256),
        nn.BatchNorm1d(256),
        nn.LeakyReLU(0.1),
        nn.Dropout(0.5),
        nn.Linear(256, 128),
        nn.BatchNorm1d(128),
        nn.LeakyReLU(0.1),
        nn.Dropout(0.5),
        nn.Linear(128, 64),
        nn.BatchNorm1d(64),
        nn.LeakyReLU(0.1),
        nn.Dropout(0.5),
        nn.Linear(64, 2)  # Output: 2 clases (SAME, DIFFERENT)
    ).to(device)

    # Probar diferentes tasas de aprendizaje
    learning_rates = [1e-4, 3e-4]
    best_clf_metrics = {'accuracy': 0, 'f1': 0, 'roc_auc': 0, 'pr_auc': 0}
    best_clf_state = None

    for lr in learning_rates:
        print(f"\nEntrenando clasificador con tasa de aprendizaje: {lr}")
        clf_optimizer = AdamW(classifier.parameters(), lr=lr, weight_decay=1e-5)
        clf_loss = nn.CrossEntropyLoss()
        scheduler_clf = torch.optim.lr_scheduler.ReduceLROnPlateau(clf_optimizer, mode='min', factor=0.1, patience=3, verbose=True)

        # Early Stopping
        epochs_clf = 20
        best_val_loss = float('inf')
        patience_clf = 3
        trigger_times_clf = 0

        for epoch in range(epochs_clf):
            # Entrenar
            classifier.train()
            epoch_loss = 0
            for x1_batch, x2_batch, sd_label in train_sd_dataloader:
                x1_batch, x2_batch, sd_label = x1_batch.to(device), x2_batch.to(device), sd_label.to(device)
                with torch.no_grad():
                    z1 = f_theta(x1_batch)
                    z2 = f_theta(x2_batch)
                    z_pred = iwm_predictor(z1, generate_action_params(x1_batch, x2_batch))
                    # Obtener las representaciones latentes
                    z1 = get_latent(z1)
                    z2 = get_latent(z2)
                    z_pred_minus_z2 = z_pred - z2
                    # Concatenar las características
                    features = torch.cat([z1, z2, z_pred_minus_z2], dim=1)  # [batch_size, 3 * latent_dim]
                clf_optimizer.zero_grad()
                logits = classifier(features)
                loss = clf_loss(logits, sd_label)
                loss.backward()
                clf_optimizer.step()
                epoch_loss += loss.item()
            avg_train_loss = epoch_loss / len(train_sd_dataloader)

            # Validar
            classifier.eval()
            val_loss = 0
            with torch.no_grad():
                for x1_val, x2_val, sd_label_val in val_sd_dataloader:
                    x1_val, x2_val, sd_label_val = x1_val.to(device), x2_val.to(device), sd_label_val.to(device)
                    z1 = f_theta(x1_val)
                    z2 = f_theta(x2_val)
                    z_pred = iwm_predictor(z1, generate_action_params(x1_val, x2_val))
                    z1 = get_latent(z1)
                    z2 = get_latent(z2)
                    z_pred_minus_z2 = z_pred - z2
                    features = torch.cat([z1, z2, z_pred_minus_z2], dim=1)  # [batch_size, 3 * latent_dim]
                    logits = classifier(features)
                    loss = clf_loss(logits, sd_label_val)
                    val_loss += loss.item()
            avg_val_loss = val_loss / len(val_sd_dataloader)
            scheduler_clf.step(avg_val_loss)
            print(f"Epoch {epoch+1}, Train Loss: {avg_train_loss}, Val Loss: {avg_val_loss}")

            # Early Stopping
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                trigger_times_clf = 0
                best_clf_state = copy.deepcopy(classifier.state_dict())
            else:
                trigger_times_clf += 1
                if trigger_times_clf >= patience_clf:
                    print("Early stopping en el clasificador")
                    break

        # Cargar el mejor estado del clasificador para esta tasa de aprendizaje
        if best_clf_state is not None:
            classifier.load_state_dict(best_clf_state)

        # Evaluar en el conjunto de validación
        classifier.eval()
        all_val_preds = []
        all_val_trues = []
        all_val_probs = []
        with torch.no_grad():
            for x1_val, x2_val, sd_label_val in val_sd_dataloader:
                x1_val, x2_val, sd_label_val = x1_val.to(device), x2_val.to(device), sd_label_val.to(device)
                z1 = f_theta(x1_val)
                z2 = f_theta(x2_val)
                z_pred = iwm_predictor(z1, generate_action_params(x1_val, x2_val))
                z1 = get_latent(z1)
                z2 = get_latent(z2)
                z_pred_minus_z2 = z_pred - z2
                features = torch.cat([z1, z2, z_pred_minus_z2], dim=1)
                logits = classifier(features)
                probs = torch.softmax(logits, dim=1)[:,1]  # Probabilidad de la clase 'DIFFERENT'
                preds = logits.argmax(dim=1)
                all_val_preds.append(preds.cpu().numpy())
                all_val_trues.append(sd_label_val.cpu().numpy())
                all_val_probs.append(probs.cpu().numpy())
        all_val_preds = np.concatenate(all_val_preds)
        all_val_trues = np.concatenate(all_val_trues)
        all_val_probs = np.concatenate(all_val_probs)

        # Calcular métricas
        acc = accuracy_score(all_val_trues, all_val_preds)
        f1 = f1_score(all_val_trues, all_val_preds, average='weighted')
        roc_auc = roc_auc_score(all_val_trues, all_val_probs)
        precision_vals, recall_vals, _ = precision_recall_curve(all_val_trues, all_val_probs)
        pr_auc = auc(recall_vals, precision_vals)

        print(f"Métricas de Validación para lr={lr}:")
        print(f"Accuracy: {acc}")
        print(f"F1-score: {f1}")
        print(f"ROC-AUC: {roc_auc}")
        print(f"PR-AUC: {pr_auc}")

        # Guardar las mejores métricas
        if roc_auc > best_clf_metrics['roc_auc']:
            best_clf_metrics['accuracy'] = acc
            best_clf_metrics['f1'] = f1
            best_clf_metrics['roc_auc'] = roc_auc
            best_clf_metrics['pr_auc'] = pr_auc
            torch.save(classifier.state_dict(), 'best_classifier.pth')

    # Cargar el mejor clasificador
    classifier.load_state_dict(torch.load('best_classifier.pth'))
    classifier.eval()

    # ============================
    # Evaluar el Clasificador en el Conjunto de Test
    # ============================
    print("Evaluando en test Same/Different con el mejor clasificador...")
    all_preds = []
    all_trues = []
    all_probs = []

    with torch.no_grad():
        for x1_batch, x2_batch, sd_label in tqdm(test_sd_dataloader, desc="Evaluando en test"):
            x1_batch, x2_batch, sd_label = x1_batch.to(device), x2_batch.to(device), sd_label.to(device)
            z1 = f_theta(x1_batch)
            z2 = f_theta(x2_batch)
            z_pred = iwm_predictor(z1, generate_action_params(x1_batch, x2_batch))
            z1 = get_latent(z1)
            z2 = get_latent(z2)
            z_pred_minus_z2 = z_pred - z2
            features = torch.cat([z1, z2, z_pred_minus_z2], dim=1)  # [batch_size, 3 * latent_dim]
            logits = classifier(features)
            probs = torch.softmax(logits, dim=1)[:,1]  # Probabilidad de la clase 'DIFFERENT'
            preds = logits.argmax(dim=1)
            all_preds.append(preds.cpu().numpy())
            all_trues.append(sd_label.cpu().numpy())
            all_probs.append(probs.cpu().numpy())

    all_preds = np.concatenate(all_preds)
    all_trues = np.concatenate(all_trues)
    all_probs = np.concatenate(all_probs)

    acc = accuracy_score(all_trues, all_preds)
    f1 = f1_score(all_trues, all_preds, average='weighted')
    prec = precision_score(all_trues, all_preds, average='weighted')
    rec = recall_score(all_trues, all_preds, average='weighted')
    cm = confusion_matrix(all_trues, all_preds)
    roc_auc = roc_auc_score(all_trues, all_probs)
    precision_vals, recall_vals, _ = precision_recall_curve(all_trues, all_probs)
    pr_auc = auc(recall_vals, precision_vals)

    print("Same/Different Test Metrics (mejor clasificador):")
    print(f"Accuracy: {acc}")
    print(f"F1-score: {f1}")
    print(f"Precision: {prec}")
    print(f"Recall: {rec}")
    print(f"ROC-AUC: {roc_auc}")
    print(f"PR-AUC: {pr_auc}")
    print("Confusion Matrix:")
    print(cm)

    # Graficar Curva ROC
    fpr, tpr, thresholds = roc_curve(all_trues, all_probs)
    plt.figure()
    plt.plot(fpr, tpr, label=f'ROC curve (area = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], 'k--')  # Línea diagonal
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('Tasa de Falsos Positivos')
    plt.ylabel('Tasa de Verdaderos Positivos')
    plt.title('Curva ROC - Clasificador Same/Different')
    plt.legend(loc="lower right")
    plt.savefig('roc_curve.png')
    plt.close()

    # Graficar Curva Precision-Recall
    plt.figure()
    plt.plot(recall_vals, precision_vals, label=f'PR curve (area = {pr_auc:.2f})')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Curva Precision-Recall - Clasificador Same/Different')
    plt.legend(loc="lower left")
    plt.savefig('precision_recall_curve.png')
    plt.close()

    print("Curvas ROC y Precision-Recall guardadas como 'roc_curve.png' y 'precision_recall_curve.png' respectivamente.")
    print("Proceso final completado.")

Train size: 216000, Validation size: 24000
Entrenando TAE...


Epoch 1/10: 100%|██████████| 422/422 [03:02<00:00,  2.31it/s, Loss=0.826]
Epoch 2/10: 100%|██████████| 422/422 [03:07<00:00,  2.25it/s, Loss=0.128] 
Epoch 3/10: 100%|██████████| 422/422 [03:08<00:00,  2.24it/s, Loss=0.123] 
Epoch 4/10: 100%|██████████| 422/422 [03:08<00:00,  2.23it/s, Loss=0.118] 
Epoch 5/10: 100%|██████████| 422/422 [03:08<00:00,  2.24it/s, Loss=0.082] 
Epoch 6/10: 100%|██████████| 422/422 [03:09<00:00,  2.23it/s, Loss=0.0731]
Epoch 7/10: 100%|██████████| 422/422 [03:09<00:00,  2.23it/s, Loss=0.0705]
Epoch 8/10: 100%|██████████| 422/422 [03:09<00:00,  2.23it/s, Loss=0.0677]
Epoch 9/10: 100%|██████████| 422/422 [03:09<00:00,  2.23it/s, Loss=0.0562]
Epoch 10/10: 100%|██████████| 422/422 [03:09<00:00,  2.23it/s, Loss=0.0473]
Saving tae_representations.pkl: 100%|██████████| 235/235 [00:06<00:00, 34.43it/s]


Representations saved to tae_representations.pkl
Entrenando TVAE...


Epoch 1/10: 100%|██████████| 422/422 [03:09<00:00,  2.23it/s, Loss=1.26] 
Epoch 2/10: 100%|██████████| 422/422 [03:09<00:00,  2.23it/s, Loss=1.24] 
Epoch 3/10: 100%|██████████| 422/422 [03:08<00:00,  2.24it/s, Loss=1.24] 
Epoch 4/10: 100%|██████████| 422/422 [03:09<00:00,  2.23it/s, Loss=1.24] 
Epoch 5/10: 100%|██████████| 422/422 [03:08<00:00,  2.23it/s, Loss=1.24] 
Epoch 6/10: 100%|██████████| 422/422 [03:09<00:00,  2.23it/s, Loss=1.24] 
Epoch 7/10: 100%|██████████| 422/422 [03:08<00:00,  2.23it/s, Loss=1.23] 
Epoch 8/10: 100%|██████████| 422/422 [03:09<00:00,  2.23it/s, Loss=1.24] 
Epoch 9/10: 100%|██████████| 422/422 [03:09<00:00,  2.23it/s, Loss=1.23] 
Epoch 10/10: 100%|██████████| 422/422 [03:09<00:00,  2.23it/s, Loss=1.23] 
Saving tvae_representations.pkl: 100%|██████████| 235/235 [00:06<00:00, 34.79it/s]


Representations saved to tvae_representations.pkl
Evaluando MIG y DCI para TAE y TVAE...




TAE MIG: 0.026787081159060967, TAE DCI: 0.07815311970843693
TVAE MIG: 3.4842608072310507e-06, TVAE DCI: 0.08165103878308098
Best model for IWM JEPA: TAE
Entrenando IWM JEPA Predictor...


IWM JEPA Epoch 1/50: 100%|██████████| 844/844 [00:22<00:00, 37.15it/s, Loss=9.55]


IWM JEPA Epoch 1, Loss: 9.549289584724823


IWM JEPA Epoch 2/50: 100%|██████████| 844/844 [00:22<00:00, 37.39it/s, Loss=4.77] 


IWM JEPA Epoch 2, Loss: 4.767750910673096


IWM JEPA Epoch 3/50: 100%|██████████| 844/844 [00:22<00:00, 37.09it/s, Loss=3.58] 


IWM JEPA Epoch 3, Loss: 3.579071264413861


IWM JEPA Epoch 4/50: 100%|██████████| 844/844 [00:22<00:00, 37.15it/s, Loss=2.98] 


IWM JEPA Epoch 4, Loss: 2.9808663398168664


IWM JEPA Epoch 5/50: 100%|██████████| 844/844 [00:22<00:00, 37.19it/s, Loss=2.66] 


IWM JEPA Epoch 5, Loss: 2.6641454727728786


IWM JEPA Epoch 6/50: 100%|██████████| 844/844 [00:22<00:00, 37.33it/s, Loss=2.47] 


IWM JEPA Epoch 6, Loss: 2.4723788799267807


IWM JEPA Epoch 7/50: 100%|██████████| 844/844 [00:22<00:00, 37.42it/s, Loss=2.35] 


IWM JEPA Epoch 7, Loss: 2.3479428763073202


IWM JEPA Epoch 8/50: 100%|██████████| 844/844 [00:22<00:00, 37.16it/s, Loss=2.25] 


IWM JEPA Epoch 8, Loss: 2.2501683108049546


IWM JEPA Epoch 9/50: 100%|██████████| 844/844 [00:22<00:00, 36.98it/s, Loss=2.18] 


IWM JEPA Epoch 9, Loss: 2.1795124551413747


IWM JEPA Epoch 10/50: 100%|██████████| 844/844 [00:22<00:00, 37.13it/s, Loss=2.11] 


IWM JEPA Epoch 10, Loss: 2.1123713835438283


IWM JEPA Epoch 11/50: 100%|██████████| 844/844 [00:22<00:00, 37.19it/s, Loss=2.07] 


IWM JEPA Epoch 11, Loss: 2.0697249144456964


IWM JEPA Epoch 12/50: 100%|██████████| 844/844 [00:22<00:00, 37.17it/s, Loss=2.02] 


IWM JEPA Epoch 12, Loss: 2.0170686429993236


IWM JEPA Epoch 13/50: 100%|██████████| 844/844 [00:22<00:00, 37.07it/s, Loss=1.98] 


IWM JEPA Epoch 13, Loss: 1.975217038844999


IWM JEPA Epoch 14/50: 100%|██████████| 844/844 [00:22<00:00, 37.27it/s, Loss=1.93] 


IWM JEPA Epoch 14, Loss: 1.9284846067993562


IWM JEPA Epoch 15/50: 100%|██████████| 844/844 [00:22<00:00, 37.12it/s, Loss=1.9]  


IWM JEPA Epoch 15, Loss: 1.8978852735593985


IWM JEPA Epoch 16/50: 100%|██████████| 844/844 [00:22<00:00, 37.20it/s, Loss=1.87] 


IWM JEPA Epoch 16, Loss: 1.8707664638616464


IWM JEPA Epoch 17/50: 100%|██████████| 844/844 [00:22<00:00, 37.25it/s, Loss=1.84] 


IWM JEPA Epoch 17, Loss: 1.842900204997492


IWM JEPA Epoch 18/50: 100%|██████████| 844/844 [00:22<00:00, 37.12it/s, Loss=1.82] 


IWM JEPA Epoch 18, Loss: 1.818584969540908


IWM JEPA Epoch 19/50: 100%|██████████| 844/844 [00:22<00:00, 37.30it/s, Loss=1.79] 


IWM JEPA Epoch 19, Loss: 1.790524871993404


IWM JEPA Epoch 20/50: 100%|██████████| 844/844 [00:22<00:00, 37.34it/s, Loss=1.77] 


IWM JEPA Epoch 20, Loss: 1.7660667208416203


IWM JEPA Epoch 21/50: 100%|██████████| 844/844 [00:22<00:00, 37.21it/s, Loss=1.74] 


IWM JEPA Epoch 21, Loss: 1.7415844912495093


IWM JEPA Epoch 22/50: 100%|██████████| 844/844 [00:22<00:00, 37.25it/s, Loss=1.72] 


IWM JEPA Epoch 22, Loss: 1.7237950276707021


IWM JEPA Epoch 23/50: 100%|██████████| 844/844 [00:22<00:00, 37.19it/s, Loss=1.7]  


IWM JEPA Epoch 23, Loss: 1.6974902609230782


IWM JEPA Epoch 24/50: 100%|██████████| 844/844 [00:22<00:00, 37.20it/s, Loss=1.68] 


IWM JEPA Epoch 24, Loss: 1.6786941870411425


IWM JEPA Epoch 25/50: 100%|██████████| 844/844 [00:22<00:00, 37.25it/s, Loss=1.66] 


IWM JEPA Epoch 25, Loss: 1.655157514940506


IWM JEPA Epoch 26/50: 100%|██████████| 844/844 [00:22<00:00, 37.20it/s, Loss=1.64] 


IWM JEPA Epoch 26, Loss: 1.644217922930469


IWM JEPA Epoch 27/50: 100%|██████████| 844/844 [00:22<00:00, 37.32it/s, Loss=1.63] 


IWM JEPA Epoch 27, Loss: 1.6280499532889416


IWM JEPA Epoch 28/50: 100%|██████████| 844/844 [00:22<00:00, 37.26it/s, Loss=1.61] 


IWM JEPA Epoch 28, Loss: 1.6149404759373145


IWM JEPA Epoch 29/50: 100%|██████████| 844/844 [00:22<00:00, 37.23it/s, Loss=1.6]  


IWM JEPA Epoch 29, Loss: 1.5959248984876968


IWM JEPA Epoch 30/50: 100%|██████████| 844/844 [00:22<00:00, 37.00it/s, Loss=1.59] 


IWM JEPA Epoch 30, Loss: 1.5858129771399836


IWM JEPA Epoch 31/50: 100%|██████████| 844/844 [00:22<00:00, 37.09it/s, Loss=1.57] 


IWM JEPA Epoch 31, Loss: 1.5708314639414656


IWM JEPA Epoch 32/50: 100%|██████████| 844/844 [00:22<00:00, 37.34it/s, Loss=1.55] 


IWM JEPA Epoch 32, Loss: 1.5545150978022841


IWM JEPA Epoch 33/50: 100%|██████████| 844/844 [00:22<00:00, 37.32it/s, Loss=1.54] 


IWM JEPA Epoch 33, Loss: 1.5405287975666082


IWM JEPA Epoch 34/50: 100%|██████████| 844/844 [00:22<00:00, 37.12it/s, Loss=1.53] 


IWM JEPA Epoch 34, Loss: 1.5331027592527924


IWM JEPA Epoch 35/50: 100%|██████████| 844/844 [00:22<00:00, 37.08it/s, Loss=1.52] 


IWM JEPA Epoch 35, Loss: 1.5180081760431352


IWM JEPA Epoch 36/50: 100%|██████████| 844/844 [00:22<00:00, 37.32it/s, Loss=1.51] 


IWM JEPA Epoch 36, Loss: 1.5149562953490217


IWM JEPA Epoch 37/50: 100%|██████████| 844/844 [00:22<00:00, 37.09it/s, Loss=1.49] 


IWM JEPA Epoch 37, Loss: 1.4911896993198666


IWM JEPA Epoch 38/50: 100%|██████████| 844/844 [00:22<00:00, 37.03it/s, Loss=1.49] 


IWM JEPA Epoch 38, Loss: 1.4902453921135004


IWM JEPA Epoch 39/50: 100%|██████████| 844/844 [00:22<00:00, 37.12it/s, Loss=1.48] 


IWM JEPA Epoch 39, Loss: 1.4753968945894196


IWM JEPA Epoch 40/50: 100%|██████████| 844/844 [00:22<00:00, 36.79it/s, Loss=1.47] 


IWM JEPA Epoch 40, Loss: 1.4671429121663786


IWM JEPA Epoch 41/50: 100%|██████████| 844/844 [00:22<00:00, 37.32it/s, Loss=1.46] 


IWM JEPA Epoch 41, Loss: 1.4570867501163935


IWM JEPA Epoch 42/50: 100%|██████████| 844/844 [00:22<00:00, 37.31it/s, Loss=1.44] 


IWM JEPA Epoch 42, Loss: 1.4434536832486284


IWM JEPA Epoch 43/50: 100%|██████████| 844/844 [00:22<00:00, 37.10it/s, Loss=1.43] 


IWM JEPA Epoch 43, Loss: 1.4297177252046305


IWM JEPA Epoch 44/50: 100%|██████████| 844/844 [00:22<00:00, 37.21it/s, Loss=1.42] 


IWM JEPA Epoch 44, Loss: 1.4202729590696181


IWM JEPA Epoch 45/50: 100%|██████████| 844/844 [00:22<00:00, 37.22it/s, Loss=1.42] 


IWM JEPA Epoch 45, Loss: 1.4152426417405006


IWM JEPA Epoch 46/50: 100%|██████████| 844/844 [00:22<00:00, 37.17it/s, Loss=1.4]  


IWM JEPA Epoch 46, Loss: 1.3991631873975998


IWM JEPA Epoch 47/50: 100%|██████████| 844/844 [00:22<00:00, 37.35it/s, Loss=1.39] 


IWM JEPA Epoch 47, Loss: 1.3888826094814952


IWM JEPA Epoch 48/50: 100%|██████████| 844/844 [00:22<00:00, 36.87it/s, Loss=1.38] 


IWM JEPA Epoch 48, Loss: 1.3843155506097875


IWM JEPA Epoch 49/50: 100%|██████████| 844/844 [00:22<00:00, 37.07it/s, Loss=1.37] 


IWM JEPA Epoch 49, Loss: 1.3701862273340542


IWM JEPA Epoch 50/50: 100%|██████████| 844/844 [00:22<00:00, 37.19it/s, Loss=1.36] 

IWM JEPA Epoch 50, Loss: 1.3619043339767727
Preparando pares SAME/DIFFERENT usando todo el conjunto de entrenamiento...



  iwm_predictor.load_state_dict(torch.load('best_iwm_jepa.pth'))


Generando 216000 pares SAME/DIFFERENT para entrenamiento...
Preparando pares SAME/DIFFERENT usando todo el conjunto de test...
Generando 240000 pares SAME/DIFFERENT para test...
Definiendo y entrenando el clasificador SAME/DIFFERENT con características mejoradas...

Entrenando clasificador con tasa de aprendizaje: 0.0001
Epoch 1, Train Loss: 0.17157928121862587, Val Loss: 0.00990383271899449
Epoch 2, Train Loss: 0.017535620199309456, Val Loss: 0.0018073022331127665
Epoch 3, Train Loss: 0.006526607850359546, Val Loss: 0.0006189090382313279
Epoch 4, Train Loss: 0.0034442176033432285, Val Loss: 0.00026130220045462337
Epoch 5, Train Loss: 0.0019421151234699345, Val Loss: 0.0001292352936352526
Epoch 6, Train Loss: 0.0012138608668896334, Val Loss: 6.388502128142588e-05
Epoch 7, Train Loss: 0.0008742487582343596, Val Loss: 4.104287566368457e-05
Epoch 8, Train Loss: 0.0006777666617374591, Val Loss: 2.217688789324609e-05
Epoch 9, Train Loss: 0.0005017503572701201, Val Loss: 1.29715134246332e-05



Epoch 1, Train Loss: 9.181498461169179e-05, Val Loss: 1.4146165553876548e-08
Epoch 2, Train Loss: 3.781738305540919e-05, Val Loss: 7.128195761134357e-09
Epoch 3, Train Loss: 0.0001389282121302855, Val Loss: 2.5242053789713585e-08
Epoch 4, Train Loss: 0.00010319238678440107, Val Loss: 3.626252220460328e-08
Epoch 5, Train Loss: 8.390912891867528e-05, Val Loss: 6.323607865975889e-09
Epoch 6, Train Loss: 5.923508150902779e-05, Val Loss: 6.792911347981999e-09
Epoch 7, Train Loss: 4.049766957400102e-05, Val Loss: 3.8630520682380235e-09
Epoch 8, Train Loss: 4.80391556190322e-05, Val Loss: 5.422559073362387e-09
Epoch 9, Train Loss: 2.1144952707059188e-05, Val Loss: 1.977450368802556e-09
Epoch 10, Train Loss: 3.179289025222555e-05, Val Loss: 1.918660294087484e-09
Epoch 11, Train Loss: 9.571608946830396e-06, Val Loss: 5.786322691537919e-10
Epoch 12, Train Loss: 3.3246368384809514e-05, Val Loss: 4.91377986130789e-10
Epoch 13, Train Loss: 4.05627004379133e-05, Val Loss: 2.7278377321599747e-10
Epoc

  classifier.load_state_dict(torch.load('best_classifier.pth'))


Métricas de Validación para lr=0.0003:
Accuracy: 1.0
F1-score: 1.0
ROC-AUC: 1.0
PR-AUC: 1.0
Evaluando en test Same/Different con el mejor clasificador...


Evaluando en test: 100%|██████████| 938/938 [00:05<00:00, 164.54it/s]


Same/Different Test Metrics (mejor clasificador):
Accuracy: 0.9946416666666666
F1-score: 0.9946416516331504
Precision: 0.9946472178450688
Recall: 0.9946416666666666
ROC-AUC: 0.9998755543402778
PR-AUC: 0.9998763414685661
Confusion Matrix:
[[119156    844]
 [   442 119558]]
Curvas ROC y Precision-Recall guardadas como 'roc_curve.png' y 'precision_recall_curve.png' respectivamente.
Proceso final completado.
