In [None]:
import os
from pathlib import Path
import numpy as np
from tqdm import tqdm
import torch
import torchaudio
from fastai.vision.all import *
from fastai.callback.all import *
from fastai.data.all import *
import torch.nn.functional as F

In [None]:
%run model.ipynb

In [None]:
%run loss.ipynb

In [None]:
%run metrics.ipynb

In [None]:
# # %%
# # ENHANCED PATCH - Shows header table + console progress updates
# from fastprogress.fastprogress import NBMasterBar, NBProgressBar, master_bar, progress_bar
# from IPython.display import display

# print("Applying enhanced ProgressCallback patch...")

# # ============================================================================
# # PATCH 1: NBMasterBar initialization
# # ============================================================================
# _original_nbmasterbar_init = NBMasterBar.__init__

# def _patched_nbmasterbar_init(self, gen, total=None, parent=None, display=True, leave=True, **kwargs):
#     try:
#         _original_nbmasterbar_init(self, gen, total, parent, display, leave, **kwargs)
#     except Exception as e:
#         self.gen = gen
#         self.total = total
#         self.parent = parent
#         self.leave = leave
#         self.display = display
#         self.first_bar = None
    
#     if not hasattr(self, 'out'):
#         try:
#             from ipywidgets import Output
#             self.out = Output()
#             if display:
#                 display(self.out)
#         except:
#             class DummyOut:
#                 def update(self, *args, **kwargs): pass
#             self.out = DummyOut()
    
#     if not hasattr(self, 'order') or not isinstance(self.order, (list, tuple)):
#         self.order = ['main', 'text']
    
#     if not hasattr(self, 'inner_dict'):
#         self.inner_dict = {}
    
#     if not hasattr(self, 'text_parts'):
#         self.text_parts = []
    
#     if not hasattr(self, 'lines'):
#         self.lines = []

# NBMasterBar.__init__ = _patched_nbmasterbar_init

# # ============================================================================
# # PATCH 2: NBMasterBar.show()
# # ============================================================================
# _original_show = NBMasterBar.show

# def _patched_show(self):
#     if not hasattr(self, 'out'): return
#     if not hasattr(self, 'inner_dict'): self.inner_dict = {}
#     if not hasattr(self, 'text_parts'): self.text_parts = []
#     if not hasattr(self, 'order') or not isinstance(self.order, (list, tuple)):
#         self.order = ['main', 'text']
#     try:
#         from fastprogress.fastprogress import Div
#         if self.text_parts:
#             self.inner_dict['text'] = Div(*self.text_parts)
#         children = []
#         for n in self.order:
#             item = self.inner_dict.get(n)
#             if item is not None:
#                 child = getattr(item, 'progress', None) or item
#                 children.append(child)
#         if children and hasattr(self.out, 'update'):
#             self.out.update(Div(*children))
#     except: pass

# NBMasterBar.show = _patched_show

# # ============================================================================
# # PATCH 3: NBMasterBar.write() - Enhanced with console output
# # ============================================================================
# _original_write = NBMasterBar.write

# def _patched_write(self, line, table=False):
#     if not hasattr(self, 'lines'): self.lines = []
#     if not hasattr(self, 'text_parts'): self.text_parts = []
    
#     # ALWAYS print to console for visibility
#     if table and isinstance(line, list):
#         # Skip printing the header (first line)
#         if len(self.lines) > 0:  # This is data, not header
#             print(' | '.join(f'{x:>12}' if isinstance(x, (int, float)) else f'{x:>12}' for x in line))
    
#     try:
#         from fastprogress.fastprogress import text2html_table, P
#         if table:
#             self.lines.append(line)
#             self.text_parts = [text2html_table(self.lines)]
#         else:
#             self.text_parts.append(P(line))
#         self.show()
#     except:
#         # Fallback console output
#         if isinstance(line, (list, tuple)):
#             print(' | '.join(str(x) for x in line))
#         else:
#             print(line)

# NBMasterBar.write = _patched_write

# # ============================================================================
# # PATCH 4: NBMasterBar.update() - Add console progress
# # ============================================================================
# _original_update = NBMasterBar.update

# def _patched_update(self, val=None):
#     try:
#         _original_update(self, val)
#     except:
#         pass
    
#     # Print epoch progress to console
#     if val is not None and hasattr(self, 'total'):
#         if self.total:
#             pct = (val + 1) / self.total * 100
#             print(f"\rEpoch {val + 1}/{self.total} ({pct:.0f}%)", end='', flush=True)

# NBMasterBar.update = _patched_update

# # ============================================================================
# # PATCH 5: NBProgressBar for batch progress (optional)
# # ============================================================================
# _original_nbprogress_update = NBProgressBar.update

# def _patched_nbprogress_update(self, val=None):
#     try:
#         _original_nbprogress_update(self, val)
#     except:
#         pass

# NBProgressBar.update = _patched_nbprogress_update

# print("âœ“ Enhanced ProgressCallback patched")
# print("  - Table header will display")
# print("  - Metrics will print to console after each epoch")
# print("  - Epoch progress will show inline")
# print("\nYou can now train:")
# print("  learn.fit_one_cycle(80, lr_max=3e-4)")

In [None]:
try:
    import torch_directml
    dml = torch_directml.device()
    print(f"DirectML device available: {dml} | {torch_directml.device_name(0)}")
    USE_DIRECTML = True
except ImportError:
    print("torch_directml not available, using CPU")
    USE_DIRECTML = False
    dml = None

In [None]:
class AudioTensor(TensorBase):
    """Wrapper for audio tensors"""
    pass


def load_audio(path_str, target_length=64000):
    """Load and preprocess audio file"""
    path_str = Path(path_str)
    wave, sr = torchaudio.load(str(path_str))
    
    # Convert to mono
    if wave.shape[0] > 1:
        wave = wave.mean(0, keepdim=True)
    
    # Resample if needed
    if sr != 16000:
        resampler = torchaudio.transforms.Resample(sr, 16000)
        wave = resampler(wave)
    
    # Pad or crop to target length
    current_length = wave.shape[1]
    if current_length < target_length:
        padding = target_length - current_length
        wave = F.pad(wave, (0, padding))
    else:
        wave = wave[:, :target_length]
    
    return AudioTensor(wave)

In [None]:
def generate_dataloaders(noisy_dir, clean_dir, bs=8, valid_pct=0.15, verbose=False,
                              target_length=64000, num_workers=0, device=torch.device("cpu")):
    """
    Create DataLoaders for VoiceBank-DEMAND dataset
    
    Args:
        noisy_dir: Path to noisy audio files
        clean_dir: Path to clean audio files
        bs: Batch size
        valid_pct: Validation split percentage
        target_length: Fixed audio length in samples (64000 = 4 seconds @ 16kHz)
        num_workers: Number of data loading workers
    """
    noisy_dir = Path(noisy_dir)
    clean_dir = Path(clean_dir)
    
    # Get all noisy files
    noisy_files = sorted(list(noisy_dir.glob('*.wav')))
    
    # Create pairs by matching filenames
    items = [str(noisy_file) for noisy_file in noisy_files 
             if (clean_dir / noisy_file.name).exists()]
    
    print(f"Found {len(items)} audio pairs")
    
    def get_x(noisy_audio_path):
        return load_audio(noisy_audio_path, target_length)
    
    def get_y(noisy_audio_path):
        noisy_path = Path(noisy_audio_path)
        clean_path = clean_dir / noisy_path.name
        return load_audio(str(clean_path), target_length)
    
    # Custom type dispatch for AudioTensor
    def AudioTensorBlock():
        return TransformBlock(type_tfms=[], batch_tfms=[])
    
    dblock = DataBlock(
        blocks=(AudioTensorBlock(), AudioTensorBlock()),
        get_x=get_x,
        get_y=get_y,
        splitter=RandomSplitter(valid_pct=valid_pct, seed=42)
    )
    
    dls = dblock.dataloaders(items, bs=bs, num_workers=num_workers, verbose=verbose)
    dls = dls.to(dml)

    return dls


In [None]:
def generate_learner(
    train_noisy_dir="data/train/noisy_trainset_28spk_wav",
    train_clean_dir="data/train/clean_trainset_28spk_wav",
    epochs=80,
    lr=3e-4,
    batch_size=8,
    channels=96,
    num_blocks=4,
    device=torch.device("cpu"),
    verbose=False
):
    """
    Train the causal noise removal model
    
    Args:
        train_noisy_dir: Path to noisy training audio
        train_clean_dir: Path to clean training audio
        epochs: Number of training epochs
        lr: Learning rate
        batch_size: Batch size
        channels: Number of channels in model
        num_blocks: Number of processing blocks
        use_56spk: Use 56 speaker dataset instead of 28
    """
    
    print(f"Loading data from:")
    print(f"Noisy: {train_noisy_dir}")
    print(f"Clean: {train_clean_dir}")
    
    # Create dataloaders
    dls = generate_dataloaders(
        train_noisy_dir, 
        train_clean_dir,
        target_length=80000,
        bs=batch_size,
        valid_pct=0.1,
        device=device,
        verbose=verbose
    )
    
    # Show a batch to verify
    print("\nDataLoader check:")
    xb, yb = dls.one_batch()
    print(f"  Noisy batch shape: {xb.shape}")
    print(f"  Clean batch shape: {yb.shape}")
    
    # Create model
    model = CausalDNoizeConvTasNet(channels=channels, num_blocks=num_blocks)
    
    # Move to DirectML device if available
    if USE_DIRECTML:
        model = model.to(dml)
        print(f"\nModel moved to DirectML device")
    
    # Create learner
    learn = Learner(
        dls,
        model,
        loss_func=CombinedLoss(),
        opt_func=Adam,
        metrics=[pesq_metric, stoi_metric, DenoisingAccuracy()],
        cbs=[
            SaveModelCallback(
                monitor='accuracy_%',
                fname='causal_dnoize_best'
            )
        ]
    ).to_fp16(enabled=False)
    
    # Override device if using DirectML
    if USE_DIRECTML:
        learn.dls.device = device
        learn.model = learn.model.to(device)
    
    print(f"  Batch size: {batch_size}")
    print(f"  Model channels: {channels}")
    print(f"  Model blocks: {num_blocks}")
    
    return learn

In [None]:
learn = generate_learner(
    train_noisy_dir="data/train/noisy_trainset_28spk_wav",
    train_clean_dir="data/train/clean_trainset_28spk_wav",
    batch_size=4,
    channels=128,
    num_blocks=6,
    device=dml
)

In [None]:
# learn.lr_find()

In [None]:
# Training parameters
n_epochs = 150
lr = 0.000575
best_pesq = 0

print("Starting training...")

In [None]:
learn.fit_one_cycle(n_epochs, lr_max=lr, div=25, pct_start=0.3, wd=1e-5)

In [None]:
learn.save('causal_dnoize_final')