In [None]:
import os
from multiprocessing import Value as MPValue

EPOCH_SEED = MPValue('i', 0)

os.environ["OMP_NUM_THREADS"] = "3"
os.environ["MKL_NUM_THREADS"] = "3"
os.environ["NUMEXPR_NUM_THREADS"] = "3"
os.environ["OPENBLAS_NUM_THREADS"] = "3"

In [None]:
import math
from pathlib import Path
import numpy as np
from tqdm import tqdm
import torch
import torchaudio
from fastai.vision.all import *
from fastai.callback.all import *
from fastai.data.all import *
import torch.nn.functional as F


In [None]:
%run model.ipynb

In [None]:
%run loss.ipynb

In [None]:
%run metrics.ipynb

In [None]:
%run helpers.ipynb

In [None]:
try:
    import torch_directml
    dml = torch_directml.device()
    print(f"DirectML device available: {dml} | {torch_directml.device_name(0)}")
    USE_DIRECTML = True
except ImportError:
    print("torch_directml not available, using CPU")
    USE_DIRECTML = False
    dml = None

In [None]:
class AudioTensor(TensorBase):
    """Wrapper for audio tensors"""
    pass


def load_audio(file_path, target_length=48000, role='input', rt60_range=None):
    """
    Load and augment audio for a single role.

    role='input':  loads noisy path, applies following strategy if rt_range is provided
        (50%): noisy → clean          denoise only
        (30%): reverb(noisy) → clean  joint denoise + deverb
        (20%): reverb(clean) → clean  deverb only

    role='target': whether provided file path is for noisy or clean audio

    Deterministic per path+epoch — safe for any num_workers.
    """
    path = Path(file_path)
    epoch_seed = EPOCH_SEED.value

    if role == 'target':
        wave_np = _load_raw_np(path)

    else:
        strategy = _get_strategy(path, epoch_seed) if rt60_range else 'denoise_only'

        if strategy == 'denoise_only':
            wave_np = _load_raw_np(path)

        elif strategy == 'joint':
            noisy_np = _load_raw_np(path)
            try:
                rng = _get_reverb_rng(path, epoch_seed)
                wave_np = add_reverb(noisy_np, rt60_range=rt60_range, rng=rng)
            except Exception:
                wave_np = noisy_np

        else:  # deverb_only
            clean_path = Path(path.parent.as_posix().replace("noisy", "clean"))
            clean_np = _load_raw_np(clean_path / path.name)
            try:
                rng = _get_reverb_rng(path, epoch_seed)
                wave_np = add_reverb(clean_np, rt60_range=rt60_range, rng=rng)
            except Exception:
                wave_np = clean_np

    # Crop then pad to exact target_length
    wave_np = wave_np[:target_length]
    if len(wave_np) < target_length:
        wave_np = np.pad(wave_np, (0, target_length - len(wave_np)))

    return AudioTensor(torch.tensor(wave_np).unsqueeze(0))

In [None]:
def generate_dataloaders(
    noisy_dir, clean_dir, bs=8, valid_pct=0.15, verbose=False, target_length=64000,
    num_workers=0, device=torch.device("cpu"), use_reverb=True, rt60_range=(0.2, 0.6)
):
    """
    Create DataLoaders for VoiceBank-DEMAND dataset
    
    Args:
        noisy_dir: Path to noisy audio files
        clean_dir: Path to clean audio files
        bs: Batch size
        valid_pct: Validation split percentage
        target_length: Fixed audio length in samples (64000 = 4 seconds @ 16kHz)
        num_workers: Number of data loading workers
    """
    noisy_dir = Path(noisy_dir)
    clean_dir = Path(clean_dir)
    
    # Get all noisy files
    noisy_files = sorted(list(noisy_dir.glob('*.wav')))
    
    # Create pairs by matching filenames
    items = [str(noisy_file) for noisy_file in noisy_files 
             if (clean_dir / noisy_file.name).exists()]
    
    total   = len(items)
    n_val   = int(total * valid_pct)
    n_train = total - n_val
    print(f"Found {total} pairs  |  train={n_train}  valid={n_val}")

    if use_reverb:
        print(f"Strategy C enabled  |  RT60={rt60_range}")
        print(f"  ~{int(n_train*0.5)} denoise only  (~50%)")
        print(f"  ~{int(n_train*0.3)} joint         (~30%)")
        print(f"  ~{int(n_train*0.2)} deverb only   (~20%)")
        print(f"  Same distribution applied to validation")
    else:
        print(f"Augmentation disabled — straight denoising pairs")

    _rt60 = rt60_range if use_reverb else None
    

    def get_x(noisy_audio_path):
        return load_audio(noisy_audio_path, target_length=target_length, role='input', rt60_range=_rt60)

    def get_y(noisy_audio_path):
        noisy_path = Path(noisy_audio_path)
        clean_path = clean_dir / noisy_path.name
        return load_audio(str(clean_path), target_length=target_length, role='target')
    
    # Custom type dispatch for AudioTensor
    def AudioTensorBlock():
        return TransformBlock(type_tfms=[], batch_tfms=[])
    
    dblock = DataBlock(
        blocks=(AudioTensorBlock(), AudioTensorBlock()),
        get_x=get_x,
        get_y=get_y,
        splitter=RandomSplitter(valid_pct=valid_pct, seed=42)
    )
    
    dls = dblock.dataloaders(items, bs=bs, num_workers=num_workers, verbose=verbose)
    dls = dls.to(device)

    return dls


In [None]:
def generate_learner(
    train_noisy_dir="data/train/noisy_trainset_28spk_wav",
    train_clean_dir="data/train/clean_trainset_28spk_wav",
    epochs=80,
    batch_size=8,
    channels=96,
    num_blocks=10,
    num_repeats=2,
    target_length=64000,
    valid_pct=0.05,
    device=torch.device("cpu"),
    verbose=False,
    use_reverb=True,
    rt60_range=(0.2, 0.6)
):
    """
    Train the causal noise removal model
    
    Args:
        train_noisy_dir: Path to noisy training audio
        train_clean_dir: Path to clean training audio
        epochs: Number of training epochs
        batch_size: Batch size
        channels: Number of channels in model
        num_blocks: Number of processing blocks
    """
    
    print(f"Loading data from:")
    print(f"Noisy: {train_noisy_dir}")
    print(f"Clean: {train_clean_dir}")
    
    # Create dataloaders
    dls = generate_dataloaders(
        train_noisy_dir, 
        train_clean_dir,
        target_length=target_length,
        bs=batch_size,
        valid_pct=valid_pct,
        device=device,
        verbose=verbose,
        use_reverb=use_reverb,
        rt60_range=rt60_range
    )
    
    # Show a batch to verify
    print("\nDataLoader check:")
    xb, yb = dls.one_batch()
    print(f"  Noisy batch shape: {xb.shape}")
    print(f"  Clean batch shape: {yb.shape}")
    
    # Create model
    model = CausalDNoizeConvTasNet(
        channels=channels, num_blocks=num_blocks,
        num_repeats=num_repeats
    )
    
    # Move to DirectML device if available
    if USE_DIRECTML:
        model = model.to(device)
        print(f"\nModel moved to DirectML device")
    
    # Create learner
    learn = Learner(
        dls,
        model,
        loss_func=CombinedLoss(use_spectral=True),
        opt_func=lookahead_adamw,
        metrics=[SISNRMetric(), NoiseReductionPct()],
        cbs=[
            SaveModelCallback(monitor='sisnr_db', fname='causal_dnoize_best', with_opt=True),
            GradientClip(max_norm=1.0), GradientAccumulation(n_acc=4), SISNRDiagnostic(),
            PeriodicPESQSTOI(every_n=10, n_batches=10)
        ]
    ).to_fp16(enabled=False)
    
    print(f"  Batch size: {batch_size}")
    print(f"  Model channels: {channels}")
    print(f"  Model blocks: {num_blocks}")
    
    return learn

In [None]:


learn = generate_learner(
    train_noisy_dir="data/train/noisy_trainset_28spk_wav",
    train_clean_dir="data/train/clean_trainset_28spk_wav",
    batch_size=4,
    channels=48,
    num_blocks=8,
    num_repeats=2,
    target_length=48000,
    valid_pct=0.05,
    device=dml,
    use_reverb=True,
    rt60_range=(0.2, 0.6)
)

In [None]:
# # Run this on one batch to see the loss component magnitudes
# learn.model.eval()
# xb, yb = learn.dls.one_batch()
# with torch.no_grad():
#     pred = learn.model(xb)

# pred_sq = pred.squeeze(1).float().cpu()
# targ_sq = yb.squeeze(1).float().cpu()

# sisnr    = si_snr_loss(pred_sq, targ_sq)
# l1       = F.l1_loss(pred, yb)
# pred_rms = pred.pow(2).mean(dim=-1).sqrt()
# targ_rms = yb.pow(2).mean(dim=-1).sqrt()
# # power_reg = F.l1_loss(pred_rms, targ_rms.detach())
# ratio     = pred_rms / (targ_rms + 1e-8)
# power_reg = (ratio - 1.0).abs().mean()

# print(f"sisnr:      {sisnr.item():.6f}")
# print(f"l1:         {l1.item():.6f}")
# print(f"power_reg:  {power_reg.item():.6f}")
# print(f"pred_rms:   {pred_rms.squeeze().tolist()}")
# print(f"targ_rms:   {targ_rms.squeeze().tolist()}")
# print(f"ratio:      {(pred_rms / (targ_rms + 1e-8)).squeeze().tolist()}")
# learn.model.train()

In [None]:
# learn.model.eval()
# xb, yb = learn.dls.one_batch()
# with torch.no_grad():
#     pred = learn.model(xb)

# pred_cpu = pred.squeeze(1).detach().float().cpu()
# targ_cpu = yb.squeeze(1).detach().float().cpu()

# pred_rms  = pred_cpu.pow(2).mean(dim=-1).sqrt()
# targ_rms  = targ_cpu.pow(2).mean(dim=-1).sqrt()
# ratio     = pred_rms / (targ_rms + 1e-8)
# power_reg = (ratio - 1.0).abs().mean()

# sisnr = si_snr_loss(pred_cpu, targ_cpu)

# print(f"sisnr:       {sisnr.item():.6f}")
# print(f"power_reg:   {power_reg.item():.6f}")
# print(f"pred_rms:    {pred_rms.tolist()}")
# print(f"targ_rms:    {targ_rms.tolist()}")
# print(f"ratio:       {ratio.tolist()}")
# print(f"mean_ratio:  {ratio.mean().item():.4f}")
# print(f"pred range:  [{pred_cpu.min().item():.4f}, {pred_cpu.max().item():.4f}]")

# learn.model.train()

In [None]:
# # Final SI-SNR check
# learn.model.eval()
# xb, yb = learn.dls.one_batch()
# with torch.no_grad():
#     pred = learn.model(xb)

# pred_cpu = pred.squeeze(1).detach().float().cpu()
# targ_cpu = yb.squeeze(1).detach().float().cpu()

# targ_zm = targ_cpu - targ_cpu.mean(dim=-1, keepdim=True)
# pred_zm = pred_cpu - pred_cpu.mean(dim=-1, keepdim=True)
# eps = 1e-8
# alpha = (targ_zm * pred_zm).sum(-1, keepdim=True) / \
#         (targ_zm.pow(2).sum(-1, keepdim=True) + eps)
# s = (alpha * targ_zm).pow(2).sum(-1)
# n = (pred_zm - alpha * targ_zm).pow(2).sum(-1)
# sisnr = 10 * torch.log10((s + eps) / (n + eps))

# print(f"SI-SNR per sample: {[round(v,2) for v in sisnr.tolist()]}")
# print(f"SI-SNR mean:       {sisnr.mean().item():.4f} dB")
# learn.model.train()

In [None]:
# On resume — automatically knows how many epochs are left
n_epochs = 200
tracker = EpochTracker(total_epochs=n_epochs)
epochs_remaining = 200 - tracker.epochs_done

resuming_after_failure = tracker.epochs_done > 0

if resuming_after_failure:
    print(f"Resuming from epoch {tracker.epochs_done}, {epochs_remaining} remaining")


In [None]:
# Load the best checkpoint saved so far on training failure
if resuming_after_failure:
    learn.load('causal_dnoize_best', with_opt=True)


In [None]:
if not resuming_after_failure:
    suggested_lr = learn.lr_find(stop_div=False, num_it=50, suggest_funcs=(steep, valley))

In [None]:
if not resuming_after_failure:
    print(suggested_lr)

In [None]:
if resuming_after_failure:
    # Reduce LR proportionally to training progress
    lr_max = tracker.saved_lr_max
    progress  = tracker.epochs_done / tracker.total_epochs

    resume_lr = lr_max * max(0.1, 1 - progress * 0.8)
    
    print(f"Resume LR: {resume_lr} | Saved Max LR: {lr_max}")
    
else:
    steep  = suggested_lr.steep
    valley = suggested_lr.valley
    print(f"lr_find:  steep={steep:.2e}  valley={valley:.2e}")

    # Weight steep more heavily for complex tasks —
    # warmup (pct_start=0.05 over 200 epochs = 10 epochs) 
    # will climb from lr_max/25 to lr_max finding the sweet spot
    w_steep, w_valley = 0.67, 0.33
    lr_max = math.exp(
        w_steep  * math.log(steep) +
        w_valley * math.log(valley)
    )

    # Floor at 1e-5 — never go below this regardless of lr_find output
    lr_max = max(lr_max, 1e-5)

    print(f"weighted geomean (steep×{w_steep} valley×{w_valley}): "
          f"{lr_max:.8f}")
    print(f"  warmup will climb from {lr_max/25:.2e} → {lr_max:.2e} "
          f"over first 10 epochs")

    tracker.saved_lr_max = lr_max


In [None]:
learn.add_cb(tracker)

training_start_msg = (
    f"Starting training for {epochs_remaining} epochs with {"lr_max = "+str(lr_max) if not resuming_after_failure else "lr = "+str(resume_lr)} | "
    f"{'fresh start' if not resuming_after_failure else 'resuming after '+str(tracker.epochs_done)+' epochs'}"
)

print(training_start_msg)

In [None]:
if resuming_after_failure:
    # no flat phase, pure cosine decay from start
    learn.fit_flat_cos(epochs_remaining, lr=resume_lr, pct_start=0.0, wd=1e-4)
else:
    learn.fit_one_cycle(epochs_remaining, lr_max=lr_max, div=10, pct_start=0.05, wd=1e-4)

In [None]:
learn.save('causal_dnoize_final', with_opt=True)

In [None]:
pesq_score, stoi_score = evaluate_checkpoint(
    'models/causal_dnoize_best.pth',
    learn.dls,
    channels=48,
    num_blocks=8,
    num_repeats=2
)

In [None]:
pesq_score

In [None]:
stoi_score