# Stage 3: Student Model Training

This notebook trains the final "Student" model using a **data augmentation strategy** based on Semi-Supervised Learning.

**Methodology:**
* **Dataset:** A combination of the original labeled data (`train_val`) and high-quality pseudo-labels generated by the "Teacher" models.
* **Objective:** To demonstrate that the Student model, by learning from a larger volume of data (including unverified but high-confidence samples), achieves better generalization performance than the original Teachers.
* **Experiment:** We perform a "Pure Lineage" training, where a Student architecture (e.g., ResNet18) learns specifically from the pseudo-labels generated by its homologous Teacher (ResNet18).

In [None]:
# ==============================================================================
# CELL 2: Imports & Setup
# ==============================================================================
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import torchvision.transforms as T
import time
import json
import random
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, recall_score, confusion_matrix

# --- Local Modules ---
import config
import models
import utils

# Ensure reproducibility
torch.manual_seed(config.SEED)
np.random.seed(config.SEED)

In [None]:
# ==============================================================================
# CELL 3: ThesisHelper Class
# ==============================================================================
class ThesisHelper:
    """
    Manages experiment logging, checkpointing, and artifact generation 
    (plots, LaTeX tables) for thesis documentation.
    """

    def __init__(self, params, class_names, base_dir, run_type='teacher'):
        self.params = params
        self.class_names = class_names
        
        # Construct run name based on execution type
        if run_type == 'student':
            self.run_name = f"student_trained_with_win_teachers{params['MODEL_NAME']}"
        else: 
            self.run_name = f"{run_type}_{params['MODEL_NAME']}"
            
        self.output_dir = Path(base_dir) / self.run_name
        self.output_dir.mkdir(parents=True, exist_ok=True)
        self.run_type = run_type
        
        self.history = []
        self.best_f1_macro = -1.0
        self.best_epoch_metrics = None
        
        print(f"[INFO] ThesisHelper initialized for '{self.run_name}'. Artifacts dir: {self.output_dir}")

    def log_epoch(self, model, metrics):
        """Logs epoch metrics and saves the best checkpoint based on F1-Macro."""
        self.history.append(metrics)
        current_f1_macro = metrics['f1m']
        
        if current_f1_macro > self.best_f1_macro:
            self.best_f1_macro = current_f1_macro
            self.best_epoch_metrics = metrics
            print(f"[INFO] New best F1-Macro: {self.best_f1_macro:.4f} (Epoch {metrics['epoch']}). Saving checkpoint...")
            self._save_checkpoint(model)

    def _save_checkpoint(self, model):
        """Saves the model state dictionary."""
        torch.save(model.state_dict(), self.output_dir / 'best_model.pth')

    def finalize(self, total_duration_seconds):
        """Generates all final artifacts after training completes."""
        if not self.history:
            print("[WARN] No history to finalize. Skipping artifact generation.")
            return

        # 1. Save History and Summary
        history_df = pd.DataFrame(self.history)
        history_df.to_csv(self.output_dir / 'training_history.csv', index=False)
        
        summary = self.best_epoch_metrics.copy()
        summary['total_duration_min'] = total_duration_seconds / 60
        cm = summary.pop('cm', None) 
        
        with open(self.output_dir / 'summary.json', 'w') as f:
            json.dump(summary, f, indent=4)
            
        print("[INFO] History and summary saved.")

        # 2. Generate Plots
        self._plot_curves(history_df)
        print("[INFO] Training curves plot saved.")
        
        # 3. Generate LaTeX Table
        self._generate_latex_table(summary, cm)
        print("[INFO] LaTeX table generated.")
        
        # 4. Log to Main Excel
        self._log_to_excel(summary, cm)
        print("[INFO] Final metrics appended to Master Excel file.")

    def _plot_curves(self, df):
        """Generates and saves training vs validation curves."""
        best_epoch = self.best_epoch_metrics['epoch']
        
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8), sharex=True)
        
        # Subplot 1: Loss
        ax1.plot(df['epoch'], df['tr_loss'], 'o-', label='Training Loss')
        ax1.plot(df['epoch'], df['loss'], 'o-', label='Validation Loss')
        ax1.axvline(x=best_epoch, color='r', linestyle='--', label=f'Best Epoch ({best_epoch})')
        ax1.set_ylabel('Loss')
        ax1.legend()
        ax1.grid(True, linestyle='--', alpha=0.6)
        
        # Subplot 2: Metrics
        ax2.plot(df['epoch'], df['tr_acc'], 'o-', label='Training Accuracy')
        ax2.plot(df['epoch'], df['acc'], 'o-', label='Validation Accuracy')
        ax2.plot(df['epoch'], df['f1m'], 'o-', label='Validation F1-Macro', linewidth=2, markersize=8)
        ax2.axvline(x=best_epoch, color='r', linestyle='--')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Metric')
        ax2.legend()
        ax2.grid(True, linestyle='--', alpha=0.6)
        ax2.yaxis.set_major_formatter(mticker.PercentFormatter(xmax=1.0))
        
        plt.tight_layout()
        plt.savefig(self.output_dir / 'training_curves.png', dpi=300)
        plt.close()

    def _generate_latex_table(self, summary, cm):
        """Generates a LaTeX table summarizing the run."""
        latex_str = f"""
\\begin{{table}}[h!]
\\centering
\\caption{{Training summary for model {self.run_name.replace('_', ' ')}. Metrics reported at the best epoch.}}
\\label{{tab:training_summary_{self.run_name}}}
\\begin{{tabular}}{{ll}}
\\hline
\\textbf{{Parameter}} & \\textbf{{Value}} \\\\
\\hline
Base Architecture & {self.params['MODEL_NAME']} \\\\
Best Epoch & {summary['epoch']} \\\\
Total Duration (min) & {summary['total_duration_min']:.2f} \\\\
\\hline
\\textbf{{Validation Metrics}} & \\textbf{{Value}} \\\\
\\hline
F1-Macro (Best) & {summary['f1m']:.4f} \\\\
Accuracy & {summary['acc']:.4f} \\\\
Loss & {summary['loss']:.4f} \\\\
Recall (Macro) & {summary['recm']:.4f} \\\\
\\hline
\\end{{tabular}}
\\end{{table}}
        """
        with open(self.output_dir / 'summary_table.tex', 'w') as f:
            f.write(latex_str)

    def _log_to_excel(self, summary, cm):
        """Appends metrics to the master Excel file."""
        metrics_to_log = {
            'carrier': config.CURRENT_CARRIER,
            'model_name': self.params['MODEL_NAME'],
            'run_tag': self.run_name,
            'num_classes': len(self.class_names),
            'acc': summary['acc'],
            'loss': summary['loss'],
            'f1m': summary['f1m'],
            'f1w': summary['f1w'],
            'recm': summary['recm'],
            'cm': cm,
            'epochs': self.params['EPOCHS'],
            'batch_size': self.params['BATCH_SIZE'],
            'lr': self.params['LR'],
            'weight_decay': self.params['WEIGHT_DECAY'],
            'notes': f"Student - Best checkpoint at epoch {summary['epoch']}"
        }
        utils.log_metrics_excel(config.METRICS_FILE, config.ARTIFACTS_DIR, self.class_names, metrics_to_log)

In [None]:
# ==============================================================================
# CELL 4: Training Components & Student Loop
# ==============================================================================

class RandomTimeShift(torch.nn.Module):
    def __init__(self, max_frac=0.1):
        super().__init__()
        self.max_frac = max_frac
    def forward(self, x):
        _, H, W = x.shape
        s = int(random.uniform(-self.max_frac, self.max_frac) * W)
        return torch.roll(x, shifts=s, dims=-1)

class RandomGain(torch.nn.Module):
    def __init__(self, a=0.95, b=1.05):
        super().__init__()
        self.a = a
        self.b = b
    def forward(self, x):
        g = random.uniform(self.a, self.b)
        return (x * g).clamp(0, 1)

weak_aug = T.Compose([
    RandomTimeShift(0.08),
    RandomGain(0.95, 1.05),
])

class LabeledSpectro(Dataset):
    def __init__(self, files, labels, transform=None):
        self.files = files
        self.labels = labels
        self.transform = transform
    def __len__(self):
        return len(self.files)
    def __getitem__(self, i):
        path = self.files[i]
        try:
            x = utils.load_png_gray(path)
            if self.transform:
                x = self.transform(x)
            y = self.labels[i]
            return x, y
        except Exception as e:
            print(f"[WARN] Skipping file due to error: {e}")
            # Return dummy tensor to preserve batch integrity
            return torch.zeros(1, config.IMG_SIZE[0], config.IMG_SIZE[1], dtype=torch.float32), self.labels[i]

def maybe_resize_for_resnet(x, should_resize):
    if should_resize:
        return torch.nn.functional.interpolate(x, size=(224, 224), mode="bilinear", align_corners=False)
    return x

class EarlyStopping:
    def __init__(self, patience, min_delta, mode='max', restore_best=True):
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.restore_best = restore_best
        self.best = -float('inf') if mode == 'max' else float('inf')
        self.wait = 0
        self.best_state = None

    def step(self, metric, model):
        is_better = (metric > self.best + self.min_delta) if self.mode == 'max' else (metric < self.best - self.min_delta)
        if is_better:
            self.best = metric
            self.wait = 0
            if self.restore_best:
                self.best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            return False
        self.wait += 1
        return self.wait >= self.patience

    def restore(self, model):
        if self.restore_best and self.best_state is not None:
            model.load_state_dict(self.best_state)

def evaluate(model, loader, criterion, params):
    device = torch.device(config.DEVICE)
    model.eval()
    va_loss, preds, gts = 0.0, [], []
    with torch.no_grad():
        for xb, yb in loader:
            xb = maybe_resize_for_resnet(xb, params.get('RESIZE_224', False))
            xb, yb = xb.to(device), yb.to(device)
            logits = model(xb)
            loss = criterion(logits, yb)
            va_loss += loss.item() * xb.size(0)
            preds.append(logits.softmax(1).argmax(1).cpu())
            gts.append(yb.cpu())
    va_loss /= len(loader.dataset)
    y_pred = torch.cat(preds).numpy()
    y_true = torch.cat(gts).numpy()
    
    return {
        'loss': va_loss,
        'acc': accuracy_score(y_true, y_pred),
        'f1m': f1_score(y_true, y_pred, average='macro', zero_division=0),
        'f1w': f1_score(y_true, y_pred, average='weighted', zero_division=0),
        'recm': recall_score(y_true, y_pred, average='macro', zero_division=0),
        'cm': confusion_matrix(y_true, y_pred)
    }

def train_student(params, train_loader, val_loader, num_classes, class_names):
    device = torch.device(config.DEVICE)
    torch.manual_seed(config.SEED)
    np.random.seed(config.SEED)

    helper = ThesisHelper(params, class_names, base_dir=config.ARTIFACTS_DIR, run_type='student')

    model = models.make_model(
        params['MODEL_NAME'], 
        num_classes, 
        params.get('USE_PRETRAIN', True)
    ).to(device)
    
    opt = torch.optim.SGD(model.parameters(), lr=params['LR'], momentum=params['MOMENTUM'], weight_decay=params['WEIGHT_DECAY'])
    crit = nn.CrossEntropyLoss()
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=params['EPOCHS'], eta_min=params['LR'] * config.ETA_MIN_FACTOR)
    
    es = EarlyStopping(patience=params['PATIENCE'], min_delta=config.EARLY_STOPPING_CONFIG['min_delta'], restore_best=False)
    t0 = time.time()

    for ep in range(1, params['EPOCHS'] + 1):
        model.train()
        tr_loss, n = 0.0, 0
        tr_preds, tr_gts = [], []

        for xb, yb in train_loader:
            xb = maybe_resize_for_resnet(xb, params.get('RESIZE_224', False))
            xb, yb = xb.to(device, non_blocking=True), yb.to(device, non_blocking=True)
            
            opt.zero_grad(set_to_none=True)
            logits = model(xb)
            loss = crit(logits, yb)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.CLIP_MAX_NORM)
            opt.step()
            
            tr_loss += loss.item() * xb.size(0)
            n += xb.size(0)
            
            tr_preds.append(logits.softmax(1).argmax(1).cpu())
            tr_gts.append(yb.cpu())
        
        tr_loss /= n
        sched.step()

        # Calculate Train Accuracy
        y_pred_tr = torch.cat(tr_preds).numpy()
        y_true_tr = torch.cat(tr_gts).numpy()
        tr_acc = accuracy_score(y_true_tr, y_pred_tr)

        # Validation
        val_metrics = evaluate(model, val_loader, crit, params)
        monitor_metric_val = val_metrics['f1m'] # Defaulting to f1m
        
        log_entry = {
            'epoch': ep, 'tr_loss': tr_loss, 'tr_acc': tr_acc, 
            **val_metrics, 'lr': sched.get_last_lr()[0]
        }
        
        helper.log_epoch(model, log_entry)
        
        print(f"[{params['MODEL_NAME']}] Ep {ep:03d}/{params['EPOCHS']} | Tr Loss: {tr_loss:.4f} | Val Loss: {val_metrics['loss']:.4f} | Val F1-Macro: {monitor_metric_val:.4f}")

        if es.step(monitor_metric_val, model):
            print(f"[INFO] Early stopping triggered at epoch {ep}.")
            break

    dur = time.time() - t0
    helper.finalize(dur)
    
    # Reload best model
    best_model_path = helper.output_dir / 'best_model.pth'
    if best_model_path.exists():
        model.load_state_dict(torch.load(best_model_path))
        print(f"[DONE] Final model loaded from best checkpoint (F1-Macro: {helper.best_f1_macro:.4f}).")

    return {'model': model, 'helper': helper}

In [None]:
# ==============================================================================
# CELL 5: Pure Lineage Experiment (Student learns from same-architecture Teacher)
# ==============================================================================

# CONFIGURATION: Define the architectures for the experiment.
ARCHITECTURES = ['resnet18', 'resnet50'] 

print(f"[INFO] STARTING PURE LINEAGE EXPERIMENT")
print(f"[INFO] Architectures to evaluate: {ARCHITECTURES}")
print("[INFO] Strategy: Students learn EXCLUSIVELY from their homologous Teacher.")
print("[INFO] Validation: Fixed 'Gold Standard' (Original Labeled Data).")

# 1. Prepare "Gold Standard" Data (Original Labeled Set)
# -----------------------------------------------------------------
class_names = sorted([p.name for p in config.TRAIN_VAL_DIR.iterdir() if p.is_dir()])
cls2idx = {name: i for i, name in enumerate(class_names)}
num_classes = len(class_names)

all_original_files, all_original_labels = [], []
for class_name in class_names:
    class_path = config.TRAIN_VAL_DIR / class_name
    files = list(class_path.glob("*.png"))
    all_original_files.extend(files)
    all_original_labels.extend([cls2idx[class_name]] * len(files))

# Fixed Validation Split (20% of original)
original_train_files, val_files, original_train_labels, val_labels = train_test_split(
    all_original_files, all_original_labels, 
    test_size=0.2, 
    random_state=config.SEED, 
    stratify=all_original_labels
)

# Common Validation Dataset
val_ds = LabeledSpectro(val_files, val_labels, transform=None)
print(f"\n[GOLD STANDARD] Validation Set: {len(val_ds)} samples")
print(f"[GOLD STANDARD] Base Training Set: {len(original_train_files)} samples")
print("-" * 60)

# 2. Training Loop per Architecture
# -----------------------------------------------------------------
for arch in ARCHITECTURES:
    teacher_name = arch  # Teacher is same arch
    student_arch = arch  # Student is same arch
    
    print(f"\n[INFO] TRAINING LINEAGE: {arch.upper()} (Student {arch} <- Teacher {arch})")
    
    # A. Load Verified Pseudo-Labels
    hitl_verified_files = []
    hitl_verified_labels = []
    
    # Using config.PSEUDO_LABEL_DIR relative path
    pseudo_root = config.PSEUDO_LABEL_DIR / teacher_name
    
    if not pseudo_root.exists():
        print(f"[WARN] No pseudo-labels found for {teacher_name} at {pseudo_root}. Skipping...")
        continue
        
    for class_path in pseudo_root.glob('*'):
        if class_path.is_dir() and class_path.name in cls2idx:
            for file_path in class_path.glob("*.png"):
                hitl_verified_files.append(file_path)
                hitl_verified_labels.append(cls2idx[class_path.name])
    
    if len(hitl_verified_files) == 0:
        print(f"[WARN] Pseudo-label directory is empty. Skipping...")
        continue

    # B. Data Fusion (Original Train + Teacher Pseudo-Labels)
    augmented_train_files = original_train_files + hitl_verified_files
    augmented_train_labels = original_train_labels + hitl_verified_labels
    
    train_ds = LabeledSpectro(augmented_train_files, augmented_train_labels, transform=weak_aug)
    
    print(f"  > Source Original (Train): {len(original_train_files)}")
    print(f"  > Source Pseudo-Labels ({teacher_name}): {len(hitl_verified_files)}")
    print(f"  > TOTAL AUGMENTED DATASET: {len(train_ds)} samples")
    
    # C. Configuration
    current_params = config.TRAIN_PARAMS[student_arch].copy()
    current_params['MODEL_NAME'] = f"{student_arch}" 
    
    # Fresh Dataloaders
    train_loader = DataLoader(train_ds, batch_size=current_params['BATCH_SIZE'], shuffle=True, num_workers=config.NUM_WORKERS)
    val_loader = DataLoader(val_ds, batch_size=current_params['BATCH_SIZE'], shuffle=False, num_workers=config.NUM_WORKERS)
    
    # D. Execute Training
    print(f"  > Starting training for Student {current_params['MODEL_NAME']}...")
    
    results = train_student(current_params, train_loader, val_loader, num_classes, class_names)
    
    print(f"[DONE] Lineage {arch} completed.")
    print("-" * 60)

print("\n\n[INFO] EXPERIMENT COMPLETED.")