In [None]:
import os, time
from datetime import timedelta
from fastai.callback.all import *
from fastai.optimizer import Lookahead, Adam

In [None]:
def lookahead_adamw(params, **kwargs):
    kwargs['decouple_wd'] = True
    return Lookahead(Adam(params, **kwargs), k=6, alpha=0.5)

In [None]:
class EpochTracker(Callback):
    """
    Tracks epoch count, lr_max and per-epoch metrics.
    Saves to disk after every epoch for clean resume after failure.
    
    Files:
      tracker_fname — epochs_done, lr_max (for resume)
      log_fname     — full per-epoch metrics CSV
    
    Order=60 ensures this runs after Recorder(order=50) so
    recorder.values[-1] is fully populated when after_epoch fires.
    """
    order = 60

    def __init__(self, lr_max=0, total_epochs=200,
                 tracker_fname='models/epoch_tracker.txt',
                 log_fname='models/training_log.csv'):
        self.saved_lr_max  = lr_max
        self.total_epochs  = total_epochs
        self.tracker_fname = tracker_fname
        self.log_fname     = log_fname
        self.epochs_done, self.saved_lr_max = self._load_tracker()
        self._init_log()

    def _load_tracker(self):
        try:
            with open(self.tracker_fname, 'r') as f:
                parts = f.read().strip().split(',')
                epochs = int(parts[0])
                lr_max = float(parts[1]) if len(parts) > 1 else self.saved_lr_max
                return epochs, lr_max
        except:
            return 0, self.saved_lr_max

    def _save_tracker(self):
        os.makedirs(os.path.dirname(self.tracker_fname), exist_ok=True)
        with open(self.tracker_fname, 'w') as f:
            f.write(f"{self.epochs_done},{self.saved_lr_max}")

    def _init_log(self):
        os.makedirs(os.path.dirname(self.log_fname), exist_ok=True)
        if not Path(self.log_fname).exists():
            with open(self.log_fname, 'w', newline='') as f:
                csv.writer(f).writerow([
                    'epoch',
                    'train_loss',
                    'valid_loss',
                    'sisnr_db',
                    'noise_removed_percentage',
                    'current_lr',
                    'lr_max',
                    'epoch_elapsed_time'
                ])

    def before_epoch(self):
        self._epoch_start = time.time()

    def after_epoch(self):
        """
        Runs after Recorder.after_epoch (order=50) so recorder.values[-1]
        is fully populated with [train_loss, valid_loss, sisnr_db, noise_removed_%]
        """
        elapsed  = time.time() - self._epoch_start
        self.epochs_done += 1

        recorder = self.learn.recorder

        # recorder.values[-1] layout (confirmed from debug output):
        # [train_loss, valid_loss, sisnr_db, noise_removed_%]
        vals = recorder.values[-1] if recorder.values else []

        train_loss    = vals[0] if len(vals) > 0 else None
        valid_loss    = vals[1] if len(vals) > 1 else None
        sisnr_db      = vals[2] if len(vals) > 2 else None
        noise_removed = vals[3] if len(vals) > 3 else None

        current_lr = self.learn.opt.hypers[-1]['lr'] if self.learn.opt else None

        with open(self.log_fname, 'a', newline='') as f:
            csv.writer(f).writerow([
                self.epochs_done,
                f"{train_loss:.6f}"    if train_loss     is not None else '',
                f"{valid_loss:.6f}"    if valid_loss     is not None else '',
                f"{sisnr_db:.6f}"      if sisnr_db       is not None else '',
                f"{noise_removed:.6f}" if noise_removed  is not None else '',
                f"{current_lr}"        if current_lr     is not None else '',
                f"{self.saved_lr_max}",
                f"{timedelta(seconds=elapsed)}"
            ])

        self._save_tracker()

In [None]:
class SISNRDiagnostic(Callback):
    """
    Computes manual SI-SNR on validation batch after each epoch
    and writes to file for monitoring without interrupting training.
    """
    def __init__(self, fname='models/sisnr_diagnostic.csv', every_n=1):
        self.fname   = fname
        self.every_n = every_n
        self._init_file()

    def _init_file(self):
        os.makedirs(os.path.dirname(self.fname), exist_ok=True)
        if not Path(self.fname).exists():
            with open(self.fname, 'w', newline='') as f:
                csv.writer(f).writerow([
                    'epoch',
                    'sisnr_sample_0', 'sisnr_sample_1',
                    'sisnr_sample_2', 'sisnr_sample_3',
                    'sisnr_mean',
                    'pred_min', 'pred_max',
                    'pred_power_mean', 'targ_power_mean',
                    'power_ratio'
                ])

    def after_epoch(self):
        if self.epoch % self.every_n != 0:
            return
        try:
            self.learn.model.eval()
            xb, yb = self.learn.dls.valid.one_batch()

            with torch.no_grad():
                pred = self.learn.model(xb)

            pred_sq = pred.squeeze(1).detach().float()
            targ_sq = yb.squeeze(1).detach().float()

            # Manual SI-SNR
            targ_zm = targ_sq - targ_sq.mean(dim=-1, keepdim=True)
            pred_zm = pred_sq - pred_sq.mean(dim=-1, keepdim=True)

            eps = 1e-8
            alpha = (targ_zm * pred_zm).sum(-1, keepdim=True) / \
                    (targ_zm.pow(2).sum(-1, keepdim=True) + eps)
            s = (alpha * targ_zm).pow(2).sum(-1)
            n = (pred_zm - alpha * targ_zm).pow(2).sum(-1)
            sisnr = 10 * torch.log10((s + eps) / (n + eps))

            # Power ratio — should trend toward 1.0 as training progresses
            pred_power = pred_zm.pow(2).sum(dim=-1).mean().item()
            targ_power = targ_zm.pow(2).sum(dim=-1).mean().item()
            ratio      = pred_power / (targ_power + eps)

            samples = sisnr.cpu().tolist()
            # Pad to 4 samples if batch is smaller
            while len(samples) < 4:
                samples.append('')

            with open(self.fname, 'a', newline='') as f:
                csv.writer(f).writerow([
                    self.epoch + 1,
                    f"{samples[0]:.4f}" if samples[0] != '' else '',
                    f"{samples[1]:.4f}" if samples[1] != '' else '',
                    f"{samples[2]:.4f}" if samples[2] != '' else '',
                    f"{samples[3]:.4f}" if samples[3] != '' else '',
                    f"{sisnr.mean().item():.4f}",
                    f"{pred_sq.min().item():.4f}",
                    f"{pred_sq.max().item():.4f}",
                    f"{pred_power:.4f}",
                    f"{targ_power:.4f}",
                    f"{ratio:.4f}"
                ])

        except Exception as e:
            # Never crash training — just log the error
            with open(self.fname, 'a', newline='') as f:
                csv.writer(f).writerow([self.epoch + 1, f"ERROR: {e}",
                                        '', '', '', '', '', '', '', '', ''])
        finally:
            self.learn.model.train()