Wavenet Class

used chat for some of these - credit will come later

In [1]:
import os
from dataclasses import dataclass
from typing import List, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
from scipy.io import wavfile
import matplotlib.pyplot as plt
from tqdm import tqdm


In [2]:
class WaveUNet(nn.Module):
    def __init__(
        self,
        input_channels: int = 1,
        output_channels: int = 1,
        num_layers: int = 6,
        features: int = 24
    ) -> None:
        """
        Initializes the Wave-U-Net model for end-to-end audio source separation.

        Parameters:
            input_channels (int): Number of input channels (e.g., 1 for mono audio).
            output_channels (int): Number of output channels (e.g., 1 per source).
            num_layers (int): Depth of the encoder/decoder layers.
            features (int): Base number of feature maps in the first conv layer.
        """
        super(WaveUNet, self).__init__()

        # Encoder layers: Downsample with increasing feature maps
        self.encoders = nn.ModuleList([
            nn.Conv1d(
                in_channels=input_channels if i == 0 else features * (2 ** i),
                out_channels=features * (2 ** (i + 1)),
                kernel_size=15,
                stride=2,
                padding=7
            )
            for i in range(num_layers)
        ])

        # Decoder layers: Upsample with skip connections
        self.decoders = nn.ModuleList([
            nn.ConvTranspose1d(
                in_channels=features * (2 ** (i + 1)),
                out_channels=features * (2 ** i),
                kernel_size=15,
                stride=2,
                padding=7,
                output_padding=1
            )
            for i in range(num_layers)
        ])

        # Final output layer to match desired output channels
        self.output_layer = nn.Conv1d(
            in_channels=features,
            out_channels=output_channels,
            kernel_size=1
        )

    def forward(self, x: Tensor) -> Tensor:
        """
        Forward pass of the Wave-U-Net.

        Parameters:
            x (Tensor): Input tensor of shape (batch, channels, time).

        Returns:
            Tensor: Output tensor of shape (batch, output_channels, time).
        """
        enc_outs = []

        # Encoder path with ReLU and storing outputs for skip connections
        for encoder in self.encoders:
            x = F.relu(encoder(x))
            enc_outs.append(x)

        # Decoder path with skip connections
        for i, decoder in enumerate(self.decoders):
            x = decoder(x)
            x = F.relu(x)

            skip_out = enc_outs[-(i + 1)]

            # Match the length using interpolation if needed
            if x.shape[2] != skip_out.shape[2]:
                x = F.interpolate(
                    x,
                    size=skip_out.shape[2],
                    mode='linear',
                    align_corners=False
                )

            # Skip connection
            x = x + skip_out

        return self.output_layer(x)


Dataset Loader

In [3]:
@dataclass
class AudioPair:
    """
    A container for a mixed waveform and its corresponding target stems.

    Attributes:
        mixed_waveform (Tensor): Tensor of shape (1, time).
        target_waveforms (Tensor): Tensor of shape (num_stems, time).
    """
    mixed_waveform: torch.Tensor
    target_waveforms: torch.Tensor


class SourceSeparationDataset(Dataset):
    """
    PyTorch dataset for loading source separation audio pairs.

    Assumes each track directory contains:
        - 'mix.wav' (the full mixture)
        - 'stems/' folder with multiple stem .wav files

    Example directory structure:
        root_dir/
            Track00001/
                mix.wav
                stems/
                    S01.wav
                    S02.wav
                    ...
    """

    def __init__(self, root_dir: str) -> None:
        """
        Initialize the dataset with the root directory containing track folders.

        Parameters:
            root_dir (str): Path to the dataset root directory.
        """
        self.root_dir = root_dir
        self.track_folders: List[str] = sorted(os.listdir(root_dir))

    def __len__(self) -> int:
        """Returns the number of tracks."""
        return len(self.track_folders)

    def __getitem__(self, idx: int) -> AudioPair:
        """
        Loads the mixed waveform and its stem waveforms as tensors.

        Parameters:
            idx (int): Index of the track.

        Returns:
            AudioPair: A dataclass containing the mix and stem tensors.
        """
        track_folder = self.track_folders[idx]
        track_path = os.path.join(self.root_dir, track_folder)

        # Load mixed waveform
        mix_path = os.path.join(track_path, "mix.wav")
        _, mixed_waveform = wavfile.read(mix_path)
        mixed_waveform = mixed_waveform.astype(np.float32) / 32768.0

        # Load stem waveforms
        stems_path = os.path.join(track_path, "stems")
        stem_files = sorted([
            f for f in os.listdir(stems_path) if f.endswith(".wav")
        ])

        target_waveforms = []
        for stem_file in stem_files:
            stem_path = os.path.join(stems_path, stem_file)
            _, stem_waveform = wavfile.read(stem_path)
            stem_waveform = stem_waveform.astype(np.float32) / 32768.0
            target_waveforms.append(stem_waveform)

        # Stack stems: (num_stems, time)
        target_waveforms = np.stack(target_waveforms)

        # Convert to PyTorch tensors
        mixed_tensor = torch.tensor(mixed_waveform, dtype=torch.float32).unsqueeze(0)  # (1, time)
        stems_tensor = torch.tensor(target_waveforms, dtype=torch.float32)             # (num_stems, time)

        return AudioPair(mixed_waveform=mixed_tensor, target_waveforms=stems_tensor)


In [4]:

@dataclass
class AudioPair:
    """
    A container for a mixed waveform and its corresponding target stems.

    Attributes:
        mixed_waveform (Tensor): Tensor of shape (1, time).
        target_waveforms (Tensor): Tensor of shape (num_stems, time).
    """
    mixed_waveform: torch.Tensor
    target_waveforms: torch.Tensor


class SingleTrackDataset(Dataset):
    """
    Dataset wrapper for a single track with one mix and multiple stems.

    This is useful for inference or testing on one audio mixture.

    Directory structure:
        track_path/
            mix.wav
            stems/
                S01.wav
                S02.wav
                ...
    """

    def __init__(self, track_path: str) -> None:
        """
        Initialize with path to a single track.

        Parameters:
            track_path (str): Path to the directory containing mix.wav and stems/
        """
        self.mix_path: str = os.path.join(track_path, "mix.wav")
        stems_dir = os.path.join(track_path, "stems")

        self.stem_paths: List[str] = sorted([
            os.path.join(stems_dir, f)
            for f in os.listdir(stems_dir)
            if f.endswith(".wav")
        ])

    def __len__(self) -> int:
        """
        Always returns 1, since this dataset only wraps a single track.

        Returns:
            int: 1
        """
        return 1

    def __getitem__(self, idx: int) -> AudioPair:
        """
        Loads and returns the mix and stem waveforms as tensors.

        Parameters:
            idx (int): Ignored, always returns the single track.

        Returns:
            AudioPair: A dataclass containing the mix and stem tensors.
        """
        # Load mix waveform
        _, mixed_waveform = wavfile.read(self.mix_path)
        mixed_waveform = mixed_waveform.astype(np.float32) / 32768.0
        mixed_tensor = torch.tensor(mixed_waveform, dtype=torch.float32).unsqueeze(0)  # (1, time)

        # Load stem waveforms
        target_waveforms = []
        for stem_path in self.stem_paths:
            _, stem_waveform = wavfile.read(stem_path)
            stem_waveform = stem_waveform.astype(np.float32) / 32768.0
            target_waveforms.append(torch.tensor(stem_waveform, dtype=torch.float32))

        stems_tensor = torch.stack(target_waveforms)  # (num_stems, time)

        return AudioPair(mixed_waveform=mixed_tensor, target_waveforms=stems_tensor)


Model initialization

In [5]:
# === Hyperparameters ===
num_epochs: int = 50
batch_size: int = 8
learning_rate: float = 1e-3

# === Device Setup ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === Model, Loss, and Optimizer ===
model = WaveUNet(input_channels=1, output_channels=10).to(device)

# Mean Squared Error loss for waveform reconstruction
criterion = nn.MSELoss()

# Adam optimizer with specified learning rate
optimizer = optim.Adam(model.parameters(), lr=learning_rate)


Training Function

In [6]:

def train_step(mixed: Tensor, target: Tensor) -> float:
    """
    Perform one training step on a batch of input/output waveforms.

    Parameters:
        mixed (Tensor): Input mixed waveform of shape (batch, 1, time).
        target (Tensor): Ground truth stem waveforms of shape (batch, num_stems, time).

    Returns:
        float: The scalar loss value for this step.
    """
    model.train()
    optimizer.zero_grad()

    # Forward pass
    output = model(mixed)

    # Compute loss
    loss = criterion(output, target)

    # Backward pass and optimization
    loss.backward()
    optimizer.step()

    return loss.item()


def separate_audio(model: torch.nn.Module, mixed_audio: np.ndarray) -> np.ndarray:
    """
    Run source separation inference on a single input waveform.

    Parameters:
        model (torch.nn.Module): Trained WaveUNet model.
        mixed_audio (np.ndarray): Input mono waveform of shape (time,).

    Returns:
        np.ndarray: Output separated waveform(s) of shape (num_stems, time).
    """
    model.eval()
    with torch.no_grad():
        # Prepare input tensor: (1, 1, time)
        mixed_tensor = torch.tensor(mixed_audio, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
        separated = model(mixed_tensor)

    # Remove batch dimension and convert to numpy: (num_stems, time)
    return separated.squeeze(0).cpu().numpy()


Data loader coallate

In [7]:

def audio_pair_collate(batch: List[AudioPair]) -> AudioPair:
    """
    Collate function for batching AudioPair samples from a DataLoader.

    Stacks mixed_waveforms and target_waveforms into batched tensors.

    Parameters:
        batch (List[AudioPair]): List of AudioPair instances from the dataset.

    Returns:
        AudioPair: A new AudioPair with stacked tensors:
                   - mixed_waveform: (batch_size, 1, time)
                   - target_waveforms: (batch_size, num_stems, time)
    """
    mixed_waveforms = [item.mixed_waveform for item in batch]   # List[(1, time)]
    target_waveforms = [item.target_waveforms for item in batch]  # List[(num_stems, time)]

    # Stack along batch dimension
    mixed_batch = torch.stack(mixed_waveforms)     # (batch_size, 1, time)
    targets_batch = torch.stack(target_waveforms)  # (batch_size, num_stems, time)

    return AudioPair(mixed_waveform=mixed_batch, target_waveforms=targets_batch)


Load in training data

In [9]:

# === Project Setup ===
project_root = os.path.abspath(os.path.join(os.getcwd(), "..", ".."))
track_name = "Track00001"
track_path = os.path.join(project_root,"ECE324_PROJECT" ,"data", "raw", track_name)

# === Dataset and Dataloader ===
dataset = SingleTrackDataset(track_path)
train_loader = DataLoader(
    dataset,
    batch_size=100,
    shuffle=False,
    collate_fn=audio_pair_collate
)

# === Model Setup ===
num_stems = len(dataset[0].target_waveforms)  # Infer number of output channels
model = WaveUNet(input_channels=1, output_channels=num_stems).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = torch.nn.MSELoss()

# === Training Loop ===
print("Training started...\n")

for epoch in range(num_epochs):
    total_loss = 0.0

    for batch in train_loader:
        # AudioPair object: extract tensors
        mixed = batch.mixed_waveform.to(device)
        target = batch.target_waveforms.to(device)

        # Train step
        loss = train_step(mixed, target)
        total_loss += loss

        print(f"Batch Loss: {loss:.6f}")

    print(f"Epoch [{epoch + 1}/{num_epochs}], Total Loss: {total_loss:.6f}\n")

print("Training complete!")


Training started...



RuntimeError: Given transposed=1, weight of size [48, 24, 15], expected input[1, 1536, 60390] to have 48 channels, but got 1536 channels instead