<a href="https://colab.research.google.com/github/d4vidi4n/EL7006/blob/main/run_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import h5py
import json
import os
from pathlib import Path
from torch.utils.data import Dataset, DataLoader, Subset
import pandas as pd
from timeit import default_timer as timer
from collections import defaultdict
import gc

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Model_SEED

In [None]:
# ===================================================================
# MÓDULOS DEL MODELO
# ===================================================================

class ConvBlock_SEED(nn.Module):
    def __init__(self, in_channels, out_channels, block_depth, window):
        super(ConvBlock_SEED, self).__init__()

        self.block_depth = block_depth

        self.convs = nn.ModuleList(
            [nn.Conv1d(in_channels if i==0 else out_channels, out_channels, window, padding=int(window/2))
                                for i in range(block_depth)])

        self.bns = nn.ModuleList([nn.BatchNorm1d(out_channels) for _ in range(block_depth)])

        self.gelu= nn.GELU()
        self.relu = nn.ReLU()

    def forward(self, x): # (B, i, L)

        for i in range(self.block_depth):

            x = self.convs[i](x)

            x = torch.relu(self.bns[i](x))

        return x


class ConvEmbed_SEED(nn.Module):
    def __init__(self, embed_dim, block_depth, blocks, dropout, window_len):
        super(ConvEmbed_SEED, self).__init__()

        self.bn = nn.BatchNorm1d(1)

        self.ln = nn.LayerNorm(1)

        self.blocks = blocks

        self.conv_blocks = nn.ModuleList([
            ConvBlock_SEED(1 if i==0 else embed_dim//(2**(blocks - i)), embed_dim//(2**(blocks - i - 1)), block_depth, window=window_len)
                                    for i in range(blocks)])

        self.pool = nn.AvgPool1d(2 , stride=2)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x): # (B, L)

        x = self.bn(x.unsqueeze(1)) # (B, 1, L)


        for i in range(self.blocks):

            x = self.conv_blocks[i](x) # (B, Ei, L)


            x = self.pool(x)

        x = self.dropout(x) # (B, E, L/8)

        return x


class StageMultiDilatedConvolutions(nn.Module):
    def __init__(self, max_dil, filters_in, filters_out):
        super(StageMultiDilatedConvolutions, self).__init__()
        self.stage_config = []
        max_exponent = int(np.round(np.log(max_dil) / np.log(2)))
        for single_exponent in range(max_exponent + 1):
            f = int(filters_out / (2 ** (single_exponent + 1)))
            d = int(2**single_exponent)
            self.stage_config.append((f, d))
        self.stage_config[-1] = (2 * self.stage_config[-1][0], self.stage_config[-1][1])

        print(self.stage_config)
        self.branches = nn.ModuleList()
        for branch_filters, branch_dilation in self.stage_config:
            branch_layers = []
            for layer_id in ["a", "b"]:
                conv_layer = nn.Conv1d(
                    in_channels=filters_in if layer_id == "a" else branch_filters,
                    out_channels=branch_filters,
                    kernel_size=3,
                    padding=branch_dilation,
                    dilation=branch_dilation,
                    bias=False
                )
                bn_layer = nn.BatchNorm1d(branch_filters)
                relu_layer = nn.ReLU()
                branch_layers.append(conv_layer)
                branch_layers.append(bn_layer)
                branch_layers.append(relu_layer)
            self.branches.append(nn.Sequential(*branch_layers))

    def forward(self, inputs):
        outputs_branches = []
        for branch in self.branches:

            branch_outputs = inputs

            for layer in branch:

                branch_outputs = layer(branch_outputs)

            outputs_branches.append(branch_outputs)

        outputs = torch.cat(outputs_branches, dim=1)
        return outputs

class Local_encoding_stage_SEED(nn.Module):
    def __init__(self,  embed_dim, block_depth, blocks, dropout, window_len, filters_in, block_dilatation):
        super(Local_encoding_stage_SEED, self).__init__()

        self.block_dilatation = block_dilatation

        self.conv_layer = ConvEmbed_SEED(embed_dim, block_depth, blocks, dropout, window_len)

        self.layer_dilatations = nn.ModuleList([
            StageMultiDilatedConvolutions(8, filters_in*pow(2, i), filters_in*pow(2, i+1))
                                for i in range(block_dilatation)])
        self.pool = nn.AvgPool1d(2 , stride=2)


    def forward(self, inputs):

        output = self.conv_layer(inputs)

        for i in range(self.block_dilatation):

            output = self.layer_dilatations[i](output)

            output = self.pool(output)

        output = output.transpose(1, 2)

        return output


class BiLSTM(nn.Module):
    def __init__(self, embed_dim=256):
        super(BiLSTM, self).__init__()
        self.dropout = nn.Dropout(0.2)
        self.lstm = nn.LSTM(embed_dim, embed_dim, num_layers=2, dropout=0.5, bidirectional=True)

    def forward(self, x):
        x = self.dropout(x)
        h, _ = self.lstm(x.transpose(0, 1))
        return h.transpose(0, 1)

class RED_FC(nn.Module):
    def __init__(self, embed_dim = 256, hide_dim = 128):
        super(RED_FC, self).__init__()

        self.dropout = nn.Dropout(0.5)

        self.hidden = nn.Linear(2*embed_dim, hide_dim)
        nn.init.xavier_uniform_(self.hidden.weight, gain=nn.init.calculate_gain('relu'))
        self.hidden.bias.data.fill_(0.)

        self.logits = nn.Linear(hide_dim, 1)
        nn.init.xavier_uniform_(self.logits.weight)
        p1 = 0.1
        logp1 = torch.log(torch.tensor(p1/(1 - p1)))
        self.logits.bias = nn.Parameter(torch.tensor([logp1]))

        self.fc = nn.Sequential(self.hidden, nn.ReLU(), self.logits)

    def forward(self, x):
        x = self.dropout(x)
        x = self.fc(x).squeeze(-1)
        return x

In [None]:
# Parámetros del modelo
EMBED_DIM = 64
EMBED_BLOCKS = 1
FILTER_IN = 64
BLOCK_DILATATION = 2
RESOL_FACTOR = 8

# Parámetros de entrenamiento
EPOCHS = 5
BATCH_SIZE = 128 # Lo cambié a 128 para ocupar más memoria, debería demorarse menos (?)
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-4
NUM_WORKERS = 12
DROPOUT = 0.1

# Configuraciones del modelo
EMBEDDING_CONFIG_SEED = {
    "embed_dim": EMBED_DIM,
    "block_depth": 2,
    "blocks": EMBED_BLOCKS,
    "dropout": 0.0,
    "filters_in": FILTER_IN,
    "block_dilatation": BLOCK_DILATATION,
}

CLASSIFIER_CONFIG = {
    "embed_dim": 256,  # BiLSTM es bidireccional (2x)
    "hide_dim": 128,
}


class SEED_Model(nn.Module):
    def __init__(self):
        super(SEED_Model, self).__init__()

        self.embedding = Local_encoding_stage_SEED(**EMBEDDING_CONFIG_SEED, window_len=3)
        final_channels = FILTER_IN * (2 ** BLOCK_DILATATION)
        self.transformer = BiLSTM(final_channels)
        self.classifier = RED_FC(**CLASSIFIER_CONFIG)

        # Loss
        self.criterion = nn.BCEWithLogitsLoss()

    def forward(self, x):
        embeds = self.embedding(x)
        output = self.transformer(embeds)
        output = self.classifier(output)
        output = output.repeat_interleave(RESOL_FACTOR, dim=1)
        return output

# Train_model


In [None]:


# ===================================================================
# PARÁMETROS DE CONFIGURACIÓN
# ===================================================================

# Parámetros del modelo
EMBED_DIM = 64
EMBED_BLOCKS = 1
FILTER_IN = 64
BLOCK_DILATATION = 2
RESOL_FACTOR = 8

# Parámetros de entrenamiento
EPOCHS = 5
BATCH_SIZE = 128
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-4
NUM_WORKERS = 12
DROPOUT = 0.1

# Configuraciones del modelo
EMBEDDING_CONFIG_SEED = {
    "embed_dim": EMBED_DIM,
    "block_depth": 2,
    "blocks": EMBED_BLOCKS,
    "dropout": 0.0,
    "filters_in": FILTER_IN,
    "block_dilatation": BLOCK_DILATATION,
}

CLASSIFIER_CONFIG = {
    "embed_dim": 256,  # BiLSTM es bidireccional (2x)
    "hide_dim": 128,
}

# Paths
h5_file_path = "/content/drive/MyDrive/codigo_eeg_redes/data/moda_augmented_s5.h5"
folds_json_path = "/content/drive/MyDrive/codigo_eeg_redes/data/moda_folds.json"
weights_save_dir = "/content/drive/MyDrive/codigo_eeg_redes/model_weights_best_original"  # o model_weights

print("🚀 CONFIGURACIÓN CARGADA:")
print(f"   📁 Dataset: {h5_file_path}")
print(f"   📁 Folds: {folds_json_path}")
print(f"   💾 Pesos: {weights_save_dir}")
print(f"   🧠 Embed dim: {EMBED_DIM}")
print(f"   🎯 Epochs: {EPOCHS}")
print(f"   📦 Batch size: {BATCH_SIZE}")


class H5Dataset(Dataset):
    def __init__(self, h5_file_path, subject_ids=None, norm_data=False):
        self.h5_file_path = h5_file_path
        self.norm_data = norm_data

        # Abrir archivo para obtener información
        with h5py.File(h5_file_path, 'r') as f:
            self.total_samples = len(f['signals'])
            self.signal_length = f['signals'].shape[1]
            self.label_length = f['labels'].shape[1]

            # Obtener todos los subject_ids
            all_subject_ids = [s.decode('utf-8') for s in f['subject_ids'][:]]

            # Filtrar por subject_ids si se especifica
            if subject_ids is not None:
                self.valid_indices = []
                for i, subj_id in enumerate(all_subject_ids):
                    if subj_id in subject_ids:
                        self.valid_indices.append(i)
            else:
                self.valid_indices = list(range(self.total_samples))

        # ✅ CALCULAR df_id_group PARA BALANCED BATCH SAMPLER
        self._calculate_median_groups()

        print(f"📊 H5Dataset creado:")
        print(f"   Total muestras en H5: {self.total_samples}")
        print(f"   Muestras válidas: {len(self.valid_indices)}")
        print(f"   Longitud señal: {self.signal_length}")
        print(f"   Longitud labels: {self.label_length}")

    def _calculate_median_groups(self):
        """Calcular grupos de mediana para balanceado"""
        print("📊 Calculando grupos de mediana para balanceado...")

        # Calcular positivos por muestra
        positives_per_sample = []
        with h5py.File(self.h5_file_path, 'r') as f:
            for idx in self.valid_indices:
                labels = f['labels'][idx]
                positives = np.sum(labels)
                positives_per_sample.append(positives)

        # Calcular mediana
        median_positives = np.median(positives_per_sample)

        # Crear DataFrame con grupos de mediana
        data = []
        for i, idx in enumerate(self.valid_indices):
            positives = positives_per_sample[i]
            median_group = 1 if positives > median_positives else 0
            data.append({'id': i + 1, 'median_group': median_group})  # +1 para compatibilidad

        self.df_id_group = pd.DataFrame(data)

        above_median = len(self.df_id_group[self.df_id_group['median_group'] == 1])
        below_median = len(self.df_id_group[self.df_id_group['median_group'] == 0])

        print(f"   ✅ Mediana de positivos: {median_positives:.1f}")
        print(f"   📊 Above median: {above_median} muestras")
        print(f"   📊 Below median: {below_median} muestras")

    def __len__(self):
        return len(self.valid_indices)

    def __getitem__(self, idx):
        real_idx = self.valid_indices[idx]

        with h5py.File(self.h5_file_path, 'r') as f:
            signal = f['signals'][real_idx]
            labels = f['labels'][real_idx]

            # Convertir a float32
            signal = signal.astype(np.float32)
            labels = labels.astype(np.float32)

            # Normalizar si es necesario
            if self.norm_data:
                min_val, max_val = np.min(signal), np.max(signal)
                if max_val > min_val:
                    signal = 2 * (signal - min_val) / (max_val - min_val) - 1

            return real_idx, signal, labels

# ===================================================================
# BALANCED BATCH SAMPLER
# ===================================================================

class BalancedBatchSampler:
    def __init__(self, df_id_group, batch_size, shuffle=True):
        self.shuffle = shuffle
        self.batch_size = batch_size

        self.id_above_median = df_id_group[
            df_id_group.median_group == 1
        ]['id'].values
        self.id_below_median = df_id_group[
            df_id_group.median_group == 0
        ]['id'].values

        # ✅ BALANCEO ORIGINAL (above/below median)
        self._balance_above_below_median()

    def _balance_above_below_median(self):
        """Balanceo original por above/below median"""
        print(f"\n📊 BALANCEO ABOVE/BELOW MEDIAN:")
        print(f"   Above median: {len(self.id_above_median)} IDs")
        print(f"   Below median: {len(self.id_below_median)} IDs")

        difference = len(self.id_above_median) - len(self.id_below_median)

        if difference > 0:  # Más IDs above que below
            ids_repeated = np.random.choice(
                self.id_below_median, difference, replace=False,
            )
            self.id_below_median = np.append(
                self.id_below_median, ids_repeated,
            )
            print(f"   ✅ Duplicados {difference} IDs del grupo below_median")

        elif difference < 0:  # Más IDs below que above
            ids_repeated = np.random.choice(
                self.id_above_median, -1*difference, replace=False,
            )
            self.id_above_median = np.append(
                self.id_above_median, ids_repeated,
            )
            print(f"   ✅ Duplicados {-difference} IDs del grupo above_median")

        print(f"   📊 RESULTADO: Above={len(self.id_above_median)}, Below={len(self.id_below_median)}")

    def __iter__(self):
        if self.shuffle:
            np.random.shuffle(self.id_above_median)
            np.random.shuffle(self.id_below_median)

        batch = [0] * self.batch_size  # indexes of batch init ceros
        idx_in_batch = 0

        for above_idx, below_idx in zip(self.id_above_median, self.id_below_median):
            batch[idx_in_batch] = below_idx - 1  # -1 para índices de torch
            batch[idx_in_batch + 1] = above_idx - 1  # -1 para índices de torch
            idx_in_batch += 2

            if idx_in_batch >= self.batch_size:  # yield batch and start new one
                yield batch
                idx_in_batch = 0
                batch = [0] * self.batch_size

        # if, for some ids group there are more than in others
        if idx_in_batch > 0:
            yield batch[:idx_in_batch]

    def __len__(self):
        total_ids = len(self.id_above_median) + len(self.id_below_median)
        return (total_ids + self.batch_size - 1) // self.batch_size

def create_batch_h5(data_list):
    """Función para crear batches desde H5Dataset"""
    data_list = [data for data in data_list if data is not None]

    if len(data_list) == 0:
        return torch.empty(0, dtype=torch.long), torch.empty(0, 0), torch.empty(0, 0)

    ids, signals, labels = zip(*data_list)

    ids_tensor = torch.LongTensor(np.array(ids))
    signals_tensor = torch.FloatTensor(np.array(signals))
    labels_tensor = torch.FloatTensor(np.array(labels))

    return ids_tensor, signals_tensor, labels_tensor

# ===================================================================
# FUNCIONES DE ENTRENAMIENTO
# ===================================================================

def early_stopping(loss_dev, patience=8):
    if len(loss_dev) < patience:
        return False
    recent_losses = loss_dev[-patience:]
    return all(recent_losses[i] >= recent_losses[i - 1] for i in range(1, patience))

def train_epoch(model, optimizer, loader, epoch, device):
    model.train()
    run_loss = 0.0
    batches = len(loader)

    for i, (_, input, target) in enumerate(loader):
        input, target = input.to(device), target.to(device)

        output = model(input)
        loss = model.criterion(output.flatten(), target.flatten())

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        optimizer.zero_grad()

        run_loss += loss.item()
        print(f'\rEpoch {epoch} ({100*(i + 1)/batches:.2f}%) | Train Loss: {run_loss / (i + 1):.4f} ', end='', flush=True)

    return run_loss / batches

def evaluate_epoch(model, loader, device):
    model.eval()
    val_loss = 0.0

    with torch.no_grad():
        for _, input, target in loader:
            input, target = input.to(device), target.to(device)

            output = model(input)
            val_loss += model.criterion(output.flatten(), target.flatten()).item()

    avg_loss = val_loss / len(loader)
    print(f"| Val Loss: {avg_loss:.4f}")
    return avg_loss

# ===================================================================
# FUNCIÓN PRINCIPAL DE ENTRENAMIENTO
# ===================================================================

def train_model_for_fold(eval_num, fold_num, val_subjects, test_subjects, all_subjects):
    """Entrena el modelo para un fold específico"""

    print(f"\n{'='*60}")
    print(f"🎯 ENTRENANDO EVAL {eval_num} - FOLD {fold_num}")
    print(f"{'='*60}")

    # ✅ LIMPIAR MEMORIA Y CACHE DE CUDA ANTES DE CREAR NUEVO MODELO
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

    # Determinar sujetos de entrenamiento
    train_subjects = [s for s in all_subjects if s not in val_subjects and s not in test_subjects]

    print(f"📊 Distribución de sujetos:")
    print(f"   Train: {len(train_subjects)} sujetos")
    print(f"   Val: {len(val_subjects)} sujetos")
    print(f"   Test: {len(test_subjects)} sujetos")

    # Crear datasets
    train_dataset = H5Dataset(h5_file_path, subject_ids=train_subjects, norm_data=False)
    val_dataset = H5Dataset(h5_file_path, subject_ids=val_subjects, norm_data=False)

    # ✅ CREAR LOADERS CON BALANCED BATCH SAMPLER PARA ENTRENAMIENTO
    train_loader = DataLoader(
        train_dataset,
        num_workers=NUM_WORKERS,
        collate_fn=create_batch_h5,
        batch_sampler=BalancedBatchSampler(train_dataset.df_id_group, batch_size=BATCH_SIZE, shuffle=True)
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        collate_fn=create_batch_h5
    )

    print(f"📦 Datos cargados:")
    print(f"   Train samples: {len(train_dataset)}")
    print(f"   Val samples: {len(val_dataset)}")
    print(f"   Train batches: {len(train_loader)}")
    print(f"   Val batches: {len(val_loader)}")

    # Configurar device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"🖥️  Device: {device}")

    # ✅ CREAR MODELO NUEVO PARA CADA FOLD (sin reutilizar)
    model = SEED_Model().to(device)
    print(f"🧠 Parámetros del modelo: {sum(p.numel() for p in model.parameters()):,}")

    # ✅ OPTIMIZADOR NUEVO PARA CADA FOLD
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

    # Variables de seguimiento
    loss_train, loss_dev = [], []
    best_val_loss = float('inf')
    patience_count = 0

    # Carpeta para guardar pesos
    save_folder = Path(weights_save_dir) / f"eval_{eval_num}" / f"fold_{fold_num}"
    save_folder.mkdir(parents=True, exist_ok=True)

    print(f"💾 Guardando en: {save_folder}")

    # Entrenamiento
    for epoch in range(1, EPOCHS + 1):
        t0 = timer()

        # Entrenar
        train_loss = train_epoch(model, optimizer, train_loader, epoch, device)

        # Evaluar
        val_loss = evaluate_epoch(model, val_loader, device)

        # Guardar métricas
        loss_train.append(train_loss)
        loss_dev.append(val_loss)

        dt = timer() - t0
        print(f'| Epoch Time: {dt//60:.0f}m {dt%60:.1f}s')

        # Guardar mejor modelo
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_count = 0

            # Guardar pesos
            torch.save(model.embedding.state_dict(), save_folder / "embedding_best.pth")
            torch.save(model.transformer.state_dict(), save_folder / "lstm_best.pth")
            torch.save(model.classifier.state_dict(), save_folder / "classifier_best.pth")

            print(f"✅ Mejor modelo guardado | Val Loss: {best_val_loss:.4f}")
        else:
            patience_count += 1


        # Paciencia
        if patience_count >= 6:
            print(f"🛑 Early stopping por paciencia en época {epoch}")
            break

    # Guardar modelo final
    torch.save(model.embedding.state_dict(), save_folder / "embedding_final.pth")
    torch.save(model.transformer.state_dict(), save_folder / "lstm_final.pth")
    torch.save(model.classifier.state_dict(), save_folder / "classifier_final.pth")

    print(f"✅ Entrenamiento completado para Eval {eval_num} - Fold {fold_num}")
    print(f"📊 Mejor Val Loss: {best_val_loss:.4f}")
    print(f"💾 Pesos guardados en: {save_folder}")

    # ✅ LIMPIAR MODELO Y MEMORIA AL FINALIZAR EL FOLD
    del model
    del optimizer
    del train_loader
    del val_loader
    del train_dataset
    del val_dataset

    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

    print(f"🧹 Memoria limpiada para siguiente fold")

    return best_val_loss

# ===================================================================
# EJECUCIÓN PRINCIPAL
# ===================================================================

def main():
    """Función principal que ejecuta todo el entrenamiento"""

    print("🚀 INICIANDO ENTRENAMIENTO COMPLETO DEL MODELO SEED")
    print("="*60)

    # Cargar folds
    print("📁 Cargando configuración de folds...")
    with open(folds_json_path, 'r') as f:
        folds_data = json.load(f)

    # Obtener todos los sujetos del dataset
    with h5py.File(h5_file_path, 'r') as f:
        all_subjects = list(set([s.decode('utf-8') for s in f['subject_ids'][:]]))

    print(f"📊 Total sujetos en dataset: {len(all_subjects)}")

    # Crear directorio base
    Path(weights_save_dir).mkdir(exist_ok=True)

    # Resultados generales
    results = []

    # Iterar sobre todas las evaluaciones y folds
    for eval_key in sorted(folds_data.keys()):
        eval_num = int(eval_key.split('_')[1])

        for fold_key in sorted(folds_data[eval_key].keys()):
            fold_num = int(fold_key.split('_')[1])

            # Obtener sujetos
            val_subjects = folds_data[eval_key][fold_key]['val']
            test_subjects = folds_data[eval_key][fold_key]['test']

            # ✅ ENTRENAR MODELO (NUEVO PARA CADA FOLD)
            best_loss = train_model_for_fold(eval_num, fold_num, val_subjects, test_subjects, all_subjects)

            # Guardar resultado
            results.append({
                'eval': eval_num,
                'fold': fold_num,
                'best_val_loss': best_loss,
                'train_subjects': len(all_subjects) - len(val_subjects) - len(test_subjects),
                'val_subjects': len(val_subjects),
                'test_subjects': len(test_subjects)
            })

    # Mostrar resumen final
    print(f"\n{'='*60}")
    print("📊 RESUMEN FINAL DE ENTRENAMIENTO")
    print(f"{'='*60}")

    results_df = pd.DataFrame(results)
    print(results_df.to_string(index=False))

    print(f"\n📊 Estadísticas generales:")
    print(f"   Mejor Val Loss promedio: {results_df['best_val_loss'].mean():.4f}")
    print(f"   Mejor Val Loss std: {results_df['best_val_loss'].std():.4f}")
    print(f"   Mejor Val Loss min: {results_df['best_val_loss'].min():.4f}")
    print(f"   Mejor Val Loss max: {results_df['best_val_loss'].max():.4f}")

    print(f"\n✅ ENTRENAMIENTO COMPLETO FINALIZADO")
    print(f"💾 Todos los pesos guardados en: {weights_save_dir}")

# Ejecutar entrenamiento
if __name__ == "__main__":
    main()

🚀 CONFIGURACIÓN CARGADA:
   📁 Dataset: /content/drive/MyDrive/docs phd/2025-2/Redes neuronales/Proyecto/código_eeg_redes/código_eeg_redes/data/moda_augmented_s5.h5
   📁 Folds: /content/drive/MyDrive/docs phd/2025-2/Redes neuronales/Proyecto/código_eeg_redes/código_eeg_redes/data/moda_folds.json
   💾 Pesos: /content/drive/MyDrive/docs phd/2025-2/Redes neuronales/Proyecto/código_eeg_redes/código_eeg_redes/model_weights
   🧠 Embed dim: 64
   🎯 Epochs: 5
   📦 Batch size: 32
🚀 INICIANDO ENTRENAMIENTO COMPLETO DEL MODELO SEED
📁 Cargando configuración de folds...
📊 Total sujetos en dataset: 180

🎯 ENTRENANDO EVAL 1 - FOLD 1
📊 Distribución de sujetos:
   Train: 108 sujetos
   Val: 36 sujetos
   Test: 36 sujetos
📊 Calculando grupos de mediana para balanceado...
   ✅ Mediana de positivos: 132.0
   📊 Above median: 4471 muestras
   📊 Below median: 4484 muestras
📊 H5Dataset creado:
   Total muestras en H5: 14940
   Muestras válidas: 8955
   Longitud señal: 4000
   Longitud labels: 4000
📊 Calc

# Inference

In [13]:
import torch
import torch.nn as nn
import numpy as np
import h5py
import json
from pathlib import Path
import gc
from torch.utils.data import Dataset

# ===================================================================
# CONFIGURACIÓN DEL MODELO (misma que en entrenamiento)
# ===================================================================



EMBED_DIM = 64
EMBED_BLOCKS = 1
FILTER_IN = 64
BLOCK_DILATATION = 2
RESOL_FACTOR = 8
BATCH_SIZE = 32
NUM_WORKERS = 12

EMBEDDING_CONFIG_SEED = {
    "embed_dim": EMBED_DIM,
    "block_depth": 2,
    "blocks": EMBED_BLOCKS,
    "dropout": 0.0,
    "filters_in": FILTER_IN,
    "block_dilatation": BLOCK_DILATATION,
}

CLASSIFIER_CONFIG = {
    "embed_dim": 256,
    "hide_dim": 128,
}

# Paths
h5_file_path = "/content/drive/MyDrive/codigo_eeg_redes/data/moda_augmented_s5.h5"
folds_json_path = "/content/drive/MyDrive/codigo_eeg_redes/data/moda_folds.json"
weights_save_dir = "/content/drive/MyDrive/codigo_eeg_redes/model_weights_best_original"  # o model_weights

print("🔍 CONFIGURACIÓN DE INFERENCIA:")
print(f"   📁 Dataset: {h5_file_path}")
print(f"   📁 Folds: {folds_json_path}")
print(f"   💾 Pesos: {weights_save_dir}")


# ===================================================================
# DATASET H5 PARA INFERENCIA
# ===================================================================

class H5InferenceDataset(Dataset):
    def __init__(self, h5_file_path, subject_ids=None):
        self.h5_file_path = h5_file_path

        with h5py.File(h5_file_path, 'r') as f:
            self.total_samples = len(f['signals'])
            self.signal_length = f['signals'].shape[1]
            self.label_length = f['labels'].shape[1]

            all_subject_ids = [s.decode('utf-8') for s in f['subject_ids'][:]]

            if subject_ids is not None:
                self.valid_indices = []
                for i, subj_id in enumerate(all_subject_ids):
                    if subj_id in subject_ids:
                        self.valid_indices.append(i)
            else:
                self.valid_indices = list(range(self.total_samples))

        print(f"📊 Dataset de inferencia creado:")
        print(f"   Total muestras: {len(self.valid_indices)}")

    def __len__(self):
        return len(self.valid_indices)

    def __getitem__(self, idx):
        real_idx = self.valid_indices[idx]

        with h5py.File(self.h5_file_path, 'r') as f:
            signal = f['signals'][real_idx]
            labels = f['labels'][real_idx]

            signal = signal.astype(np.float32)
            labels = labels.astype(np.float32)

            return real_idx, signal, labels

def create_batch_inference(data_list):
    """Función para crear batches para inferencia"""
    data_list = [data for data in data_list if data is not None]

    if len(data_list) == 0:
        return torch.empty(0, dtype=torch.long), torch.empty(0, 0), torch.empty(0, 0)

    ids, signals, labels = zip(*data_list)

    ids_tensor = torch.LongTensor(np.array(ids))
    signals_tensor = torch.FloatTensor(np.array(signals))
    labels_tensor = torch.FloatTensor(np.array(labels))

    return ids_tensor, signals_tensor, labels_tensor

# ===================================================================
# FUNCIONES DE Post PROCESAMIENTO
# ===================================================================

def process_spindle_simple(det_out, dur_out):
    """Versión simplificada del procesamiento de spindles"""
    new_output = dur_out.copy()

    F_RESAMPLE = 200  # Frecuencia de muestreo
    min_duration = int(F_RESAMPLE * 0.3)  # 0.3 segundos
    max_duration = int(F_RESAMPLE * 5)    # 5 segundos

    # Conectar detecciones cercanas
    i_last_label = 0
    in_event = False

    for i, label in enumerate(dur_out):
        if label == 1 and not in_event:
            if i_last_label != 0 and i - i_last_label < min_duration:
                new_output[i_last_label: i] = np.ones(i - i_last_label)
            in_event = True
        elif label == 0 and in_event:
            i_last_label = i
            in_event = False

    # Eliminar eventos muy cortos o largos
    i_start_label = 0
    in_event = False

    for i, label in enumerate(new_output):
        if label == 1 and not in_event:
            i_start_label = i
            in_event = True
        elif (label == 0 or i + 1 == len(new_output)) and in_event:
            if i - i_start_label < min_duration:
                new_output[i_start_label: i] = np.zeros(i - i_start_label)
            elif i - i_start_label > max_duration:
                p0 = (i - i_start_label - max_duration) // 2 + i_start_label
                pf = p0 + max_duration
                new_output[i_start_label: p0] = np.zeros(p0 - i_start_label)
                new_output[pf: i] = np.zeros(i - pf)
            in_event = False

    return new_output

def get_results_simple(labels, output):

    label_stamps = []
    predict_stamps = []

    in_predict_event = 0
    in_label_event = 0

    for i, (label, predict) in enumerate(zip(labels, output)):
        # Eventos reales
        if not in_label_event and label == 1:
            in_label_event = i
        elif in_label_event and (label == 0 or i + 1 == len(labels)):
            label_stamps.append((in_label_event, i))
            in_label_event = 0
            if (not in_predict_event and len(predict_stamps) < len(label_stamps)) or len(predict_stamps) + 1 < len(label_stamps):
                predict_stamps.append(0)

        # Eventos predichos
        if not in_predict_event and predict == 1:
            in_predict_event = i
        elif in_predict_event and (predict == 0 or i + 1 == len(output)):
            predict_stamps.append((in_predict_event, i))
            in_predict_event = 0
            if (not in_label_event and len(label_stamps) < len(predict_stamps)) or len(label_stamps) + 1 < len(predict_stamps):
                label_stamps.append(0)

    result = []
    # Calcular IoUs
    for label, predict in zip(label_stamps, predict_stamps):
        if not label and predict:  # Falso positivo
            result.append(-1)
        elif not predict and label:  # Falso negativo
            result.append(0)
        else:  # IoU
            inter = min(label[1], predict[1]) - max(label[0], predict[0])
            union = max(label[1], predict[1]) - min(label[0], predict[0])
            result.append(inter / union)

    return result

def calculate_metrics_simple(results, thresholds, threshold=0.2):
    """Versión simplificada de calculate_metrics"""
    true_positives = np.zeros_like(thresholds, dtype=float)
    false_positives = np.zeros_like(thresholds, dtype=float)
    false_negatives = np.zeros_like(thresholds, dtype=float)

    threshold_index = np.argmin(np.abs(thresholds - threshold))

    for i, t in enumerate(thresholds):
        for IoU in results:
            if IoU == -1:
                false_positives[i] += 1
            elif IoU == 0:
                false_negatives[i] += 1
            elif IoU >= t:
                true_positives[i] += 1
            else:
                false_negatives[i] += 1

    with np.errstate(divide='ignore', invalid='ignore'):
        precision = np.divide(true_positives, (true_positives + false_positives))
        recall = np.divide(true_positives, (true_positives + false_negatives))

        precision = np.nan_to_num(precision, nan=0.0)
        recall = np.nan_to_num(recall, nan=0.0)

        f1_score = np.divide(2 * precision * recall, (precision + recall),
                           out=np.zeros_like(recall), where=(precision + recall) != 0)

    return threshold_index, precision, recall, f1_score
# ===================================================================
# FUNCIÓN PRINCIPAL DE INFERENCIA
# ===================================================================

def run_inference_for_fold(eval_num, fold_num, test_subjects, model_weights_dir):
    """Ejecuta inferencia para un fold específico SIN DataLoader"""

    print(f"\n🔍 INFERENCIA EVAL {eval_num} - FOLD {fold_num}")
    print(f"   Test subjects: {test_subjects}")

    # Limpiar memoria
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

    # Crear dataset de test (SIN DataLoader)
    test_dataset = H5InferenceDataset(h5_file_path, subject_ids=test_subjects)

    # Cargar modelo
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = SEED_Model().to(device)

    # Cargar pesos

    weights_folder = Path(model_weights_dir) / f"eval_{eval_num}" / f"fold_{fold_num}"

    try:
        model.embedding.load_state_dict(torch.load(weights_folder / "embedding_best.pth", map_location=device), strict=False) # Stric=False para no cargar pesos innecesarios de los embeddings originales
        model.transformer.load_state_dict(torch.load(weights_folder / "lstm_best.pth", map_location=device))
        model.classifier.load_state_dict(torch.load(weights_folder / "classifier_best.pth", map_location=device))
        print(f"   ✅ Pesos cargados desde: {weights_folder}")
    except Exception as e:
        print(f"   ❌ Error cargando pesos: {e}")
        return []

    # Modo evaluación
    model.eval()

    all_results = []
    total_signals = len(test_dataset)

    # ✅ ITERAR DIRECTAMENTE SOBRE EL DATASET
    with torch.no_grad():
        for s, (_, comp_signal, comp_labels) in enumerate(test_dataset):

            print(f"   Procesando señal {s+1}/{total_signals}")

            # Procesar señal completa
            comp_output = process_full_signal(model, comp_signal, device) #LOGITS

            # Aplicar thresholds
            det_out = np.where(comp_output >= 0.5, np.ones_like(comp_output), np.zeros_like(comp_output))
            dur_out = np.where(comp_output >= 0.425, np.ones_like(comp_output), np.zeros_like(comp_output))

            # Procesar detecciones
            processed_out = process_spindle_simple(det_out, dur_out)

            # Calcular IoUs
            sample_results = get_results_simple(comp_labels.astype(int), processed_out)
            all_results.extend(sample_results)

    print(f"\n   📊 Total resultados: {len(all_results)}")

    # Limpiar memoria
    del model
    del test_dataset

    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

    return all_results

def process_full_signal(model, comp_signal, device):
    """Procesa una señal completa usando segmentos (como en detections_2.py)"""

    F_RESAMPLE = 200
    SIGNAL_WINDOW = 20  # 20 segundos
    INFERENCE_STRIDE = 10  # 10 segundos

    # Cantidad de segmentos
    segments = int(len(comp_signal) / F_RESAMPLE - SIGNAL_WINDOW) // INFERENCE_STRIDE + 1

    for i in range(segments):
        p0 = INFERENCE_STRIDE * F_RESAMPLE * i
        pf = p0 + SIGNAL_WINDOW * F_RESAMPLE  # 20 segundos desde la pos inicial

        # Obtener output del segmento
        signal_segment = torch.FloatTensor(comp_signal[p0:pf]).unsqueeze(0).to(device)
        output = model(signal_segment)
        output = torch.sigmoid(output).squeeze().cpu().numpy()

        # Concatenar 10 segundos centrales del output
        p1 = F_RESAMPLE * (SIGNAL_WINDOW - INFERENCE_STRIDE) // 2
        p2 = p1 + F_RESAMPLE * INFERENCE_STRIDE

        if i == 0:
            comp_output = output[:p2]
        elif i + 1 == segments:
            comp_output = np.concatenate((comp_output, output[p1:]), axis=0)
        else:
            comp_output = np.concatenate((comp_output, output[p1:p2]), axis=0)

    # Analizar si queda resto de señal
    rest = (len(comp_signal) / F_RESAMPLE - SIGNAL_WINDOW) % INFERENCE_STRIDE
    if rest != 0:
        p0 = len(comp_signal) - SIGNAL_WINDOW * F_RESAMPLE
        signal_segment = torch.FloatTensor(comp_signal[p0:]).unsqueeze(0).to(device)
        output = model(signal_segment)
        output = torch.sigmoid(output).squeeze().cpu().numpy()
        comp_output = np.concatenate((comp_output, output[-int(rest * F_RESAMPLE):]), axis=0)

    return comp_output

def main_inference():
    """Función principal que ejecuta toda la inferencia SIN DataLoader"""

    print("🚀 INICIANDO INFERENCIA COMPLETA")
    print("="*60)

    # Cargar configuración de folds
    with open(folds_json_path, 'r') as f:
        folds_data = json.load(f)

    all_fold_results = []
    evaluations = [1, 2, 3]
    folds = [1, 2, 3, 4, 5]

    total_folds = len(evaluations) * len(folds)
    current_fold = 0

    # Iterar sobre todas las evaluaciones y folds
    for eval_num in evaluations:
        for fold_num in folds:
            current_fold += 1

            eval_key = f"eval_{eval_num}"
            fold_key = f"fold_{fold_num}"

            test_subjects = folds_data[eval_key][fold_key]['test']

            print(f"\n{'='*60}")
            print(f"🎯 PROCESANDO {current_fold}/{total_folds}: EVAL {eval_num} - FOLD {fold_num}")
            print(f"{'='*60}")

            # Ejecutar inferencia para este fold
            fold_results = run_inference_for_fold(eval_num, fold_num, test_subjects, weights_save_dir)

            if fold_results:
                all_fold_results.extend(fold_results)
                print(f"✅ Fold completado - Resultados agregados: {len(fold_results)}")
            else:
                print(f"❌ Error en fold {eval_num}-{fold_num}")

    # Calcular métricas finales (microaverage)
    print(f"\n{'='*60}")
    print("📊 CALCULANDO MÉTRICAS FINALES")
    print(f"{'='*60}")

    if all_fold_results:
        thresholds = np.linspace(0, 1, 100)
        index, precision, recall, f1_score = calculate_metrics_simple(all_fold_results, thresholds)

        final_metrics = {
            'total_results': len(all_fold_results),
            'f1_score': float(f1_score[index] * 100),
            'precision': float(precision[index] * 100),
            'recall': float(recall[index] * 100),
            'threshold_iou': 0.2,
            'evaluations': evaluations,
            'folds': folds,
            'microaverage': True
        }

        print(f"📊 RESULTADOS FINALES (Microaverage):")
        print(f"   Total resultados: {final_metrics['total_results']}")
        print(f"   F1-Score: {final_metrics['f1_score']:.3f}%")
        print(f"   Precision: {final_metrics['precision']:.3f}%")
        print(f"   Recall: {final_metrics['recall']:.3f}%")

        # Guardar resultados
        output_file = "inference_results_moda_microaverage.json"
        with open(output_file, 'w') as f:
            json.dump(final_metrics, f, indent=4)

        print(f"\n💾 Resultados guardados en: {output_file}")

    else:
        print("❌ No se obtuvieron resultados de ningún fold")

# Ejecutar inferencia
if __name__ == "__main__":
    main_inference()

[1;30;43mSe truncaron las últimas líneas 5000 del resultado de transmisión.[0m
   Procesando señal 993/2992
   Procesando señal 994/2992
   Procesando señal 995/2992
   Procesando señal 996/2992
   Procesando señal 997/2992
   Procesando señal 998/2992
   Procesando señal 999/2992
   Procesando señal 1000/2992
   Procesando señal 1001/2992
   Procesando señal 1002/2992
   Procesando señal 1003/2992
   Procesando señal 1004/2992
   Procesando señal 1005/2992
   Procesando señal 1006/2992
   Procesando señal 1007/2992
   Procesando señal 1008/2992
   Procesando señal 1009/2992
   Procesando señal 1010/2992
   Procesando señal 1011/2992
   Procesando señal 1012/2992
   Procesando señal 1013/2992
   Procesando señal 1014/2992
   Procesando señal 1015/2992
   Procesando señal 1016/2992
   Procesando señal 1017/2992
   Procesando señal 1018/2992
   Procesando señal 1019/2992
   Procesando señal 1020/2992
   Procesando señal 1021/2992
   Procesando señal 1022/2992
   Procesando señal 1023/2

In [14]:
!ls "/content/drive/MyDrive/codigo_eeg_redes"


Conv_MDB_Diagram.png			  Proyecto_RN.ipynb
data					  README.md
inference_results_moda_microaverage.json  requirements.txt
mc_dropout_results			  run_model.ipynb
model_weights				  SEED_diagram.png
model_weights_best_original


# MC dropout

In [12]:
import torch
import torch.nn as nn
import numpy as np
import h5py
import json
from pathlib import Path
import gc
from torch.utils.data import Dataset

# ===================================================================
# CONFIGURACIÓN DEL MODELO (misma que en entrenamiento)
# ===================================================================

EMBED_DIM = 64
EMBED_BLOCKS = 1
FILTER_IN = 64
BLOCK_DILATATION = 2
RESOL_FACTOR = 8
BATCH_SIZE = 32
NUM_WORKERS = 12

EMBEDDING_CONFIG_SEED = {
    "embed_dim": EMBED_DIM,
    "block_depth": 2,
    "blocks": EMBED_BLOCKS,
    "dropout": 0.0,
    "filters_in": FILTER_IN,
    "block_dilatation": BLOCK_DILATATION,
}

CLASSIFIER_CONFIG = {
    "embed_dim": 256,
    "hide_dim": 128,
}

# Paths
h5_file_path = "/content/drive/MyDrive/codigo_eeg_redes/data/moda_augmented_s5.h5"
folds_json_path = "/content/drive/MyDrive/codigo_eeg_redes/data/moda_folds.json"
weights_save_dir = "/content/drive/MyDrive/codigo_eeg_redes/model_weights_best_original"  # o model_weights

mc_dropout_dir = Path("/content/drive/MyDrive/codigo_eeg_redes/mc_dropout_results")
mc_dropout_dir.mkdir(parents=True, exist_ok=True)

print("🔍 CONFIGURACIÓN DE INFERENCIA:")
print(f"   📁 Dataset: {h5_file_path}")
print(f"   📁 Folds: {folds_json_path}")
print(f"   💾 Pesos: {weights_save_dir}")

# ===================================================================
# FUNCIONES PARA MC DROPOUT
# ===================================================================

def enable_mc_dropout(model):
    """
    Fuerza a todas las capas Dropout a estar en modo entrenamiento durante la inferencia.
    Esto mantiene la aleatoriedad para MC Dropout.
    """
    for module in model.modules():
        if module.__class__.__name__.startswith('Dropout'):
            module.train()



# ===================================================================
# DATASET H5 PARA INFERENCIA
# ===================================================================

class H5InferenceDataset(Dataset):
    def __init__(self, h5_file_path, subject_ids=None):
        self.h5_file_path = h5_file_path

        with h5py.File(h5_file_path, 'r') as f:
            self.total_samples = len(f['signals'])
            self.signal_length = f['signals'].shape[1]
            self.label_length = f['labels'].shape[1]

            all_subject_ids = [s.decode('utf-8') for s in f['subject_ids'][:]]

            if subject_ids is not None:
                self.valid_indices = []
                for i, subj_id in enumerate(all_subject_ids):
                    if subj_id in subject_ids:
                        self.valid_indices.append(i)
            else:
                self.valid_indices = list(range(self.total_samples))

        print(f"📊 Dataset de inferencia creado:")
        print(f"   Total muestras: {len(self.valid_indices)}")

    def __len__(self):
        return len(self.valid_indices)

    def __getitem__(self, idx):
        real_idx = self.valid_indices[idx]

        with h5py.File(self.h5_file_path, 'r') as f:
            signal = f['signals'][real_idx]
            labels = f['labels'][real_idx]

            signal = signal.astype(np.float32)
            labels = labels.astype(np.float32)

            return real_idx, signal, labels

def create_batch_inference(data_list):
    """Función para crear batches para inferencia"""
    data_list = [data for data in data_list if data is not None]

    if len(data_list) == 0:
        return torch.empty(0, dtype=torch.long), torch.empty(0, 0), torch.empty(0, 0)

    ids, signals, labels = zip(*data_list)

    ids_tensor = torch.LongTensor(np.array(ids))
    signals_tensor = torch.FloatTensor(np.array(signals))
    labels_tensor = torch.FloatTensor(np.array(labels))

    return ids_tensor, signals_tensor, labels_tensor

# ===================================================================
# FUNCIONES DE Post PROCESAMIENTO
# ===================================================================

def process_spindle_simple(det_out, dur_out):
    """Versión simplificada del procesamiento de spindles"""
    new_output = dur_out.copy()

    F_RESAMPLE = 200  # Frecuencia de muestreo
    min_duration = int(F_RESAMPLE * 0.3)  # 0.3 segundos
    max_duration = int(F_RESAMPLE * 5)    # 5 segundos

    # Conectar detecciones cercanas
    i_last_label = 0
    in_event = False

    for i, label in enumerate(dur_out):
        if label == 1 and not in_event:
            if i_last_label != 0 and i - i_last_label < min_duration:
                new_output[i_last_label: i] = np.ones(i - i_last_label)
            in_event = True
        elif label == 0 and in_event:
            i_last_label = i
            in_event = False

    # Eliminar eventos muy cortos o largos
    i_start_label = 0
    in_event = False

    for i, label in enumerate(new_output):
        if label == 1 and not in_event:
            i_start_label = i
            in_event = True
        elif (label == 0 or i + 1 == len(new_output)) and in_event:
            if i - i_start_label < min_duration:
                new_output[i_start_label: i] = np.zeros(i - i_start_label)
            elif i - i_start_label > max_duration:
                p0 = (i - i_start_label - max_duration) // 2 + i_start_label
                pf = p0 + max_duration
                new_output[i_start_label: p0] = np.zeros(p0 - i_start_label)
                new_output[pf: i] = np.zeros(i - pf)
            in_event = False

    return new_output

def get_results_simple(labels, output):

    label_stamps = []
    predict_stamps = []

    in_predict_event = 0
    in_label_event = 0

    for i, (label, predict) in enumerate(zip(labels, output)):
        # Eventos reales
        if not in_label_event and label == 1:
            in_label_event = i
        elif in_label_event and (label == 0 or i + 1 == len(labels)):
            label_stamps.append((in_label_event, i))
            in_label_event = 0
            if (not in_predict_event and len(predict_stamps) < len(label_stamps)) or len(predict_stamps) + 1 < len(label_stamps):
                predict_stamps.append(0)

        # Eventos predichos
        if not in_predict_event and predict == 1:
            in_predict_event = i
        elif in_predict_event and (predict == 0 or i + 1 == len(output)):
            predict_stamps.append((in_predict_event, i))
            in_predict_event = 0
            if (not in_label_event and len(label_stamps) < len(predict_stamps)) or len(label_stamps) + 1 < len(predict_stamps):
                label_stamps.append(0)

    result = []
    # Calcular IoUs
    for label, predict in zip(label_stamps, predict_stamps):
        if not label and predict:  # Falso positivo
            result.append(-1)
        elif not predict and label:  # Falso negativo
            result.append(0)
        else:  # IoU
            inter = min(label[1], predict[1]) - max(label[0], predict[0])
            union = max(label[1], predict[1]) - min(label[0], predict[0])
            result.append(inter / union)

    return result

def calculate_metrics_simple(results, thresholds, threshold=0.2):
    """Versión simplificada de calculate_metrics"""
    true_positives = np.zeros_like(thresholds, dtype=float)
    false_positives = np.zeros_like(thresholds, dtype=float)
    false_negatives = np.zeros_like(thresholds, dtype=float)

    threshold_index = np.argmin(np.abs(thresholds - threshold))

    for i, t in enumerate(thresholds):
        for IoU in results:
            if IoU == -1:
                false_positives[i] += 1
            elif IoU == 0:
                false_negatives[i] += 1
            elif IoU >= t:
                true_positives[i] += 1
            else:
                false_negatives[i] += 1

    with np.errstate(divide='ignore', invalid='ignore'):
        precision = np.divide(true_positives, (true_positives + false_positives))
        recall = np.divide(true_positives, (true_positives + false_negatives))

        precision = np.nan_to_num(precision, nan=0.0)
        recall = np.nan_to_num(recall, nan=0.0)

        f1_score = np.divide(2 * precision * recall, (precision + recall),
                           out=np.zeros_like(recall), where=(precision + recall) != 0)

    return threshold_index, precision, recall, f1_score
# ===================================================================
# FUNCIÓN PRINCIPAL DE INFERENCIA
# ===================================================================

def run_inference_for_fold(eval_num, fold_num, test_subjects, model_weights_dir):
    """Ejecuta inferencia para un fold específico SIN DataLoader"""

    print(f"\n🔍 INFERENCIA EVAL {eval_num} - FOLD {fold_num}")
    print(f"   Test subjects: {test_subjects}")

    # Limpiar memoria
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

    # Crear dataset de test (SIN DataLoader)
    test_dataset = H5InferenceDataset(h5_file_path, subject_ids=test_subjects)

    # Cargar modelo
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = SEED_Model().to(device)

    # Cargar pesos

    weights_folder = Path(model_weights_dir) / f"eval_{eval_num}" / f"fold_{fold_num}"

    try:
        model.embedding.load_state_dict(torch.load(weights_folder / "embedding_best.pth", map_location=device), strict=False) # Stric=False para no cargar pesos innecesarios de los embeddings originales
        model.transformer.load_state_dict(torch.load(weights_folder / "lstm_best.pth", map_location=device))
        model.classifier.load_state_dict(torch.load(weights_folder / "classifier_best.pth", map_location=device))
        print(f"   ✅ Pesos cargados desde: {weights_folder}")
    except Exception as e:
        print(f"   ❌ Error cargando pesos: {e}")
        return []

    # Modo evaluación estándar
    model.eval()

    # 🔹 ACTIVAR MC DROPOUT
    enable_mc_dropout(model)
    print("🔹 MC Dropout activado: las capas Dropout permanecerán activas durante la inferencia.")


    all_results = []
    total_signals = len(test_dataset)

    # ✅ ITERAR DIRECTAMENTE SOBRE EL DATASET
    with torch.no_grad():
        for s, (_, comp_signal, comp_labels) in enumerate(test_dataset):

            print(f"   Procesando señal {s+1}/{total_signals}")

            # Procesar señal completa
            comp_output = process_full_signal(model, comp_signal, device, signal_idx=s) #LOGITS

            # Aplicar thresholds
            det_out = np.where(comp_output >= 0.5, np.ones_like(comp_output), np.zeros_like(comp_output))
            dur_out = np.where(comp_output >= 0.425, np.ones_like(comp_output), np.zeros_like(comp_output))

            # Procesar detecciones
            processed_out = process_spindle_simple(det_out, dur_out)

            # Calcular IoUs
            sample_results = get_results_simple(comp_labels.astype(int), processed_out)
            all_results.extend(sample_results)

    print(f"\n   📊 Total resultados: {len(all_results)}")

    # Limpiar memoria
    del model
    del test_dataset

    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

    return all_results

def process_full_signal(model, comp_signal, device, signal_idx):
    """Procesa una señal completa usando segmentos (como en detections_2.py)"""

    F_RESAMPLE = 200
    SIGNAL_WINDOW = 20  # 20 segundos
    INFERENCE_STRIDE = 10  # 10 segundos

    # Cantidad de segmentos
    segments = int(len(comp_signal) / F_RESAMPLE - SIGNAL_WINDOW) // INFERENCE_STRIDE + 1

    for i in range(segments):
        p0 = INFERENCE_STRIDE * F_RESAMPLE * i
        pf = p0 + SIGNAL_WINDOW * F_RESAMPLE  # 20 segundos desde la pos inicial

        # Obtener output del segmento
        signal_segment = torch.FloatTensor(comp_signal[p0:pf]).unsqueeze(0).to(device)
        output = model(signal_segment)
        output = torch.sigmoid(output).squeeze().cpu().numpy()

        # Concatenar 10 segundos centrales del output
        p1 = F_RESAMPLE * (SIGNAL_WINDOW - INFERENCE_STRIDE) // 2
        p2 = p1 + F_RESAMPLE * INFERENCE_STRIDE

        if i == 0:
            comp_output = output[:p2]
        elif i + 1 == segments:
            comp_output = np.concatenate((comp_output, output[p1:]), axis=0)
        else:
            comp_output = np.concatenate((comp_output, output[p1:p2]), axis=0)

    # Analizar si queda resto de señal
    rest = (len(comp_signal) / F_RESAMPLE - SIGNAL_WINDOW) % INFERENCE_STRIDE
    if rest != 0:
        p0 = len(comp_signal) - SIGNAL_WINDOW * F_RESAMPLE
        signal_segment = torch.FloatTensor(comp_signal[p0:]).unsqueeze(0).to(device)



        # 🔹 MC Dropout: realizar varias predicciones para estimar incertidumbre
        NUM_MC_SAMPLES = 20  # número de pasadas forward, ajusta según tu GPU
        outputs_mc = []

        for _ in range(NUM_MC_SAMPLES):
            mc_output = model(signal_segment)
            mc_output = torch.sigmoid(mc_output).squeeze().cpu().numpy()
            outputs_mc.append(mc_output)

        # Convertir a numpy array: shape = (NUM_MC_SAMPLES, longitud_segmento)
        outputs_mc = np.stack(outputs_mc)

        # Calcular promedio y desviación estándar
        output_mean = outputs_mc.mean(axis=0)
        output_std = outputs_mc.std(axis=0)  # Incertidumbre

        # Guardar la incertidumbre para cada señal analizada
        np.save(mc_dropout_dir / f"mc_std_signal{signal_idx+1}.npy", output_std)
        np.save(mc_dropout_dir / f"mc_outputs_signal{signal_idx+1}.npy", outputs_mc)      # todas las predicciones (opcional)



        # Usar el promedio para el pipeline original
        output = output_mean



        comp_output = np.concatenate((comp_output, output[-int(rest * F_RESAMPLE):]), axis=0)

    return comp_output

def main_inference():
    """Función principal que ejecuta toda la inferencia SIN DataLoader"""

    print("🚀 INICIANDO INFERENCIA COMPLETA")
    print("="*60)

    # Cargar configuración de folds
    with open(folds_json_path, 'r') as f:
        folds_data = json.load(f)

    all_fold_results = []
    evaluations = [1, 2, 3]
    folds = [1, 2, 3, 4, 5]

    total_folds = len(evaluations) * len(folds)
    current_fold = 0

    # Iterar sobre todas las evaluaciones y folds
    for eval_num in evaluations:
        for fold_num in folds:
            current_fold += 1

            eval_key = f"eval_{eval_num}"
            fold_key = f"fold_{fold_num}"

            test_subjects = folds_data[eval_key][fold_key]['test']

            print(f"\n{'='*60}")
            print(f"🎯 PROCESANDO {current_fold}/{total_folds}: EVAL {eval_num} - FOLD {fold_num}")
            print(f"{'='*60}")

            # Ejecutar inferencia para este fold
            fold_results = run_inference_for_fold(eval_num, fold_num, test_subjects, weights_save_dir)

            if fold_results:
                all_fold_results.extend(fold_results)
                print(f"✅ Fold completado - Resultados agregados: {len(fold_results)}")
            else:
                print(f"❌ Error en fold {eval_num}-{fold_num}")

    # Calcular métricas finales (microaverage)
    print(f"\n{'='*60}")
    print("📊 CALCULANDO MÉTRICAS FINALES")
    print(f"{'='*60}")

    if all_fold_results:
        thresholds = np.linspace(0, 1, 100)
        index, precision, recall, f1_score = calculate_metrics_simple(all_fold_results, thresholds)

        final_metrics = {
            'total_results': len(all_fold_results),
            'f1_score': float(f1_score[index] * 100),
            'precision': float(precision[index] * 100),
            'recall': float(recall[index] * 100),
            'threshold_iou': 0.2,
            'evaluations': evaluations,
            'folds': folds,
            'microaverage': True
        }

        print(f"📊 RESULTADOS FINALES (Microaverage):")
        print(f"   Total resultados: {final_metrics['total_results']}")
        print(f"   F1-Score: {final_metrics['f1_score']:.3f}%")
        print(f"   Precision: {final_metrics['precision']:.3f}%")
        print(f"   Recall: {final_metrics['recall']:.3f}%")

        # Guardar resultados
        output_file = "inference_results_moda_microaverage_mc_dropout.json"
        with open(output_file, 'w') as f:
            json.dump(final_metrics, f, indent=4)

        print(f"\n💾 Resultados guardados en: {output_file}")

    else:
        print("❌ No se obtuvieron resultados de ningún fold")

# Ejecutar inferencia
if __name__ == "__main__":
    main_inference()

[1;30;43mSe truncaron las últimas líneas 5000 del resultado de transmisión.[0m
   Procesando señal 994/2992
   Procesando señal 995/2992
   Procesando señal 996/2992
   Procesando señal 997/2992
   Procesando señal 998/2992
   Procesando señal 999/2992
   Procesando señal 1000/2992
   Procesando señal 1001/2992
   Procesando señal 1002/2992
   Procesando señal 1003/2992
   Procesando señal 1004/2992
   Procesando señal 1005/2992
   Procesando señal 1006/2992
   Procesando señal 1007/2992
   Procesando señal 1008/2992
   Procesando señal 1009/2992
   Procesando señal 1010/2992
   Procesando señal 1011/2992
   Procesando señal 1012/2992
   Procesando señal 1013/2992
   Procesando señal 1014/2992
   Procesando señal 1015/2992
   Procesando señal 1016/2992
   Procesando señal 1017/2992
   Procesando señal 1018/2992
   Procesando señal 1019/2992
   Procesando señal 1020/2992
   Procesando señal 1021/2992
   Procesando señal 1022/2992
   Procesando señal 1023/2992
   Procesando señal 1024/

# Borrador

In [None]:
model = SEED_Model()

for name, param in model.named_parameters():
    print(name)
    print(param.shape)
    # print(name, param.shape)

[(64, 1), (32, 2), (16, 4), (16, 8)]
[(128, 1), (64, 2), (32, 4), (32, 8)]
embedding.conv_layer.bn.weight
torch.Size([1])
embedding.conv_layer.bn.bias
torch.Size([1])
embedding.conv_layer.ln.weight
torch.Size([1])
embedding.conv_layer.ln.bias
torch.Size([1])
embedding.conv_layer.conv_blocks.0.convs.0.weight
torch.Size([64, 1, 3])
embedding.conv_layer.conv_blocks.0.convs.0.bias
torch.Size([64])
embedding.conv_layer.conv_blocks.0.convs.1.weight
torch.Size([64, 64, 3])
embedding.conv_layer.conv_blocks.0.convs.1.bias
torch.Size([64])
embedding.conv_layer.conv_blocks.0.bns.0.weight
torch.Size([64])
embedding.conv_layer.conv_blocks.0.bns.0.bias
torch.Size([64])
embedding.conv_layer.conv_blocks.0.bns.1.weight
torch.Size([64])
embedding.conv_layer.conv_blocks.0.bns.1.bias
torch.Size([64])
embedding.layer_dilatations.0.branches.0.0.weight
torch.Size([64, 64, 3])
embedding.layer_dilatations.0.branches.0.1.weight
torch.Size([64])
embedding.layer_dilatations.0.branches.0.1.bias
torch.Size([64])
em