Wavenet Class

used chat for some of these - credit will come later

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import librosa


class WaveUNet(nn.Module):
    def __init__(self, input_channels=1, output_channels=1, num_layers=6, features=24):
        super(WaveUNet, self).__init__()

        # Encoder
        self.encoders = nn.ModuleList([
            nn.Conv1d(input_channels if i == 0 else features * (2**i), 
                      features * (2**(i+1)), kernel_size=15, stride=2, padding=7)
            for i in range(num_layers)
        ])

        # Decoder
        self.decoders = nn.ModuleList([
            nn.ConvTranspose1d(features * (2**(i+1)), 
                               features * (2**i), kernel_size=15, stride=2, padding=7, output_padding=1)
            for i in range(num_layers)
        ])

        # Final output layer
        self.output_layer = nn.Conv1d(features, output_channels, kernel_size=1)

    def forward(self, x):
        enc_outs = []

        # Encoder Pass
        for encoder in self.encoders:
            x = F.relu(encoder(x))
            enc_outs.append(x)
            print(f"Encoder output shape: {x.shape}")  # Debugging shape

        # Decoder Pass
        for i, decoder in enumerate(self.decoders):
            print(f"Decoder input shape before skip connection: {x.shape}")
            print(f"Before decoding, input shape: {x.shape}")
            x = decoder(x)
            print(f"After decoding, output shape: {x.shape}")
            x = F.relu(x)
            print(f"Decoder output shape: {x.shape}")
            if x.shape != enc_outs[-(i+1)].shape:
                # Adjust the size using interpolate if necessary
                print("hello")
                x = F.interpolate(x, size=enc_outs[-(i+1)].shape[2], mode='linear', align_corners=False)
            
            # Ensure decoder output matches encoder output size
            x += enc_outs[-(i+1)]  # Skip connection
            print(f"Shape after skip connection: {x.shape}")

        # Final Output
        return self.output_layer(x)

Dataset Loader

In [16]:
from dataclasses import dataclass
from torch.utils.data import Dataset
from scipy.io import wavfile
import numpy as np
import os

@dataclass
class AudioPair:
    mixed_waveform: torch.Tensor
    target_waveforms: torch.Tensor  # Multiple stems

class SourceSeparationDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.track_folders = sorted(os.listdir(root_dir))  # List of tracks (Track00001, Track00002, ...)

    def __len__(self):
        return len(self.track_folders)

    def __getitem__(self, idx):
        track_folder = self.track_folders[idx]
        track_path = os.path.join(self.root_dir, track_folder)

        # Paths to mix and stems
        mix_path = os.path.join(track_path, "mix.wav")
        stems_path = os.path.join(track_path, "stems")  

        # Load mixed waveform
        _, mixed_waveform = wavfile.read(mix_path)
        mixed_waveform = mixed_waveform.astype(np.float32) / 32768.0

        # Load all stem waveforms
        stem_files = sorted([f for f in os.listdir(stems_path) if f.endswith(".wav")])  
        target_waveforms = []
        
        for stem_file in stem_files:
            stem_path = os.path.join(stems_path, stem_file)
            _, stem_waveform = wavfile.read(stem_path)
            target_waveforms.append(stem_waveform.astype(np.float32) / 32768.0)

        # Stack stems into (num_instruments, time)
        target_waveforms = np.stack(target_waveforms)

        # Convert to PyTorch tensors
        mixed_waveform = torch.tensor(mixed_waveform, dtype=torch.float32).unsqueeze(0)  # (1, time)
        target_waveforms = torch.tensor(target_waveforms, dtype=torch.float32)  # (num_instruments, time)

        return AudioPair(mixed_waveform=mixed_waveform, target_waveforms=target_waveforms)
    


In [17]:
class SingleTrackDataset(Dataset):
    def __init__(self, track_path):
        self.mix_path = os.path.join(track_path, "mix.wav")
        self.stem_paths = sorted([
            os.path.join(track_path, "stems", f) 
            for f in os.listdir(os.path.join(track_path, "stems")) if f.endswith(".wav")
        ])
    
    def __len__(self):
        return 1  # Only one track

    def __getitem__(self, idx):
        # Load mix
        _, mixed_waveform = wavfile.read(self.mix_path)
        mixed_waveform = mixed_waveform.astype(np.float32) / 32768.0
        mixed_waveform = torch.tensor(mixed_waveform, dtype=torch.float32).unsqueeze(0)

        # Load all stems
        target_waveforms = []
        for stem_path in self.stem_paths:
            _, stem_waveform = wavfile.read(stem_path)
            stem_waveform = stem_waveform.astype(np.float32) / 32768.0
            target_waveforms.append(torch.tensor(stem_waveform, dtype=torch.float32))

        target_waveforms = torch.stack(target_waveforms)  # Shape: [num_stems, time]

        return AudioPair(mixed_waveform=mixed_waveform, target_waveforms=target_waveforms)  # Fixed the argument name


Model initialization

In [18]:
# Hyperparameters
num_epochs = 50
batch_size = 8
learning_rate = 1e-3

# Initialize model, loss, and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = WaveUNet(input_channels=1, output_channels=10).to(device)
criterion = nn.MSELoss()  # Mean Squared Error for waveform reconstruction
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

Training Function

In [19]:
def train_step(mixed, target):
    model.train()
    optimizer.zero_grad()

    # Forward pass
    output = model(mixed)

    # Compute loss
    loss = criterion(output, target)

    # Backpropagation
    loss.backward()
    optimizer.step()
    
    return loss.item()

def separate_audio(model, mixed_audio):
    model.eval()
    with torch.no_grad():
        mixed_audio = torch.tensor(mixed_audio).unsqueeze(0).unsqueeze(0)  # Add batch & channel dims
        separated = model(mixed_audio)
    return separated.squeeze(0).numpy()


Data loader coallate

In [20]:
def audio_pair_collate(batch):
    mixed_waveforms = [item.mixed_waveform for item in batch]
    target_waveforms = [item.target_waveforms for item in batch]
    
    # Stack the tensors for each field
    mixed_waveforms = torch.stack(mixed_waveforms)
    target_waveforms = torch.stack(target_waveforms)
    
    return AudioPair(mixed_waveform=mixed_waveforms, target_waveforms=target_waveforms)

Load in training data

In [38]:
project_root = os.path.abspath(os.path.join(os.getcwd(), "..", ".."))

# Construct the correct path
track_name = "Track00001"

track_path = os.path.join(project_root, "data", "raw", track_name)

dataset = SingleTrackDataset(track_path)
train_loader = DataLoader(dataset, batch_size=100, shuffle=False, collate_fn=audio_pair_collate)

# Initialize model
num_stems = len(dataset[0].target_waveforms)  # Get number of instruments
model = WaveUNet(input_channels=1, output_channels=num_stems)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = torch.nn.MSELoss()

# Training loop
for epoch in range(num_epochs):
    total_loss = 0.0
    for batch in train_loader:
        # Since batch is an AudioPair, access the attributes directly
        mixed, target = batch.mixed_waveform, batch.target_waveforms
        
        mixed, target = mixed.to(device), target.to(device)
        loss = train_step(mixed, target)
        total_loss += loss.item()
        print(f"Batch, Loss: {total_loss:.6f}")
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss:.6f}")

print("Training complete!")

Encoder output shape: torch.Size([1, 48, 1932458])
Encoder output shape: torch.Size([1, 96, 966229])
Encoder output shape: torch.Size([1, 192, 483115])
Encoder output shape: torch.Size([1, 384, 241558])
Encoder output shape: torch.Size([1, 768, 120779])
Encoder output shape: torch.Size([1, 1536, 60390])
Decoder input shape before skip connection: torch.Size([1, 1536, 60390])
Before decoding, input shape: torch.Size([1, 1536, 60390])


RuntimeError: Given transposed=1, weight of size [48, 24, 15], expected input[1, 1536, 60390] to have 48 channels, but got 1536 channels instead