U-Net model works with convolution and spectrograms

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class UNet(nn.Module):
    def __init__(self, input_channels=1, output_channels=4):
        super(UNet, self).__init__()
        
        def conv_block(in_ch, out_ch):
            return nn.Sequential(
                nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
                nn.ReLU(inplace=True)
            )

        self.encoder1 = conv_block(input_channels, 32)
        self.encoder2 = conv_block(32, 64)
        self.encoder3 = conv_block(64, 128)
        
        self.middle = conv_block(128, 256)
        
        self.decoder3 = conv_block(256 + 128, 128)
        self.decoder2 = conv_block(128 + 64, 64)
        self.decoder1 = conv_block(64 + 32, 32)

        self.final = nn.Conv2d(32, output_channels, kernel_size=1)

        self.pool = nn.MaxPool2d(2)

    def forward(self, x):
        e1 = self.encoder1(x)
        p1 = self.pool(e1)
        
        e2 = self.encoder2(p1)
        p2 = self.pool(e2)
        
        e3 = self.encoder3(p2)
        p3 = self.pool(e3)

        mid = self.middle(p3)

        d3 = F.interpolate(mid, scale_factor=2) 
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.decoder3(d3)

        d2 = F.interpolate(d3, scale_factor=2)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.decoder2(d2)

        d1 = F.interpolate(d2, scale_factor=2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.decoder1(d1)

        out = self.final(d1)
        return out

Spectrogram Dataset Class

In [3]:
import torch
from torch.utils.data import Dataset, DataLoader
import librosa
import numpy as np
import os
class SpectrogramDataset(Dataset):
    def __init__(self, mix_files, sr=44100, n_fft=2048, hop_length=512, instruments=["drums", "bass", "piano", "guitar"]):
        """
        Args:
            mix_files (list): List of paths to mixed WAV files.
            sr (int): Sample rate for librosa.
            n_fft (int): FFT window size.
            hop_length (int): Hop length for STFT.
            instruments (list): List of instrument names for separation.
        """
        self.mix_files = mix_files
        self.sr = sr
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.instruments = instruments

    def audio_to_spectrogram(self, audio):
        """Convert waveform to a magnitude spectrogram."""
        stft = librosa.stft(audio, n_fft=self.n_fft, hop_length=self.hop_length)
        magnitude = np.abs(stft)
        return (magnitude - magnitude.mean()) / (magnitude.std() + 1e-6)  # Normalize

    def __len__(self):
        return len(self.mix_files)

    def __getitem__(self, idx):
        mix_path = self.mix_files[idx]
        mix_audio, _ = librosa.load(mix_path, sr=self.sr, mono=True)  # Load mixed audio
        mix_spec = self.audio_to_spectrogram(mix_audio)
        mix_spec = np.expand_dims(mix_spec, axis=0)  # Add channel dimension

        return torch.tensor(mix_spec, dtype=torch.float32), mix_path  # Return path for reference

Train Model

In [11]:
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

num_epochs = 1
batch_size = 50

project_root = os.path.abspath(os.path.join(os.getcwd(), "..", ".."))
track_name = "Track00001"
track_path = os.path.join(project_root, "data", "raw", track_name, "mix.wav")

dataset = SpectrogramDataset([track_path])
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# Initialize model
num_stems = 10  # CHANGE THIS TO NOT BE HARDCODED
model = UNet(input_channels=1, output_channels=4).to(device)
criterion = nn.MSELoss()  # Loss for spectrograms
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

print("model initialized")

for epoch in range(num_epochs):
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")
    for mixed, stems in dataloader:
        
        optimizer.zero_grad()
        outputs = model(mixed)
        outputs = outputs[..., :stems.shape[-1]]
        loss = criterion(outputs, stems)
        loss.backward()
        optimizer.step()
        progress_bar.set_postfix(loss=f"{loss.item():.4f}")

    print(f"Epoch {epoch+1}, Loss: {loss.item()}")

model initialized


Epoch 1/1:   0%|          | 0/1 [03:50<?, ?batch/s]


RuntimeError: [enforce fail at alloc_cpu.cpp:114] data. DefaultCPUAllocator: not enough memory: you tried to allocate 2729878400 bytes.