In [None]:
import os

os.environ["OMP_NUM_THREADS"] = "3"
os.environ["MKL_NUM_THREADS"] = "3"
os.environ["NUMEXPR_NUM_THREADS"] = "3"
os.environ["OPENBLAS_NUM_THREADS"] = "3"

In [None]:
import math
from pathlib import Path
import numpy as np
from tqdm import tqdm
import torch
import torchaudio
from fastai.vision.all import *
from fastai.callback.all import *
from fastai.data.all import *
import torch.nn.functional as F

In [None]:
%run model.ipynb

In [None]:
%run loss.ipynb

In [None]:
%run metrics.ipynb

In [None]:
%run helpers.ipynb

In [None]:
try:
    import torch_directml
    dml = torch_directml.device()
    print(f"DirectML device available: {dml} | {torch_directml.device_name(0)}")
    USE_DIRECTML = True
except ImportError:
    print("torch_directml not available, using CPU")
    USE_DIRECTML = False
    dml = None

In [None]:
class AudioTensor(TensorBase):
    """Wrapper for audio tensors"""
    pass


def load_audio(path_str, target_length=64000):
    """Load and preprocess audio file"""
    path_str = Path(path_str)
    wave, sr = torchaudio.load(str(path_str))
    
    # Convert to mono
    if wave.shape[0] > 1:
        wave = wave.mean(0, keepdim=True)
    
    # Resample if needed
    if sr != 16000:
        resampler = torchaudio.transforms.Resample(sr, 16000)
        wave = resampler(wave)
    
    # Pad or crop to target length
    current_length = wave.shape[1]
    if current_length < target_length:
        padding = target_length - current_length
        wave = F.pad(wave, (0, padding))
    else:
        wave = wave[:, :target_length]
    
    return AudioTensor(wave)

In [None]:
def generate_dataloaders(noisy_dir, clean_dir, bs=8, valid_pct=0.15, verbose=False,
                              target_length=64000, num_workers=0, device=torch.device("cpu")):
    """
    Create DataLoaders for VoiceBank-DEMAND dataset
    
    Args:
        noisy_dir: Path to noisy audio files
        clean_dir: Path to clean audio files
        bs: Batch size
        valid_pct: Validation split percentage
        target_length: Fixed audio length in samples (64000 = 4 seconds @ 16kHz)
        num_workers: Number of data loading workers
    """
    noisy_dir = Path(noisy_dir)
    clean_dir = Path(clean_dir)
    
    # Get all noisy files
    noisy_files = sorted(list(noisy_dir.glob('*.wav')))
    
    # Create pairs by matching filenames
    items = [str(noisy_file) for noisy_file in noisy_files 
             if (clean_dir / noisy_file.name).exists()]
    
    print(f"Found {len(items)} audio pairs")
    
    def get_x(noisy_audio_path):
        return load_audio(noisy_audio_path, target_length)
    
    def get_y(noisy_audio_path):
        noisy_path = Path(noisy_audio_path)
        clean_path = clean_dir / noisy_path.name
        return load_audio(str(clean_path), target_length)
    
    # Custom type dispatch for AudioTensor
    def AudioTensorBlock():
        return TransformBlock(type_tfms=[], batch_tfms=[])
    
    dblock = DataBlock(
        blocks=(AudioTensorBlock(), AudioTensorBlock()),
        get_x=get_x,
        get_y=get_y,
        splitter=RandomSplitter(valid_pct=valid_pct, seed=42)
    )
    
    dls = dblock.dataloaders(items, bs=bs, num_workers=num_workers, verbose=verbose)
    dls = dls.to(dml)

    return dls


In [None]:
def generate_learner(
    train_noisy_dir="data/train/noisy_trainset_28spk_wav",
    train_clean_dir="data/train/clean_trainset_28spk_wav",
    epochs=80,
    batch_size=8,
    channels=96,
    num_blocks=10,
    num_repeats=2,
    target_length=64000,
    valid_pct=0.05,
    device=torch.device("cpu"),
    verbose=False
):
    """
    Train the causal noise removal model
    
    Args:
        train_noisy_dir: Path to noisy training audio
        train_clean_dir: Path to clean training audio
        epochs: Number of training epochs
        batch_size: Batch size
        channels: Number of channels in model
        num_blocks: Number of processing blocks
    """
    
    print(f"Loading data from:")
    print(f"Noisy: {train_noisy_dir}")
    print(f"Clean: {train_clean_dir}")
    
    # Create dataloaders
    dls = generate_dataloaders(
        train_noisy_dir, 
        train_clean_dir,
        target_length=target_length,
        bs=batch_size,
        valid_pct=valid_pct,
        device=device,
        verbose=verbose
    )
    
    # Show a batch to verify
    print("\nDataLoader check:")
    xb, yb = dls.one_batch()
    print(f"  Noisy batch shape: {xb.shape}")
    print(f"  Clean batch shape: {yb.shape}")
    
    # Create model
    model = CausalDNoizeConvTasNet(
        channels=channels, num_blocks=num_blocks,
        num_repeats=num_repeats
    )
    
    # Move to DirectML device if available
    if USE_DIRECTML:
        model = model.to(dml)
        print(f"\nModel moved to DirectML device")
    
    # Create learner
    learn = Learner(
        dls,
        model,
        loss_func=CombinedLoss(use_spectral=True),
        opt_func=lookahead_adamw,
        metrics=[SISNRMetric(), NoiseReductionPct()],
        cbs=[
            SaveModelCallback(monitor='sisnr_db', fname='causal_dnoize_best', with_opt=True),
            GradientClip(max_norm=1.0), GradientAccumulation(n_acc=4), SISNRDiagnostic()
        ]
    ).to_fp16(enabled=False)
    
    # # Override device if using DirectML
    # if USE_DIRECTML:
    #     learn.model = learn.model.to(device)
    
    print(f"  Batch size: {batch_size}")
    print(f"  Model channels: {channels}")
    print(f"  Model blocks: {num_blocks}")
    
    return learn

In [None]:


learn = generate_learner(
    train_noisy_dir="data/train/noisy_trainset_28spk_wav",
    train_clean_dir="data/train/clean_trainset_28spk_wav",
    batch_size=4,
    channels=48,
    num_blocks=8,
    num_repeats=2,
    target_length=48000,
    valid_pct=0.05,
    device=dml
)

In [None]:
# On resume â€” automatically knows how many epochs are left
n_epochs = 200
tracker = EpochTracker(total_epochs=n_epochs)
epochs_remaining = 200 - tracker.epochs_done

resuming_after_failure = tracker.epochs_done > 0

if resuming_after_failure:
    print(f"Resuming from epoch {tracker.epochs_done}, {epochs_remaining} remaining")


In [None]:
# Load the best checkpoint saved so far on training failure
if resuming_after_failure:
    learn.load('causal_dnoize_best', with_opt=True)


In [None]:
if not resuming_after_failure:
    suggested_lr = learn.lr_find(stop_div=False, suggest_funcs=(steep, valley))

In [None]:
if not resuming_after_failure:
    print(suggested_lr)

In [None]:
if resuming_after_failure:
    # Reduce LR proportionally to training progress
    lr_max = tracker.saved_lr_max
    progress  = tracker.epochs_done / tracker.total_epochs

    resume_lr = lr_max * max(0.1, 1 - progress * 0.8)
    
    print(f"Resume LR: {resume_lr} | Saved Max LR: {lr_max}")
    
else:
    # Use valley directly if steep is unreasonably small
    if suggested_lr.steep < 1e-5:
        lr_max = suggested_lr.valley / 10.
        print(f"steep too small, using valley / 10: {lr_max:.8f}")
    else:
        lr_max = math.exp(
            (math.log(suggested_lr.steep) + math.log(suggested_lr.valley)) / 2
        )
        print(f"using geometric mean: {lr_max:.8f}")
    tracker.saved_lr_max = lr_max


In [None]:
learn.add_cb(tracker)

training_start_msg = (
    f"Starting training for {epochs_remaining} epochs with {"lr_max = "+str(lr_max) if not resuming_after_failure else "lr = "+str(resume_lr)} | "
    f"{'fresh start' if not resuming_after_failure else 'resuming after '+str(tracker.epochs_done)+' epochs'}"
)

print(training_start_msg)

In [None]:
if resuming_after_failure:
    # no flat phase, pure cosine decay from start
    learn.fit_flat_cos(epochs_remaining, lr=resume_lr, pct_start=0.0, wd=1e-4)
else:
    learn.fit_one_cycle(epochs_remaining, lr_max=lr_max, div=25, pct_start=0.05, wd=1e-4)

In [None]:
learn.save('causal_dnoize_final', with_opt=True)

In [None]:
pesq_score, stoi_score = evaluate_checkpoint(
    'models/causal_dnoize_best.pth', learn.dls
)

In [None]:
pesq_score

In [None]:
stoi_score