In [1]:
# ==============================================================================
# CELDA 1: Markdown de IntroducciÃ³n
# ==============================================================================
# ## ETAPA 3: Entrenamiento del "Student"
#
# Este notebook entrena el modelo final, o "Student". Su principal caracterÃ­stica 
# es que utiliza un **conjunto de datos de entrenamiento aumentado**.
#
# Este conjunto de datos se compone de:
# 1. Los datos de entrenamiento originales (del `train_val` split).
# 2. Las nuevas pseudo-etiquetas de alta calidad generadas por los "Teachers" y **verificadas por un experto**.
#
# El objetivo es que el "Student", al aprender de mÃ¡s datos, logre un mejor 
# rendimiento y capacidad de generalizaciÃ³n que los "Teachers" originales.

In [2]:
# ==============================================================================
# CELDA 2: Importaciones
# ==============================================================================
# --- Importaciones EstÃ¡ndar ---
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, recall_score, confusion_matrix
import time
from pathlib import Path
import random
import json
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import torchvision.transforms as T

# --- Importar desde nuestros mÃ³dulos locales ---
import config
import models
import utils

In [3]:
# ==============================================================================
# CELDA 2.5: Clase Asistente para Tesis (ThesisHelper)
# ==============================================================================
class ThesisHelper:
    """
    Una clase para gestionar el logging, guardado de checkpoints y generaciÃ³n
    de artefactos para una tesis durante el entrenamiento de un modelo.
    """

    def __init__(self, params, class_names, base_dir, run_type='teacher'):
        self.params = params
        self.class_names = class_names
        # El nombre ahora incluye el tipo de ejecuciÃ³n (teacher o student)
        if run_type == 'student':
            self.run_name = f"student_trained_with_win_teachers{params['MODEL_NAME']}"
        else: # Para los 'teacher'
            self.run_name = f"{run_type}_{params['MODEL_NAME']}"
        self.output_dir = Path(base_dir) / self.run_name
        self.output_dir.mkdir(parents=True, exist_ok=True)
        self.run_type = run_type # Guardamos el tipo
        
        self.history = []
        self.best_f1_macro = -1.0
        self.best_epoch_metrics = None
        
        print(f"ThesisHelper inicializado para '{self.run_name}'. Artefactos se guardarÃ¡n en: {self.output_dir}")


    def log_epoch(self, model, metrics):
        """Registra las mÃ©tricas de una Ã©poca y guarda el mejor checkpoint."""
        self.history.append(metrics)
        current_f1_macro = metrics['f1m']
        
        if current_f1_macro > self.best_f1_macro:
            self.best_f1_macro = current_f1_macro
            self.best_epoch_metrics = metrics
            print(f"ðŸš€ Nuevo mejor F1-Macro: {self.best_f1_macro:.4f} en la Ã©poca {metrics['epoch']}. Guardando checkpoint...")
            self._save_checkpoint(model)

    def _save_checkpoint(self, model):
        """Guarda el state_dict del modelo."""
        torch.save(model.state_dict(), self.output_dir / 'best_model.pth')

    def finalize(self, total_duration_seconds):
        """Genera todos los artefactos finales despuÃ©s del entrenamiento."""
        if not self.history:
            print("No hay historial para finalizar. Saltando la generaciÃ³n de artefactos.")
            return

        # 1. Guardar historial y resumen
        history_df = pd.DataFrame(self.history)
        history_df.to_csv(self.output_dir / 'training_history.csv', index=False)
        
        summary = self.best_epoch_metrics.copy()
        summary['total_duration_min'] = total_duration_seconds / 60
        # La matriz de confusiÃ³n puede ser grande, la guardamos por separado
        cm = summary.pop('cm', None) 
        
        with open(self.output_dir / 'summary.json', 'w') as f:
            json.dump(summary, f, indent=4)
            
        print(f"ðŸ“„ Historial y resumen guardados.")

        # 2. Generar y guardar grÃ¡ficas
        self._plot_curves(history_df)
        print(f"ðŸ“Š GrÃ¡ficas de entrenamiento guardadas.")
        
        # 3. Generar y guardar tabla LaTeX
        self._generate_latex_table(summary, cm)
        print(f"ðŸ“‹ Tabla LaTeX generada.")
        
        # 4. Registrar en el Excel principal
        self._log_to_excel(summary, cm)
        print(f"âœ… MÃ©tricas finales registradas en Excel.")

    def _plot_curves(self, df):
        """Genera y guarda las curvas de entrenamiento y validaciÃ³n."""
        best_epoch = self.best_epoch_metrics['epoch']
        
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), sharex=True)
        fig.suptitle(f'Curvas de Entrenamiento para {self.run_name}', fontsize=16)

        # Subplot 1: PÃ©rdidas (Loss)
        ax1.plot(df['epoch'], df['tr_loss'], 'o-', label='Training Loss')
        ax1.plot(df['epoch'], df['loss'], 'o-', label='Validation Loss')
        ax1.axvline(x=best_epoch, color='r', linestyle='--', label=f'Mejor Ã‰poca ({best_epoch})')
        ax1.set_ylabel('PÃ©rdida (Loss)')
        ax1.legend()
        ax1.grid(True, linestyle='--', alpha=0.6)
        
        # AnotaciÃ³n para la mejor Ã©poca
        best_loss = self.best_epoch_metrics['loss']
        ax1.annotate(f'Mejor F1-Macro\nVal Loss: {best_loss:.4f}',
                     xy=(best_epoch, best_loss),
                     xytext=(best_epoch + 3, best_loss + 0.1*df['loss'].max()),
                     arrowprops=dict(facecolor='black', shrink=0.05, width=1, headwidth=8),
                     bbox=dict(boxstyle="round,pad=0.3", fc="yellow", ec="black", lw=1, alpha=0.7))

        # Subplot 2: MÃ©tricas de Rendimiento
        ax2.plot(df['epoch'], df['tr_acc'], 'o-', label='Training Accuracy')
        ax2.plot(df['epoch'], df['acc'], 'o-', label='Validation Accuracy')
        ax2.plot(df['epoch'], df['f1m'], 'o-', label='Validation F1-Macro', linewidth=2, markersize=8)
        ax2.axvline(x=best_epoch, color='r', linestyle='--')
        ax2.set_xlabel('Ã‰poca')
        ax2.set_ylabel('MÃ©trica')
        ax2.legend()
        ax2.grid(True, linestyle='--', alpha=0.6)
        ax2.yaxis.set_major_formatter(mticker.PercentFormatter(xmax=1.0))
        
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        plt.savefig(self.output_dir / 'training_curves.png', dpi=300)
        plt.close()

    def _generate_latex_table(self, summary, cm):
        """Genera una tabla LaTeX con el resumen del entrenamiento."""
        latex_str = f"""
\\begin{{table}}[h!]
\\centering
\\caption{{Resumen del entrenamiento del modelo {self.run_name.replace('_', ' ')} y mÃ©tricas finales en la mejor Ã©poca.}}
\\label{{tab:training_summary_{self.run_name}}}
\\begin{{tabular}}{{ll}}
\\hline
\\textbf{{ParÃ¡metro}} & \\textbf{{Valor}} \\\\
\\hline
Modelo Base & {self.params['MODEL_NAME']} \\\\
Mejor Ã‰poca & {summary['epoch']} \\\\
DuraciÃ³n Total (min) & {summary['total_duration_min']:.2f} \\\\
\\hline
\\textbf{{MÃ©trica de ValidaciÃ³n}} & \\textbf{{Valor}} \\\\
\\hline
F1-Macro (Mejor) & {summary['f1m']:.4f} \\\\
Exactitud (Accuracy) & {summary['acc']:.4f} \\\\
PÃ©rdida (Loss) & {summary['loss']:.4f} \\\\
Recall (Macro) & {summary['recm']:.4f} \\\\
\\hline
\\end{{tabular}}
\\end{{table}}
        """
        with open(self.output_dir / 'summary_table.tex', 'w') as f:
            f.write(latex_str)

    def _log_to_excel(self, summary, cm):
        """Registra las mÃ©tricas en el archivo Excel principal."""
        metrics_to_log = {
            'carrier': config.CARRIER,
            'model_name': self.params['MODEL_NAME'],
            'run_tag': self.run_name,
            'num_classes': len(self.class_names),
            'acc': summary['acc'],
            'loss': summary['loss'],
            'f1m': summary['f1m'],
            'f1w': summary['f1w'],
            'recm': summary['recm'],
            'cm': cm,
            'epochs': self.params['EPOCHS'],
            'batch_size': self.params['BATCH_SIZE'],
            'lr': self.params['LR'],
            'weight_decay': self.params['WEIGHT_DECAY'],
            'notes': f"Student - Mejor checkpoint en la Ã©poca {summary['epoch']}"
        }
        utils.log_metrics_excel(config.EXCEL_PATH, config.RESULTS_DIR, self.class_names, metrics_to_log)

In [4]:
# ==============================================================================
# CELDA 3: ReutilizaciÃ³n de Clases y Funciones
# ==============================================================================
# Reutilizamos las mismas clases y funciones de entrenamiento del notebook 01
# ya que la lÃ³gica de entrenamiento es idÃ©ntica.

class RandomTimeShift(torch.nn.Module):
    def __init__(self, max_frac=0.1):
        super().__init__()
        self.max_frac = max_frac
    def forward(self, x):
        _, H, W = x.shape
        s = int(random.uniform(-self.max_frac, self.max_frac) * W)
        return torch.roll(x, shifts=s, dims=-1)

class RandomGain(torch.nn.Module):
    def __init__(self, a=0.95, b=1.05):
        super().__init__()
        self.a = a
        self.b = b
    def forward(self, x):
        g = random.uniform(self.a, self.b)
        return (x * g).clamp(0, 1)

weak_aug = T.Compose([
    RandomTimeShift(0.08),
    RandomGain(0.95, 1.05),
])

class LabeledSpectro(Dataset):
    def __init__(self, files, labels, transform=None):
        self.files = files
        self.labels = labels
        self.transform = transform
    def __len__(self):
        return len(self.files)
    def __getitem__(self, i):
        path = self.files[i]
        try:
            x = utils.load_png_gray(path)
            if self.transform:
                x = self.transform(x)
            y = self.labels[i]
            return x, y
        except Exception as e:
            print(f"[dataset_warning] Saltando archivo por error: {e}")
            return torch.zeros(1, config.IMG_H, config.IMG_W, dtype=torch.float32), self.labels[i]

def maybe_resize_for_resnet(x, should_resize):
    if should_resize:
        return torch.nn.functional.interpolate(x, size=(224, 224), mode="bilinear", align_corners=False)
    return x

class EarlyStopping:
    def __init__(self, patience, min_delta, mode='max', restore_best=True):
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.restore_best = restore_best
        self.best = -float('inf') if mode == 'max' else float('inf')
        self.wait = 0
        self.best_state = None

    def step(self, metric, model):
        is_better = (metric > self.best + self.min_delta) if self.mode == 'max' else (metric < self.best - self.min_delta)
        if is_better:
            self.best = metric
            self.wait = 0
            if self.restore_best:
                self.best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            return False
        self.wait += 1
        return self.wait >= self.patience

    def restore(self, model):
        if self.restore_best and self.best_state is not None:
            model.load_state_dict(self.best_state)

def evaluate(model, loader, criterion, params):
    device = torch.device(config.DEVICE)
    model.eval()
    va_loss, preds, gts = 0.0, [], []
    with torch.no_grad():
        for xb, yb in loader:
            xb = maybe_resize_for_resnet(xb, params.get('RESNET_RESIZE_TO_224', False))
            xb, yb = xb.to(device), yb.to(device)
            logits = model(xb)
            loss = criterion(logits, yb)
            va_loss += loss.item() * xb.size(0)
            preds.append(logits.softmax(1).argmax(1).cpu())
            gts.append(yb.cpu())
    va_loss /= len(loader.dataset)
    y_pred = torch.cat(preds).numpy()
    y_true = torch.cat(gts).numpy()
    
    metrics = {
        'loss': va_loss,
        'acc': accuracy_score(y_true, y_pred),
        'f1m': f1_score(y_true, y_pred, average='macro', zero_division=0),
        'f1w': f1_score(y_true, y_pred, average='weighted', zero_division=0),
        'recm': recall_score(y_true, y_pred, average='macro', zero_division=0),
        'cm': confusion_matrix(y_true, y_pred)
    }
    return metrics

##############################################################################################################
# ==============================================================================
# CELDA 3: train_student (Actualizada)
# ==============================================================================
# (AquÃ­ se asume que las clases LabeledSpectro, EarlyStopping, etc. de tu celda 3 original ya estÃ¡n definidas)

def train_student(params, train_loader, val_loader, num_classes, class_names):
    device = torch.device(config.DEVICE)
    torch.manual_seed(config.SEED)
    np.random.seed(config.SEED)

    # Inicializar el Asistente de Tesis
    helper = ThesisHelper(params, class_names, base_dir=config.RESULTS_DIR, run_type='student')

    model = models.make_model(
        params['MODEL_NAME'], 
        num_classes, 
        params.get('RESNET_USE_PRETRAIN', True)
    ).to(device)
    
    opt = torch.optim.SGD(model.parameters(), lr=params['LR'], momentum=params['MOMENTUM'], weight_decay=params['WEIGHT_DECAY'])
    crit = nn.CrossEntropyLoss()
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=params['EPOCHS'], eta_min=params['LR'] * config.ETA_MIN_FACTOR)
    
    es = EarlyStopping(patience=params['PATIENCE'], min_delta=config.MIN_DELTA, restore_best=False) # restore_best=False, ThesisHelper se encarga
    t0 = time.time()

    for ep in range(1, params['EPOCHS'] + 1):
        model.train()
        tr_loss, n = 0.0, 0
        tr_preds, tr_gts = [], [] # Para calcular tr_acc

        for xb, yb in train_loader:
            xb = maybe_resize_for_resnet(xb, params.get('RESNET_RESIZE_TO_224', False))
            xb, yb = xb.to(device, non_blocking=True), yb.to(device, non_blocking=True)
            
            opt.zero_grad(set_to_none=True)
            logits = model(xb)
            loss = crit(logits, yb)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.CLIP_MAX_NORM)
            opt.step()
            
            tr_loss += loss.item() * xb.size(0)
            n += xb.size(0)
            
            # Guardar predicciones y etiquetas para tr_acc
            tr_preds.append(logits.softmax(1).argmax(1).cpu())
            tr_gts.append(yb.cpu())
        
        tr_loss /= n
        sched.step()

        # Calcular tr_acc
        y_pred_tr = torch.cat(tr_preds).numpy()
        y_true_tr = torch.cat(tr_gts).numpy()
        tr_acc = accuracy_score(y_true_tr, y_pred_tr)

        # Evaluar en validaciÃ³n
        val_metrics = evaluate(model, val_loader, crit, params)
        monitor_metric_key = config.MONITOR.replace('_macro', 'm').replace('_weighted', 'w')
        monitor_metric_val = val_metrics[monitor_metric_key]
        
        # Consolidar todas las mÃ©tricas de la Ã©poca
        log_entry = {
            'epoch': ep, 'tr_loss': tr_loss, 'tr_acc': tr_acc, 
            **val_metrics, 'lr': sched.get_last_lr()[0]
        }
        
        # Usar ThesisHelper para registrar y guardar el mejor modelo
        helper.log_epoch(model, log_entry)
        
        print(f"[{params['MODEL_NAME']}] e{ep:03d}/{params['EPOCHS']} | tr_loss={tr_loss:.4f} | va_loss={val_metrics['loss']:.4f} | tr_acc={tr_acc:.4f} | va_acc={val_metrics['acc']:.4f} | {config.MONITOR}={monitor_metric_val:.4f}")

        # Comprobar Early Stopping
        stop = es.step(monitor_metric_val, model)
        if stop:
            print(f"Early stopping en epoch {ep} (mejor {config.MONITOR}={es.best:.4f}).")
            break

    dur = time.time() - t0
    
    # Finalizar el proceso: guardar resÃºmenes, grÃ¡ficas, etc.
    helper.finalize(dur)
    
    # Cargar el mejor modelo guardado por el helper para devolverlo
    best_model_path = helper.output_dir / 'best_model.pth'
    if best_model_path.exists():
        model.load_state_dict(torch.load(best_model_path))
        print(f"Modelo final cargado desde el mejor checkpoint (F1-Macro: {helper.best_f1_macro:.4f}).")

    print(f"[DONE] {params['MODEL_NAME']} | dur={dur/60:.1f} min | best_{config.MONITOR}={helper.best_f1_macro:.4f}")
    
    return {'model': model, 'helper': helper}


In [5]:
# ==============================================================================
# CELDA 4: Bucle Principal de Entrenamiento del Student (Actualizada)
# ==============================================================================

# --- CONFIGURACIÃ“N PARA ESTA EJECUCIÃ“N ---
STUDENT_MODEL_KEY = 'resnet50' 

print(f"Entrenando un Student con la arquitectura '{STUDENT_MODEL_KEY}'...")
print(f"Usando pseudo-etiquetas verificadas de modelo Teacher ganador...")

# --- Carga de datos AUMENTADA ---
class_names = sorted([p.name for p in config.TRAIN_VAL_DIR.iterdir() if p.is_dir()])
cls2idx = {name: i for i, name in enumerate(class_names)}
num_classes = len(class_names)

# 1. Cargar datos de entrenamiento originales (split train/val)
all_original_files, all_original_labels = [], []
for class_name in class_names:
    class_path = config.TRAIN_VAL_DIR / class_name
    files = list(class_path.glob("*.png"))
    all_original_files.extend(files)
    all_original_labels.extend([cls2idx[class_name]] * len(files))

original_train_files, val_files, original_train_labels, val_labels = train_test_split(
    all_original_files, all_original_labels, test_size=0.2, random_state=config.SEED, stratify=all_original_labels
)

# 2. Cargar datos de pseudo-etiquetas de TODOS los teachers
all_pseudo_files = set()
for model_key in config.TRAIN_PARAMS.keys():
    pseudo_root = config.PSEUDO_DIR / model_key
    if pseudo_root.exists():
        print(f"  - Cargando desde '{model_key}'...")
        for class_path in pseudo_root.glob('*'):
             if class_path.is_dir():
                for file_path in class_path.glob("*.png"):
                    all_pseudo_files.add(file_path)

pseudo_files = sorted(list(all_pseudo_files))
pseudo_labels = [cls2idx[fp.parent.name] for fp in pseudo_files if fp.parent.name in cls2idx]
pseudo_files = [fp for fp in pseudo_files if fp.parent.name in cls2idx] # Asegurarse de que coincidan

# 3. Combinar datasets
augmented_train_files = original_train_files + pseudo_files
augmented_train_labels = original_train_labels + pseudo_labels

# 4. Crear Datasets y Dataloaders
train_ds = LabeledSpectro(augmented_train_files, augmented_train_labels, transform=weak_aug)
val_ds = LabeledSpectro(val_files, val_labels, transform=None)

print(f"\nDataset de Entrenamiento Original: {len(original_train_files)} muestras")
print(f"Pseudo-Etiquetas Ãšnicas AÃ±adidas: {len(pseudo_files)} muestras")
print(f"Dataset de Entrenamiento Aumentado Total: {len(train_ds)} muestras")
print(f"Dataset de ValidaciÃ³n: {len(val_ds)} muestras\n")

# --- Entrenamiento del Student ---
params = config.TRAIN_PARAMS[STUDENT_MODEL_KEY]
train_loader = DataLoader(train_ds, batch_size=params['BATCH_SIZE'], shuffle=True, num_workers=0)
val_loader = DataLoader(val_ds, batch_size=params['BATCH_SIZE'], shuffle=False, num_workers=0)

results = train_student(params, train_loader, val_loader, num_classes, class_names)

# --- ImpresiÃ³n final ---
print("\n--- Proceso de Entrenamiento Completado ---")
print("Puedes encontrar todos los artefactos en la carpeta de resultados.")
# Opcional: imprimir la tabla LaTeX directamente en la salida
with open(results['helper'].output_dir / 'summary_table.tex', 'r') as f:
    print("\nTabla LaTeX generada para tu informe:")
    print(f.read())

Entrenando un Student con la arquitectura 'resnet50'...
Usando pseudo-etiquetas verificadas de modelo Teacher ganador...
  - Cargando desde 'resnet50'...

Dataset de Entrenamiento Original: 1049 muestras
Pseudo-Etiquetas Ãšnicas AÃ±adidas: 2416 muestras
Dataset de Entrenamiento Aumentado Total: 3465 muestras
Dataset de ValidaciÃ³n: 263 muestras

ThesisHelper inicializado para 'student_trained_with_win_teachersresnet50'. Artefactos se guardarÃ¡n en: D:\PYTHON\30_CLASIFICADOR_DE_INTERFERENCIAS\RESULTADOS\Carrier_C4_9435\student_trained_with_win_teachersresnet50
ðŸš€ Nuevo mejor F1-Macro: 0.5700 en la Ã©poca 1. Guardando checkpoint...
[resnet50] e001/300 | tr_loss=0.7010 | va_loss=0.6124 | tr_acc=0.7642 | va_acc=0.7985 | f1_macro=0.5700
ðŸš€ Nuevo mejor F1-Macro: 0.7540 en la Ã©poca 2. Guardando checkpoint...
[resnet50] e002/300 | tr_loss=0.2090 | va_loss=0.3635 | tr_acc=0.9354 | va_acc=0.8897 | f1_macro=0.7540
ðŸš€ Nuevo mejor F1-Macro: 0.8991 en la Ã©poca 3. Guardando checkpoint...
[res

  model.load_state_dict(torch.load(best_model_path))
