## First time

In [1]:
# def dprint(expr):
#     val = eval(expr, globals(), locals())
#     print(f"{expr}: {val}")

In [2]:
# !pip install librosa torchaudio sox
# !pip install torch torchaudio tqdm einops

In [3]:
# !wget -O clean_trainset_28spk_wav.zip "https://datashare.ed.ac.uk/bitstream/handle/10283/2791/clean_trainset_28spk_wav.zip"
# !wget -O noisy_trainset_28spk_wav.zip "https://datashare.ed.ac.uk/bitstream/handle/10283/2791/noisy_trainset_28spk_wav.zip"
# !wget -O clean_testset_wav.zip "https://datashare.ed.ac.uk/bitstream/handle/10283/2791/clean_testset_wav.zip"
# !wget -O noisy_testset_wav.zip "https://datashare.ed.ac.uk/bitstream/handle/10283/2791/noisy_testset_wav.zip"

# !unzip -q clean_trainset_28spk_wav.zip -d ./VoiceBank/Clean_Train
# !unzip -q noisy_trainset_28spk_wav.zip -d ./VoiceBank/Noisy_Train
# !unzip -q clean_testset_wav.zip -d ./VoiceBank/Clean_Test
# !unzip -q noisy_testset_wav.zip -d ./VoiceBank/Noisy_Test

## Second Step

In [4]:
# import os
# import torchaudio
# import torchaudio.transforms as transforms
# import torch

# # Define paths
# base_dir = "./drive/MyDrive/Datasets/"
# clean_train_path = os.path.join(base_dir, "VoiceBank", "Clean_Train","clean_trainset_28spk_wav")
# noisy_train_path = os.path.join(base_dir, "VoiceBank", "Noisy_Train", "noisy_trainset_28spk_wav")
# output_clean = os.path.join(base_dir, "VoiceBank_processed", "Clean_Train")
# output_noisy = os.path.join(base_dir, "VoiceBank_processed", "Noisy_Train")


# os.makedirs(output_clean, exist_ok=True)
# os.makedirs(output_noisy, exist_ok=True)

# # Convert audio to Mel Spectrogram
# mel_transform = transforms.MelSpectrogram(sample_rate=16000, n_mels=80)

# def process_audio(file, input_dir, output_dir):
#     file_path = os.path.join(input_dir, file)
#     waveform, sr = torchaudio.load(file_path)

#     # Convert to Mel Spectrogram
#     mel_spec = mel_transform(waveform)

#     # Save as torch tensor
#     torch.save(mel_spec, os.path.join(output_dir, file.replace('.wav', '.pt')))

# # Process dataset
# for file in os.listdir(clean_train_path):
#     if file.endswith(".wav"):
#         process_audio(file, clean_train_path, output_clean)
#         process_audio(file, noisy_train_path, output_noisy)

# print("✅ Dataset preprocessed and saved as Mel spectrograms!")

## Main code

In [5]:
import os
import torch
base_dir = "/kaggle/input/voicebank"
clean_train_path = os.path.join(base_dir, "clean_trainset_28spk_wav","clean_trainset_28spk_wav")
noisy_train_path = os.path.join(base_dir, "noisy_trainset_28spk_wav", "noisy_trainset_28spk_wav")

output_clean = os.path.join(base_dir, "VoiceBank_processed", "Clean_Train")
output_noisy = os.path.join(base_dir, "VoiceBank_processed", "Noisy_Train")

Noise Scheduler

In [6]:
class NoiseScheduler:
    def __init__(self, timesteps=1000, s=0.008):
        self.timesteps = timesteps
        
        # Create a cosine schedule for alpha_bar
        # We compute timesteps + 1 values to derive betas for the discrete intervals.
        steps = timesteps + 1  
        t_lin = torch.linspace(0, timesteps, steps) / timesteps  # normalized [0,1]
        
        # Cosine schedule as introduced in improved diffusion:
        # f(t) = cos((t + s) / (1+s) * (pi/2))^2; normalized so f(0)=1.
        alpha_bar = torch.cos((t_lin + s) / (1 + s) * (math.pi / 2)) ** 2
        alpha_bar = alpha_bar / alpha_bar[0]  # ensure alpha_bar[0] is 1
        
        # Now, derive discrete betas from the continuous alpha_bar schedule.
        betas = []
        for t in range(timesteps):
            # Clip beta to be in a reasonable range to avoid numerical issues.
            beta = min(1 - alpha_bar[t+1] / alpha_bar[t], 0.999)
            betas.append(beta)
        self.beta = torch.tensor(betas)
        
        # Compute other parameters based on beta
        self.alpha = 1 - self.beta
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)
        self.snr = self.alpha_bar / (1 - self.alpha_bar)
    
    def _move_to_device(self, tensor, device):
        """Helper to move tensors to device while preserving gradient info"""
        return tensor.to(device, non_blocking=True, copy=False)
    
    def add_noise(self, x, t, noise=None):
        """
        Adds noise to input x at timestep t, handling device compatibility.
        Ensures all tensors are on the same device as input x.
        """
        device = x.device
        alpha_bar = self._move_to_device(self.alpha_bar, device)
    
        # Reshape for broadcasting - assuming x is in NCHW or similar format.
        sqrt_alpha_bar_t = alpha_bar[t] ** 0.5
        sqrt_alpha_bar_t = sqrt_alpha_bar_t.view(-1, 1, 1, 1)
    
        sqrt_one_minus_alpha_bar_t = (1 - alpha_bar[t]) ** 0.5
        sqrt_one_minus_alpha_bar_t = sqrt_one_minus_alpha_bar_t.view(-1, 1, 1, 1)
    
        if noise is None:
            noise = torch.randn_like(x)
    
        return sqrt_alpha_bar_t * x + sqrt_one_minus_alpha_bar_t * noise, noise
    
    def get_loss_weight(self, t):
        """
        Returns SNR-based weights with proper device handling.
        Ensures SNR tensor is on same device as timestep tensor t.
        """
        device = t.device
        snr = self._move_to_device(self.snr, device)
        return torch.clamp(snr[t], min=1.0, max=10.0)
    
    def get_beta(self, t):
        """
        Device-safe beta accessor.
        """
        device = t.device
        return self._move_to_device(self.beta, device)[t]

UNet Architecture

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_emb_dim=None):
        super().__init__()
        self.time_mlp = nn.Linear(time_emb_dim, out_channels) if time_emb_dim else None

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.norm1 = nn.GroupNorm(8, out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.norm2 = nn.GroupNorm(8, out_channels)

    def forward(self, x, time_emb=None):
        h = self.conv1(x)
        h = self.norm1(h)
        h = F.relu(h)

        # Time embedding injection
        if self.time_mlp and time_emb is not None:
            time_emb = self.time_mlp(time_emb)
            time_emb = time_emb.reshape(time_emb.shape[0], -1, 1, 1)
            h = h + time_emb

        h = self.conv2(h)
        h = self.norm2(h)
        h = F.relu(h)

        return h

class UNet2D(nn.Module):
    def __init__(self, in_channels=1, cond_channels=1, out_channels=1, features=[64, 128, 256], time_emb_dim=256):
        super().__init__()
        self.time_dim = time_emb_dim
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.ReLU(),
            nn.Linear(time_emb_dim, time_emb_dim),
        )

        # Learnable projection of spectrogram condition
        self.cond_proj = nn.Conv2d(cond_channels, in_channels, kernel_size=1)

        # Encoder (Downsampling)
        self.encoder = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2)

        prev_channels = in_channels
        for feature in features:
            self.encoder.append(ConvBlock(prev_channels, feature, time_emb_dim))
            prev_channels = feature

        # Bottleneck
        self.bottleneck = ConvBlock(features[-1], features[-1] * 2, time_emb_dim)

        # Decoder (Upsampling)
        self.decoder = nn.ModuleList()
        self.upsamples = nn.ModuleList()

        reversed_features = list(reversed(features))
        prev_channels = features[-1] * 2

        for i, feature in enumerate(reversed_features):
            self.upsamples.append(nn.ConvTranspose2d(prev_channels, feature, kernel_size=2, stride=2))
            self.decoder.append(ConvBlock(feature * 2, feature, time_emb_dim))
            prev_channels = feature

        # Final output layer
        self.final = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x, cond_spec, timestep):
        # x: [batch,1,time] → add pseudo-height for 2D conv
        x = x.unsqueeze(2)           # → [batch,1,1,time]
        # Project and fuse condition
        cond = self.cond_proj(cond_spec)
        x = x + cond

        # Compute time embedding
        t = self.time_mlp(timestep)

        # Encoder with skip connections
        skips = []
        for encoder in self.encoder:
            x = encoder(x, t)
            skips.append(x)
            x = self.pool(x)

        # Bottleneck
        x = self.bottleneck(x, t)

        # Decoder with skip connections
        skips = skips[::-1]  # Reverse for easier access

        for i, (upsample, decoder) in enumerate(zip(self.upsamples, self.decoder)):
            x = upsample(x)
            skip = skips[i]

            # Ensure matching shapes for concatenation
            if x.shape != skip.shape:
                x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=True)

            # Concatenate with skip connection
            x = torch.cat([x, skip], dim=1)

            # Apply decoder block with time embedding
            x = decoder(x, t)

        out = self.final(x)
        return out.squeeze(2)       # → [batch,out_channels,time]

DataLoader

In [8]:
import torch
from torch.utils.data import Dataset, DataLoader
import os
import torch.nn.functional as F
import torchaudio

class SpectrogramDataset(Dataset):
    def __init__(self, clean_dir, noisy_dir, sample_rate=16000, n_mels=128, n_fft=1024, hop_length=256,
                 chunk_width=None, cache_data=False):
        self.clean_dir = clean_dir
        self.noisy_dir = noisy_dir
        self.sample_rate = sample_rate
        self.file_list = [f for f in os.listdir(clean_dir) if f.endswith(".wav")]
        self.chunk_width = chunk_width  # Width in frames for random slicing
        self.cache_data = cache_data
        self.cached_data = {}

        # use magnitude (power=1.0) instead of power spectrogram
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=n_fft,
            hop_length=hop_length,
            n_mels=n_mels,
            power=1.0,  # Use magnitude spectrogram
            normalized=True,  # Normalize the output
        )

        if cache_data:
            print("Caching dataset...")
            for idx, file_name in enumerate(self.file_list):
                if idx % 100 == 0:
                    print(f"Caching: {idx}/{len(self.file_list)}")
                clean_path = os.path.join(self.clean_dir, file_name)
                noisy_path = os.path.join(self.noisy_dir, file_name)

                clean = self._process_wav(clean_path)
                noisy = self._process_wav(noisy_path)

                if self.chunk_width:
                    clean, noisy = self._random_chunk_pair(clean, noisy)

                self.cached_data[file_name] = (clean, noisy)
            print("Dataset cached successfully!")

    # def _process_wav(self, path):
    #     waveform, sr = torchaudio.load(path)
    #     if sr != self.sample_rate:
    #         resampler = torchaudio.transforms.Resample(
    #             orig_freq=sr, new_freq=self.sample_rate
    #         )
    #         waveform = resampler(waveform)
    #     # Mel spectrogram to dB-scaled [0,1]
    #     mel = self.mel_transform(waveform)
    #     with torch.no_grad():
    #         spec = 20 * torch.log10(torch.clamp(mel, min=1e-5)) - 20
    #         spec = torch.clamp((spec + 100) / 100, 0.0, 1.0)
    #     return spec

    # def _random_chunk_pair(self, clean, noisy):
    #     _, _, time_steps = clean.shape
    #     if time_steps <= self.chunk_width:
    #         # print("Oops")
    #         # Pad or truncate to match chunk_width
    #         clean = F.pad(clean, (0, self.chunk_width - time_steps)) if time_steps < self.chunk_width else clean[:, :, :self.chunk_width]
    #         noisy = F.pad(noisy, (0, self.chunk_width - time_steps)) if time_steps < self.chunk_width else noisy[:, :, :self.chunk_width]
    #         return clean, noisy
    #     # Random slicing for larger inputs
    #     start = torch.randint(0, time_steps - self.chunk_width, (1,)).item()
    #     end = start + self.chunk_width
    #     return clean[:, :, start:end], noisy[:, :, start:end]


    # def _normalize(self, tensor, eps=1e-5):
    #     mean = tensor.mean()
    #     std = tensor.std()
    #     return (tensor - mean) / (std + eps)

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_name = self.file_list[idx]
        clean_waveform, _ = torchaudio.load(os.path.join(self.clean_dir, file_name))
        noisy_waveform, _ = torchaudio.load(os.path.join(self.noisy_dir, file_name))
        # compute noisy spectrogram with manual dB scaling
        with torch.no_grad():
            mel = self.mel_transform(noisy_waveform)
            noisy_spec = 20 * torch.log10(torch.clamp(mel, min=1e-5)) - 20
            noisy_spec = torch.clamp((noisy_spec + 100) / 100, 0.0, 1.0)
        # noisy_spec = self._normalize(noisy_spec)
        return clean_waveform, noisy_waveform, noisy_spec

In [9]:
from sklearn.model_selection import train_test_split

def split_dataset(clean_dir, noisy_dir, val_ratio=0.3, cache_data=False, chunk_width=None):
    file_list = [f for f in os.listdir(clean_dir) if f.endswith(".wav")]
    train_files, val_files = train_test_split(file_list, test_size=val_ratio, random_state=42)

    train_dataset = SpectrogramDataset(
        clean_dir, noisy_dir, cache_data=cache_data, chunk_width=chunk_width
    )
    val_dataset = SpectrogramDataset(
        clean_dir, noisy_dir, cache_data=cache_data, chunk_width=chunk_width
    )

    train_dataset.file_list = train_files
    val_dataset.file_list = val_files

    return train_dataset, val_dataset

def create_dataloaders(clean_dir, noisy_dir, batch_size=8, num_workers=4, cache_data=False, chunk_width=None):
    train_dataset, val_dataset = split_dataset(
        clean_dir, noisy_dir, cache_data=cache_data, chunk_width=chunk_width
    )

    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True
    )

    return train_loader, val_loader

Training Loop

In [10]:
from tqdm import tqdm
import torch.nn.functional as F
import torch.optim as optim
import torch
import os

def train(model, train_loader, val_loader, scheduler, optimizer, epochs=10, device="cuda",
          checkpoint_interval=2, checkpoint_dir="checkpoints", resume_from_checkpoint=None,
          use_weighted_loss=True, clip_grad_norm=1.0, mixed_precision=True):

    model = model.to(device)
    scaler = torch.amp.GradScaler(enabled=mixed_precision)
    os.makedirs(checkpoint_dir, exist_ok=True)
    epoch_losses = []
    start_epoch = 0

    # Resume from checkpoint if specified
    if resume_from_checkpoint:
        checkpoint = torch.load(resume_from_checkpoint, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scaler.load_state_dict(checkpoint['scaler_state'])
        start_epoch = checkpoint['epoch']
        print(f"Resuming training from epoch {start_epoch}...")

    for epoch in range(start_epoch, epochs):
        model.train()
        running_loss = 0.0
        loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} (Training)", leave=False)

        for clean, noisy_waveform, cond_spec in loop:
            # move tensors to device
            clean = clean.to(device)    # clean waveform]
            noisy_waveform = noisy_waveform.to(device)  # noisy waveform
            cond_spec = cond_spec.to(device)    # noisy spectrogram
            batch_size = clean.size(0)

            # sample timesteps
            t = torch.randint(0, scheduler.timesteps, (batch_size,), device=device) # [batch_size,]

            # noise scheduling (unchanged)
            noise_scale = scheduler.alpha_bar[t].view(batch_size,1,1,1)
            noise_scale_sqrt = noise_scale.sqrt()
            m = (((1 - noise_scale) / noise_scale.sqrt()) ** 0.5).view(batch_size,1,1,1)
            noise = torch.randn_like(clean)

            noisy_t, combine_noise = (
                (1-m) * noise_scale_sqrt * clean + m * noise_scale_sqrt * noisy_waveform + (1.0 - (1+m**2)*noise_scale).sqrt()*noise,
                (m * noise_scale_sqrt * (noisy_waveform - clean) + (1.0 - (1+m**2)*noise_scale).sqrt()*noise) / (1.0 - noise_scale).sqrt()
            )

            print(f"noisy_t: {noisy_t.shape}, combine_noise: {combine_noise.shape}, cond_spec: {cond_spec.shape}")
            # forward with spectrogram condition
            predicted = model(noisy_t, cond_spec, t)
            print(f"predicted: {predicted.shape}")

            loss = F.mse_loss(combine_noise, predicted.squeeze(1))

            optimizer.zero_grad()
            scaler.scale(loss).backward()

            if clip_grad_norm > 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)

            scaler.step(optimizer)
            scaler.update()
            running_loss += loss.item()
            loop.set_postfix(loss=loss.item())

        avg_train_loss = running_loss / len(train_loader)
        print(f"Epoch {epoch+1} Training Loss: {avg_train_loss:.4f}")

        # Validation Step
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} (Validation)", leave=False):
                clean = batch[0].to(device) if isinstance(batch, (list, tuple)) else batch.to(device)
                batch_size = clean.size(0)
                indices = torch.randperm(scheduler.timesteps, device=device)
                t = indices[:batch_size] if batch_size <= scheduler.timesteps else indices.repeat(batch_size // len(indices) + 1)[:batch_size]

                with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
                    noisy_t, _ = scheduler.add_noise(clean, t)
                    predicted_x0 = model(noisy_t, t)
                    val_loss += F.mse_loss(predicted_x0, clean).item()

        avg_val_loss = val_loss / len(val_loader)
        print(f"Epoch {epoch+1} Validation Loss: {avg_val_loss:.4f}")

        epoch_losses.append({'epoch': epoch+1, 'train_loss': avg_train_loss, 'val_loss': avg_val_loss})

        if (epoch + 1) % checkpoint_interval == 0:
            checkpoint_path = os.path.join(checkpoint_dir, f"UNet_21_04_{epoch+1}.pth")
            torch.save({
                'epoch': epoch+1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': avg_train_loss,
                'val_loss': avg_val_loss,
                'scaler_state': scaler.state_dict(),
            }, checkpoint_path)

    return epoch_losses

In [11]:
# Initialize model, optimizer, and scheduler
device = "cuda" if torch.cuda.is_available() else "cpu"
model = UNet2D(in_channels=1, out_channels=1, features=[64, 128, 256], time_emb_dim=256)

if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    model = torch.nn.DataParallel(model)

model = model.to(device)

scheduler = NoiseScheduler()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

# Recommended chunk width (128 frames ≈ 2 seconds of audio)
chunk_width = 128
num_workers = min(8, os.cpu_count())

# Create train and validation dataloaders with chunking
train_loader, val_loader = create_dataloaders(
    clean_dir=clean_train_path,
    noisy_dir=noisy_train_path,
    batch_size=64,
    num_workers=num_workers,
    cache_data=False,
    chunk_width=chunk_width
)

# Train the model
losses = train(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    scheduler=scheduler,
    optimizer=optimizer,
    epochs=50,
    device=device,
    checkpoint_interval=2,
    checkpoint_dir="./checkpoints",
    use_weighted_loss=True,
    clip_grad_norm=1.0, 
    # resume_from_checkpoint="./checkpoints/UNet_New_11_04_30.pth"
)

# Saving Epoch_losses for future use and concatenation purpose
import pickle
from datetime import datetime

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

# Save losses with a timestamped filename
with open(f"losses_{timestamp}.pkl", 'wb') as f:
    pickle.dump(losses, f)

                                                              

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/worker.py", line 351, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
    return self.collate_fn(data)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/collate.py", line 398, in default_collate
    return collate(batch, collate_fn_map=default_collate_fn_map)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/collate.py", line 211, in collate
    return [
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/collate.py", line 212, in <listcomp>
    collate(samples, collate_fn_map=collate_fn_map)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/collate.py", line 155, in collate
    return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/collate.py", line 272, in collate_tensor_fn
    return torch.stack(batch, 0, out=out)
RuntimeError: stack expects each tensor to be equal size, but got [1, 131074] at entry 0 and [1, 161620] at entry 1


In [None]:
import matplotlib.pyplot as plt

# Extract training and validation losses
epochs = [entry['epoch'] for entry in losses]
train_losses = [entry['train_loss'] for entry in losses]
val_losses = [entry['val_loss'] for entry in losses]

# Plot the losses
plt.figure(figsize=(10, 6))
plt.plot(epochs, train_losses, label='Training Loss', marker='o')
plt.plot(epochs, val_losses, label='Validation Loss', marker='o')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss Over Epochs')
plt.legend()
plt.grid(True)
plt.show()