In [1]:
import os
from pathlib import Path
import numpy as np
from tqdm import tqdm
import torch
import torch_directml
import torchaudio
from fastai.vision.all import *
from fastai.callback.all import *


In [2]:
%run model.ipynb

In [3]:
%run loss.ipynb

In [None]:
# wrapper for fastai dataset

class EnhancementDataset(Dataset):
    def __init__(self, noisy_paths, clean_paths, sr=16000, segment_sec=4.0):
        self.noisy_paths = noisy_paths
        self.clean_paths = clean_paths
        self.sr = sr
        self.segment_samples = int(segment_sec * sr)

    def __len__(self):
        return len(self.noisy_paths)

    def __getitem__(self, idx):
        noisy, _ = torchaudio.load(self.noisy_paths[idx])
        clean, _ = torchaudio.load(self.clean_paths[idx])

        max_start = noisy.shape[1] - self.segment_samples
        if max_start > 0:
            start = np.random.randint(0, max_start)
            noisy = noisy[:, start:start+self.segment_samples]
            clean = clean[:, start:start+self.segment_samples]
        else:
            pad = self.segment_samples - noisy.shape[1]
            noisy = F.pad(noisy, (0, pad))
            clean = F.pad(clean, (0, pad))

        return noisy, clean  # [C=1, T], [C=1, T]


In [None]:
# Main training

def train_fastai(noisy_dir, clean_dir, epochs=80, lr=3e-4, batch_size=8, save_path=Path("checkpoints/dnoize_fastai.pt")):
    directml_available = torch_directml.is_available()
    device = torch_directml.device() if directml_available else torch.device("cpu")
    training_start_msg = f"Training on: {device}"
    if directml_available:
        training_start_msg +=  f" | {torch_directml.device_name(0)}"

    print(training_start_msg)

    # Get file lists (assume paired by name or sorted)
    noisy_paths = sorted([os.path.join(noisy_dir, f) for f in os.listdir(noisy_dir) if f.endswith('.wav')])
    clean_paths = sorted([os.path.join(clean_dir, f) for f in os.listdir(clean_dir) if f.endswith('.wav')])
    assert len(noisy_paths) == len(clean_paths)

    dset = EnhancementDataset(noisy_paths, clean_paths)
    dls = DataLoaders.from_dsets(
        dset,
        bs=batch_size,
        shuffle=True,
        device=device,
        num_workers=0
    )

    model = DNoizeConvTasNet(channels=96, num_blocks=4).to(device)

    learn = Learner(
        dls,
        model,
        loss_func=SISNRLoss(),
        opt_func=AdamW,
        metrics=[],
        cbs=[SaveModelCallback(monitor='loss', fname=save_path)]
    )

    # LR finder (optional but recommended)
    learn.lr_find()
    print("Suggested LR:", learn.recorder.min_lr)

    # Train
    learn.fit_one_cycle(epochs, lr_max=lr)

    # Save final model
    learn.save(save_path)
    print(f"Training finished. Best model saved as: {save_path}")

In [None]:
noisy_dir = "/path/to/noisy_wavs"
clean_dir = "/path/to/clean_wavs"

train_fastai(noisy_dir, clean_dir, epochs=80, batch_size=8)