# Stage 1: Teacher Models Training

This notebook executes the training pipeline for the initial "Teacher" models (`cnn_paper`, `cnn_paper_L2`, `resnet50`, `resnet18`).

**Objective:** To generate competent base classifiers using only the labeled subset (`train_val`) created by the data preparation script. These models will later serve as "Teachers" to generate pseudo-labels for the unlabeled data in the semi-supervised stage.

In [None]:
# ==============================================================================
# CELL 2: Imports & Setup
# ==============================================================================
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import torchvision.transforms as T
import time
import json
import random
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, recall_score, confusion_matrix

# --- Local Modules (Relative Paths handled in config.py) ---
import config
import models
import utils

# Ensure reproducibility
torch.manual_seed(config.SEED)
np.random.seed(config.SEED)

In [None]:
# ==============================================================================
# CELL 3: Training Utilities and Helper Classes
# ==============================================================================

class ThesisHelper:
    """
    Handles experiment logging, checkpoint saving, and artifact generation 
    (plots, tables) for thesis documentation.
    """
    def __init__(self, params, class_names, base_dir, run_type='teacher'):
        self.params = params
        self.class_names = class_names
        
        # Construct unique run tag based on hyperparameters
        lr = self.params['LR']
        wd = self.params['WEIGHT_DECAY']
        hparams_tag = f"_lr{lr}_wd{wd}"
        
        if run_type == 'student':
            base_name = f"student_trained_with_win_teachers{self.params['MODEL_NAME']}"
        else: 
            base_name = f"{run_type}_{self.params['MODEL_NAME']}"
            
        self.run_name = base_name + hparams_tag
        self.output_dir = Path(base_dir) / self.run_name
        self.output_dir.mkdir(parents=True, exist_ok=True)
        self.run_type = run_type
        
        self.history = []
        self.best_f1_macro = -1.0
        self.best_epoch_metrics = None
        
        print(f"[INFO] ThesisHelper initialized. Artifacts dir: {self.output_dir}")

    def log_epoch(self, model, metrics):
        self.history.append(metrics)
        current_f1_macro = metrics['f1m']
        
        if current_f1_macro > self.best_f1_macro:
            self.best_f1_macro = current_f1_macro
            self.best_epoch_metrics = metrics
            print(f"   >>> [NEW BEST] F1-Macro: {self.best_f1_macro:.4f} (Epoch {metrics['epoch']}). Saving checkpoint...")
            self._save_checkpoint(model)

    def _save_checkpoint(self, model):
        torch.save(model.state_dict(), self.output_dir / 'best_model.pth')

    def finalize(self, total_duration_seconds):
        if not self.history:
            print("[WARN] No history to finalize.")
            return

        # 1. Save History CSV
        history_df = pd.DataFrame(self.history)
        history_df.to_csv(self.output_dir / 'training_history.csv', index=False)
        
        # 2. Save JSON Summary
        summary = self.best_epoch_metrics.copy()
        summary['total_duration_min'] = total_duration_seconds / 60
        cm = summary.pop('cm', None) 
        
        with open(self.output_dir / 'summary.json', 'w') as f:
            json.dump(summary, f, indent=4)
        
        # 3. Generate Artifacts
        self._plot_curves(history_df)
        self._generate_latex_table(summary)
        self._log_to_excel(summary, cm)
        
        print(f"[INFO] Experiment finalized. All artifacts saved in {self.output_dir}")

    def _plot_curves(self, df):
        best_epoch = self.best_epoch_metrics['epoch']
        
        plt.style.use('seaborn-v0_8-whitegrid')
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8), sharex=True)
        
        # Loss Curve
        ax1.plot(df['epoch'], df['tr_loss'], 'o-', label='Training Loss', markersize=4)
        ax1.plot(df['epoch'], df['loss'], 'o-', label='Validation Loss', markersize=4)
        ax1.axvline(x=best_epoch, color='r', linestyle='--', alpha=0.5, label=f'Best Epoch ({best_epoch})')
        ax1.set_ylabel('Loss')
        ax1.set_title(f'Training Curves: {self.run_name}')
        ax1.legend()
        ax1.grid(True, linestyle='--', alpha=0.6)
        
        # Metrics Curve
        ax2.plot(df['epoch'], df['tr_acc'], 'o-', label='Train Accuracy', markersize=4, alpha=0.7)
        ax2.plot(df['epoch'], df['acc'], 'o-', label='Val Accuracy', markersize=4, alpha=0.7)
        ax2.plot(df['epoch'], df['f1m'], 'o-', label='Val F1-Macro', linewidth=2, markersize=6, color='green')
        ax2.axvline(x=best_epoch, color='r', linestyle='--', alpha=0.5)
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Metric Score')
        ax2.legend()
        ax2.grid(True, linestyle='--', alpha=0.6)
        ax2.yaxis.set_major_formatter(mticker.PercentFormatter(xmax=1.0))
        
        plt.tight_layout()
        plt.savefig(self.output_dir / 'training_curves.png', dpi=300)
        plt.close()

    def _generate_latex_table(self, summary):
        # Generates a ready-to-copy LaTeX table for the paper
        latex_str = f"""
\\begin{{table}}[h!]
\\centering
\\caption{{Training summary for model {self.run_name.replace('_', ' ')}. Metrics reported at the best epoch.}}
\\label{{tab:training_summary_{self.run_name}}}
\\begin{{tabular}}{{ll}}
\\hline
\\textbf{{Parameter}} & \\textbf{{Value}} \\\\
\\hline
Base Architecture & {self.params['MODEL_NAME']} \\\\
Best Epoch & {summary['epoch']} \\\\
Total Duration (min) & {summary['total_duration_min']:.2f} \\\\
\\hline
\\textbf{{Validation Metrics}} & \\textbf{{Value}} \\\\
\\hline
F1-Macro (Best) & {summary['f1m']:.4f} \\\\
Accuracy & {summary['acc']:.4f} \\\\
Loss & {summary['loss']:.4f} \\\\
Recall (Macro) & {summary['recm']:.4f} \\\\
\\hline
\\end{{tabular}}
\\end{{table}}
        """
        with open(self.output_dir / 'summary_table.tex', 'w') as f:
            f.write(latex_str)

    def _log_to_excel(self, summary, cm):
        note = f"{self.run_type.capitalize()} - Best Checkpoint at Epoch {summary['epoch']}"
        
        metrics_to_log = {
            'carrier': config.CURRENT_CARRIER,
            'model_name': self.params['MODEL_NAME'],
            'run_tag': self.run_name,
            'num_classes': len(self.class_names),
            'acc': summary['acc'],
            'loss': summary['loss'],
            'f1m': summary['f1m'],
            'f1w': summary['f1w'],
            'recm': summary['recm'],
            'cm': cm,
            'epochs': self.params['EPOCHS'],
            'batch_size': self.params['BATCH_SIZE'],
            'lr': self.params['LR'],
            'weight_decay': self.params['WEIGHT_DECAY'],
            'notes': note
        }
        utils.log_metrics_excel(config.METRICS_FILE, config.ARTIFACTS_DIR, self.class_names, metrics_to_log)

# --- Data Augmentation ---
class RandomTimeShift(torch.nn.Module):
    """Applies a random cyclic shift along the time axis."""
    def __init__(self, max_frac=0.1):
        super().__init__()
        self.max_frac = max_frac
    def forward(self, x):
        _, H, W = x.shape
        s = int(random.uniform(-self.max_frac, self.max_frac) * W)
        return torch.roll(x, shifts=s, dims=-1)

class RandomGain(torch.nn.Module):
    """Applies a random gain (multiplication) to pixel intensity."""
    def __init__(self, a=0.95, b=1.05):
        super().__init__()
        self.a = a
        self.b = b
    def forward(self, x):
        g = random.uniform(self.a, self.b)
        return (x * g).clamp(0, 1)

weak_aug = T.Compose([
    RandomTimeShift(0.08),
    RandomGain(0.95, 1.05),
])

class LabeledSpectro(Dataset):
    def __init__(self, files, labels, transform=None):
        self.files = files
        self.labels = labels
        self.transform = transform
    def __len__(self):
        return len(self.files)
    def __getitem__(self, i):
        path = self.files[i]
        try:
            x = utils.load_png_gray(path)
            if self.transform:
                x = self.transform(x)
            y = self.labels[i]
            return x, y
        except Exception as e:
            print(f"[WARN] Skipping file due to error: {path} -> {e}")
            # Return dummy tensor to maintain batch integrity
            return torch.zeros(1, config.IMG_SIZE[0], config.IMG_SIZE[1]), self.labels[i]

def maybe_resize_for_resnet(x, should_resize):
    if should_resize:
        return torch.nn.functional.interpolate(x, size=(224, 224), mode="bilinear", align_corners=False)
    return x

# --- Early Stopping & Evaluation ---
class EarlyStopping:
    def __init__(self, patience, min_delta, mode='max', restore_best=True):
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.restore_best = restore_best
        self.best = -float('inf') if mode == 'max' else float('inf')
        self.wait = 0
        self.best_state = None

    def step(self, metric, model):
        is_better = (metric > self.best + self.min_delta) if self.mode == 'max' else (metric < self.best - self.min_delta)
        if is_better:
            self.best = metric
            self.wait = 0
            if self.restore_best:
                self.best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            return False 
        self.wait += 1
        return self.wait >= self.patience

    def restore(self, model):
        if self.restore_best and self.best_state is not None:
            model.load_state_dict(self.best_state)

def evaluate(model, loader, criterion, params):
    device = torch.device(config.DEVICE)
    model.eval()
    va_loss, preds, gts = 0.0, [], []
    with torch.no_grad():
        for xb, yb in loader:
            xb = maybe_resize_for_resnet(xb, params.get('RESIZE_224', False))
            xb, yb = xb.to(device), yb.to(device)
            logits = model(xb)
            loss = criterion(logits, yb)
            va_loss += loss.item() * xb.size(0)
            preds.append(logits.softmax(1).argmax(1).cpu())
            gts.append(yb.cpu())
            
    va_loss /= len(loader.dataset)
    y_pred = torch.cat(preds).numpy()
    y_true = torch.cat(gts).numpy()
    
    return {
        'loss': va_loss,
        'acc': accuracy_score(y_true, y_pred),
        'f1m': f1_score(y_true, y_pred, average='macro', zero_division=0),
        'f1w': f1_score(y_true, y_pred, average='weighted', zero_division=0),
        'recm': recall_score(y_true, y_pred, average='macro', zero_division=0),
        'cm': confusion_matrix(y_true, y_pred)
    }

def run_training_session(params, train_loader, val_loader, num_classes, class_names, run_type):
    device = torch.device(config.DEVICE)
    helper = ThesisHelper(params, class_names, base_dir=config.ARTIFACTS_DIR, run_type=run_type)
    
    model = models.make_model(
        params['MODEL_NAME'], 
        num_classes, 
        params.get('USE_PRETRAIN', True)
    ).to(device)
    
    opt = torch.optim.SGD(model.parameters(), lr=params['LR'], momentum=params['MOMENTUM'], weight_decay=params['WEIGHT_DECAY'])
    crit = nn.CrossEntropyLoss()
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=params['EPOCHS'], eta_min=params['LR'] * config.ETA_MIN_FACTOR)
    es = EarlyStopping(patience=params['PATIENCE'], min_delta=config.EARLY_STOPPING_CONFIG['min_delta'], restore_best=False)
    
    t0 = time.time()
    for ep in range(1, params['EPOCHS'] + 1):
        model.train()
        tr_loss, n = 0.0, 0
        tr_preds, tr_gts = [], []
        
        for xb, yb in train_loader:
            xb = maybe_resize_for_resnet(xb, params.get('RESIZE_224', False))
            xb, yb = xb.to(device, non_blocking=True), yb.to(device, non_blocking=True)
            
            opt.zero_grad(set_to_none=True)
            logits = model(xb)
            loss = crit(logits, yb)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.CLIP_MAX_NORM)
            opt.step()
            
            tr_loss += loss.item() * xb.size(0)
            n += xb.size(0)
            tr_preds.append(logits.softmax(1).argmax(1).cpu())
            tr_gts.append(yb.cpu())
            
        tr_loss /= n
        sched.step()
        
        # Calculate Train Accuracy
        tr_acc = accuracy_score(torch.cat(tr_gts).numpy(), torch.cat(tr_preds).numpy())
        
        # Validation
        val_metrics = evaluate(model, val_loader, crit, params)
        monitor_metric_val = val_metrics['f1m'] # Defaulting to f1m as per config
        
        log_entry = {
            'epoch': ep, 'tr_loss': tr_loss, 'tr_acc': tr_acc, 
            **val_metrics, 'lr': sched.get_last_lr()[0]
        }
        helper.log_epoch(model, log_entry)
        
        print(f"[{params['MODEL_NAME']}] Ep {ep:03d}/{params['EPOCHS']} | Train Loss: {tr_loss:.4f} | Val Loss: {val_metrics['loss']:.4f} | Val F1-Macro: {monitor_metric_val:.4f}")
        
        if es.step(monitor_metric_val, model):
            print(f"[INFO] Early stopping triggered at epoch {ep}.")
            break
            
    dur = time.time() - t0
    helper.finalize(dur)
    
    # Reload best model
    best_model_path = helper.output_dir / 'best_model.pth'
    if best_model_path.exists():
        model.load_state_dict(torch.load(best_model_path))
        print(f"[DONE] Best model loaded (F1-Macro: {helper.best_f1_macro:.4f}).")
        
    return {'model': model, 'helper': helper}

In [None]:
# ==============================================================================
# CELL 4: Comparative Analysis Artifacts
# ==============================================================================

def generate_teacher_comparison_artifacts(results_dir, teacher_keys):
    """
    Aggregates summaries from all teachers and generates comparison charts and LaTeX tables.
    """
    print("\n--- Generating Teacher Comparison Artifacts ---")
    summaries = []
    
    for key in teacher_keys:
        # Note: We need to reconstruct the folder name. Assuming standard params for identification.
        # Ideally, we should track the exact output folder names during training.
        # For this script, we scan the directory for folders matching the model name.
        found = False
        for path in Path(results_dir).iterdir():
            if path.is_dir() and f"teacher_{key}" in path.name:
                summary_path = path / "summary.json"
                if summary_path.exists():
                    with open(summary_path, 'r') as f:
                        data = json.load(f)
                        data['model_key'] = key
                        summaries.append(data)
                        found = True
                        break
        if not found:
            print(f"[WARN] Summary not found for teacher: {key}")

    if not summaries:
        print("[ERROR] No summaries found. Cannot generate comparison.")
        return

    df = pd.DataFrame(summaries).sort_values('f1m', ascending=False).reset_index(drop=True)
    
    # 1. Bar Chart
    plt.style.use('seaborn-v0_8-whitegrid')
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Highlight the best model
    colors = ['#ff7f0e' if i == 0 else '#1f77b4' for i in df.index]
    bars = ax.bar(df['model_key'], df['f1m'], color=colors)
    
    ax.bar_label(bars, fmt='%.4f', padding=3)
    ax.set_ylabel('F1-Score Macro (Validation)')
    ax.set_xlabel('Teacher Model Architecture')
    ax.set_title('Performance Comparison of Teacher Models')
    ax.set_ylim(0, max(df['f1m']) * 1.1)
    
    fig.tight_layout()
    plot_path = Path(results_dir) / "comparison_teachers_performance.png"
    plt.savefig(plot_path, dpi=300)
    print(f"Comparison chart saved to: {plot_path}")
    plt.close()

    # 2. LaTeX Table
    df_latex = df[['model_key', 'f1m', 'acc', 'loss']].copy()
    df_latex.columns = ['Model', 'F1-Macro', 'Accuracy', 'Loss']
    
    # Bold the best model name
    df_latex['Model'] = df_latex.apply(lambda row: f"\\textbf{{{row.Model}}}" if row.name == 0 else row.Model, axis=1)

    latex_table = df_latex.to_latex(
        index=False, 
        float_format="%.4f", 
        caption="Validation performance comparison of Teacher models. The best model (by F1-Macro) is highlighted in bold.", 
        label="tab:teacher_comparison", 
        position="h!"
    )
    
    table_path = Path(results_dir) / "comparison_teachers_table.tex"
    with open(table_path, 'w') as f:
        f.write(latex_table)
    print(f"LaTeX table saved to: {table_path}")

In [None]:
# ==============================================================================
# CELL 5: Main Execution Pipeline
# ==============================================================================

# 1. Data Loading & Splitting
# ---------------------------
if not config.TRAIN_VAL_DIR.exists():
    raise FileNotFoundError(f"Data directory not found at {config.TRAIN_VAL_DIR}. Did you run 00_prepare_data.py?")

class_names = sorted([p.name for p in config.TRAIN_VAL_DIR.iterdir() if p.is_dir()])
cls2idx = {name: i for i, name in enumerate(class_names)}
num_classes = len(class_names)

all_files, all_labels = [], []
for class_name in class_names:
    class_path = config.TRAIN_VAL_DIR / class_name
    files = list(class_path.glob("*.png"))
    all_files.extend(files)
    all_labels.extend([cls2idx[class_name]] * len(files))

# Stratified Split (80% Train, 20% Val)
train_files, val_files, train_labels, val_labels = train_test_split(
    all_files, all_labels, test_size=0.2, random_state=config.SEED, stratify=all_labels
)

train_ds = LabeledSpectro(train_files, train_labels, transform=weak_aug)
val_ds = LabeledSpectro(val_files, val_labels, transform=None)

print(f"[INFO] Dataset Loaded. Train: {len(train_ds)} | Val: {len(val_ds)} | Classes: {num_classes}")

# 2. Training Loop
# ----------------
all_teacher_results = {}
teacher_model_keys = list(config.TRAIN_PARAMS.keys())

for model_key in teacher_model_keys:
    print(f"\n{'-'*20} TRAINING TEACHER: {model_key.upper()} {'-'*20}")
    
    params = config.TRAIN_PARAMS[model_key]
    
    # Update paths in params if necessary or just rely on config
    train_loader = DataLoader(train_ds, batch_size=params['BATCH_SIZE'], shuffle=True, num_workers=config.NUM_WORKERS)
    val_loader = DataLoader(val_ds, batch_size=params['BATCH_SIZE'], shuffle=False, num_workers=config.NUM_WORKERS)
    
    results = run_training_session(
        params, 
        train_loader, 
        val_loader, 
        num_classes, 
        class_names, 
        run_type='teacher'
    )
    
    all_teacher_results[model_key] = results['helper'].best_epoch_metrics

# 3. Final Comparison
# -------------------
generate_teacher_comparison_artifacts(config.ARTIFACTS_DIR, teacher_model_keys)

print("\n--- FINAL TEACHER PERFORMANCE SUMMARY ---")
for model_key, metrics in all_teacher_results.items():
    if metrics:
        print(f"  > {model_key:<15} | Best F1-Macro: {metrics['f1m']:.4f} (Epoch {metrics['epoch']})")