This notebook defines a simple model of autoencoder type, trains it on the train part of a dataset of pairs of clean and degraded audio, and test the result on the test part of the dataset. Be careful to run this code you need a path pointint to the clean data and one pointing to the degraded data.

In [None]:
import torch
import torch.nn as nn
import torchaudio
import torch.nn.functional as F
import os
from tqdm import tqdm
from tqdm.notebook import tqdm
import numpy as np
import shutil
from torch.utils.data import Dataset, DataLoader

# Defining model

In [None]:
class SimpleConvAutoencoder(nn.Module):
    def __init__(self):
        super(SimpleConvAutoencoder, self).__init__()

        # --- ENCODER ---
        # Each layer reduces the spatial dimension by 2
        self.enc_conv1 = nn.Sequential(nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1), nn.ReLU())
        self.enc_conv2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), nn.ReLU())
        self.enc_conv3 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), nn.ReLU())

        # --- DECODER ---
        # Each layer increases the spatial dimension by 2
        self.dec_conv1 = nn.Sequential(nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1), nn.ReLU())
        self.dec_conv2 = nn.Sequential(nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1), nn.ReLU())
        self.dec_conv3 = nn.Sequential(nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1), nn.Sigmoid())

    def forward(self, x):
        # Pass through Encoder
        x1 = self.enc_conv1(x)
        x2 = self.enc_conv2(x1)
        x3 = self.enc_conv3(x2)

        # Pass through Decoder with cropping
        x4 = self.dec_conv1(x3)
        x4 = self.crop(x4, x2.shape)

        x5 = self.dec_conv2(x4)
        x5 = self.crop(x5, x1.shape)

        x6 = self.dec_conv3(x5)
        output = self.crop(x6, x.shape)

        return output

    def crop(self, tensor_to_crop, target_shape):
        """
        Crop a tensor to match the spatial dimensions of a target shape.
        Args:
            tensor_to_crop (Tensor): The tensor to be cropped.
            target_shape (tuple): The target shape to crop to (should be a 4D shape: [batch, channel, height, width]).
        Returns:
            Tensor: The cropped tensor.
        """
        target_height, target_width = target_shape[2], target_shape[3]
        current_height, current_width = tensor_to_crop.shape[2], tensor_to_crop.shape[3]

        delta_h = current_height - target_height
        delta_w = current_width - target_width

        h_start = delta_h // 2
        w_start = delta_w // 2

        return tensor_to_crop[:, :, h_start : h_start + target_height, w_start : w_start + target_width]


def process_file_to_tensor(file_path, device, n_fft=1024):
    """
    Loads an audio file and converts it to a normalized 3D magnitude spectrogram tensor.
    Args:
        file_path (str): Path to the audio file.
        device (torch.device): Device to load the tensor onto.
        n_fft (int): FFT window size.
    Returns:
        Tuple[Tensor, Tensor]: Normalized magnitude spectrogram (3D tensor), phase tensor.
    """
    waveform, sr = torchaudio.load(file_path)
    waveform = waveform.to(device)

    if sr != 48000:
        waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=48000)(waveform)
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)

    window = torch.hann_window(n_fft, device=device)
    stft = torch.stft(waveform, n_fft=n_fft, hop_length=n_fft//4,
                      window=window, return_complex=True)

    magnitude = torch.abs(stft)
    phase = torch.angle(stft)

    magnitude_norm = (magnitude - magnitude.min()) / (magnitude.max() - magnitude.min())

    magnitude_norm_3d = magnitude_norm.unsqueeze(0)

    return magnitude_norm_3d, phase


def tensor_to_audio_file(output_magnitude, original_phase, file_path, n_fft=1024):
    """
    Reconstructs and saves an audio file from magnitude and phase tensors.
    Args:
        output_magnitude (Tensor): The output magnitude tensor from the model.
        original_phase (Tensor): The original phase tensor.
        file_path (str): Path to save the reconstructed audio file.
        n_fft (int): FFT window size.
    """
    target_shape = original_phase.shape
    resized_output_magnitude = F.interpolate(output_magnitude, size=(target_shape[1], target_shape[2]), mode='bilinear', align_corners=False)

    resized_output_magnitude_squeezed = resized_output_magnitude.squeeze(1)

    spectrogram_complex = torch.polar(resized_output_magnitude_squeezed, original_phase)
    waveform = torch.istft(spectrogram_complex, n_fft=n_fft, hop_length=n_fft//4)
    torchaudio.save(file_path, waveform.cpu(), 48000)

def preprocess_audio(file_path):
    """
    Preprocesses an audio file into a normalized magnitude spectrogram and phase.
    Args:
        file_path (str): Path to the audio file.
    Returns:
        Tuple[Tensor, Tensor]: Normalized magnitude spectrogram, phase tensor.
    """
    return process_file_to_tensor(file_path, device)

def postprocess_audio(output_magnitude, original_phase, n_fft=1024):
    """
    Converts the model's output magnitude and original phase back to a waveform.
    Args:
        output_magnitude (Tensor): The output magnitude tensor from the model.
        original_phase (Tensor): The original phase tensor.
        n_fft (int): FFT window size.
    Returns:
        Tensor: The reconstructed waveform, normalized.
    """
    target_shape = original_phase.shape
    resized_output_magnitude = F.interpolate(
        output_magnitude,
        size=(target_shape[1], target_shape[2]),
        mode='bilinear',
        align_corners=False
    )

    resized_output_magnitude_squeezed = resized_output_magnitude.squeeze(1)

    spectrogram_complex = torch.polar(resized_output_magnitude_squeezed, original_phase)
    waveform = torch.istft(spectrogram_complex, n_fft=n_fft, hop_length=n_fft//4)

    waveform = waveform / (waveform.abs().max() + 1e-8)
    return waveform


# Importing data

For this part you need to have a set of clean audio and a set of associated degraded audio saved somewhere

In [None]:
def prepare_audio_pairs(clean_dir, degraded_dir, local_clean_dir=None, max_pairs=1000, file_extension='.mp3'):
    """
    Prepares pairs of degraded and clean audio files for training.

    Args:
        clean_dir (str): Path to the directory containing clean (high-quality) audio files.
        degraded_dir (str): Path to the directory containing degraded audio files.
        local_clean_dir (str, optional): If provided, the clean dataset will be copied to this local directory
                                         if it does not already exist. If None, no copying is performed.
        max_pairs (int): Maximum number of pairs to return (for faster training/testing).
        file_extension (str): File extension to look for (default: '.mp3').

    Returns:
        List[Tuple[str, str]]: List of tuples, each containing (degraded_file_path, clean_file_path).
    """
    # Optionally copy the clean dataset to a local directory
    if local_clean_dir is not None:
        if os.path.exists(clean_dir):
            print(f"Copying clean dataset from {clean_dir} to {local_clean_dir} using shutil.copytree...")
            if not os.path.exists(local_clean_dir):
                shutil.copytree(clean_dir, local_clean_dir, dirs_exist_ok=True)
            else:
                print("Clean dataset already exists in the local directory.")
            clean_dir_to_use = local_clean_dir
        else:
            print(f"WARNING: Clean dataset folder not found at {clean_dir}")
            return []
    else:
        clean_dir_to_use = clean_dir

    # Check that both directories exist
    if not (os.path.exists(degraded_dir) and os.path.exists(clean_dir_to_use)):
        print("ERROR: Could not find the degraded or clean dataset directories.")
        return []

    # Find matching pairs
    data_pairs = []
    for root, _, files in os.walk(degraded_dir):
        for file in files:
            if file.endswith(file_extension):
                degraded_path = os.path.join(root, file)
                relative_path = os.path.relpath(degraded_path, degraded_dir)
                clean_path = os.path.join(clean_dir_to_use, relative_path)

                if os.path.exists(clean_path):
                    data_pairs.append((degraded_path, clean_path))

    print(f"Found {len(data_pairs)} matching pairs of audio files for training.")

    # Subsample for faster training if needed
    training_pairs = data_pairs[:max_pairs]
    print(f"Using a subset of {len(training_pairs)} pairs for this training session.")

    return training_pairs

In [None]:
class AudioSuperResDataset(Dataset):
    """
    PyTorch Dataset for audio super-resolution tasks.
    Each item is a tuple of (degraded_spectrogram, clean_spectrogram).
    """
    def __init__(self, data_pairs, preprocessor, target_length=5000):
        """
        Args:
            data_pairs (List[Tuple[str, str]]): List of (degraded_path, clean_path) pairs.
            preprocessor (callable): Function to process an audio file path into a spectrogram tensor.
            target_length (int): The fixed temporal length for all spectrograms.
        """
        self.data_pairs = data_pairs
        self.preprocessor = preprocessor
        self.target_length = target_length

    def __len__(self):
        return len(self.data_pairs)

    def __getitem__(self, idx):
        degraded_path, clean_path = self.data_pairs[idx]

        # Preprocess both degraded and clean audio files
        input_magnitude, _ = self.preprocessor(degraded_path)
        target_magnitude, _ = self.preprocessor(clean_path)

        # Ensure correct dimensions (C, F, T)
        if input_magnitude.dim() == 4:
            input_magnitude = input_magnitude.squeeze(0)
        if target_magnitude.dim() == 4:
            target_magnitude = target_magnitude.squeeze(0)

        # Fix the temporal length of both spectrograms
        input_magnitude = self.fix_length(input_magnitude, self.target_length)
        target_magnitude = self.fix_length(target_magnitude, self.target_length)

        return input_magnitude, target_magnitude

    def fix_length(self, spectrogram, target_length):
        """
        Adjusts the temporal length of a spectrogram to a fixed size.
        Pads with zeros or truncates as needed.

        Args:
            spectrogram (Tensor): Input spectrogram tensor.
            target_length (int): Desired temporal length.

        Returns:
            Tensor: Spectrogram with fixed temporal length.
        """
        current_length = spectrogram.shape[-1]

        if current_length > target_length:
            return spectrogram[..., :target_length]
        elif current_length < target_length:
            padding = target_length - current_length
            return F.pad(spectrogram, (0, padding), mode='constant', value=0)
        else:
            return spectrogram

# Example: Splitting data into train and test sets
# (Assume data_pairs is already defined, e.g., from prepare_audio_pairs)
train_pairs = data_pairs[:1000]
test_pairs = data_pairs[1000:1200]

# Create Dataset objects
train_dataset = AudioSuperResDataset(train_pairs, preprocess_audio)
test_dataset = AudioSuperResDataset(test_pairs, preprocess_audio)

# Create DataLoader objects for batching and shuffling
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False)

print("PyTorch Dataset and DataLoader are ready.")

# Training the model

In [None]:
def train_autoencoder(
    model,
    train_dataloader,
    num_epochs=5,
    learning_rate=1e-4,
    criterion=None,
    device=None,
    model_save_path="finetuned_autoencoder_simple.pth",
    print_every=20
):
    """
    Trains an autoencoder model on the provided data.

    Args:
        model (nn.Module): The autoencoder model to train.
        train_dataloader (DataLoader): DataLoader for the training data.
        num_epochs (int): Number of epochs to train for.
        learning_rate (float): Learning rate for the optimizer.
        criterion (callable, optional): Loss function. If None, uses nn.MSELoss().
        device (str or torch.device, optional): Device to use ('cuda', 'cpu', etc). If None, auto-detects.
        model_save_path (str): Path to save the trained model weights.
        print_every (int): Print loss every N batches.

    Returns:
        nn.Module: The trained model.
    """
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device)

    if criterion is None:
        criterion = nn.MSELoss()

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    model.train()

    for epoch in range(num_epochs):
        print(f"\n--- Starting Epoch {epoch+1}/{num_epochs} ---")
        for i, (input_batch, target_batch) in enumerate(tqdm(train_dataloader)):
            input_batch = input_batch.to(device)
            target_batch = target_batch.to(device)

            # Forward pass
            outputs = model(input_batch)
            loss = criterion(outputs, target_batch)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (i+1) % print_every == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_dataloader)}], Loss: {loss.item():.4f}')

    print("\n--- Training Finished! ---")

    # Save the trained model
    torch.save(model.state_dict(), model_save_path)
    print(f"Model saved to {model_save_path}")

    return model

For our project we did the training with num_epochs = 5, learning_rate = 1e-4 and criterion = nn.MSELoss().

# Testing the model on test part of dataset

In [None]:
import os
import torch
import torchaudio
from tqdm import tqdm

def evaluate_and_save_examples(
    model,
    test_dataloader,
    test_pairs,
    process_file_to_tensor,
    postprocess_audio,
    device=None,
    criterion=None,
    num_examples_to_save=3,
    output_dir=".",
    sample_rate=48000,
    print_progress=True
):
    """
    Evaluates a model on the test set and saves a few output audio examples.

    Args:
        model (nn.Module): The trained model to evaluate.
        test_dataloader (DataLoader): DataLoader for the test data.
        test_pairs (list): List of (degraded_path, clean_path) pairs, used for file naming.
        process_file_to_tensor (callable): Function to process a file path into (magnitude, phase).
        postprocess_audio (callable): Function to convert model output and phase into waveform.
        device (str or torch.device, optional): Device to use. If None, auto-detects.
        criterion (callable, optional): Loss function. If None, uses nn.MSELoss().
        num_examples_to_save (int): Number of output examples to save.
        output_dir (str): Directory to save output audio files.
        sample_rate (int): Sample rate for saving audio.
        print_progress (bool): Whether to print progress with tqdm.

    Returns:
        float: Average test loss.
        list: List of saved output filenames.
    """
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device)
    model.eval()

    if criterion is None:
        criterion = torch.nn.MSELoss()

    test_loss = 0.0
    num_examples = 0
    num_saved = 0
    saved_filenames = []

    os.makedirs(output_dir, exist_ok=True)

    with torch.no_grad():
        data_iter = tqdm(test_dataloader, desc="Test Progress") if print_progress else test_dataloader
        for i, (input_batch, target_batch) in enumerate(data_iter):
            input_batch = input_batch.to(device)
            target_batch = target_batch.to(device)
            outputs = model(input_batch)

            loss = criterion(outputs, target_batch)
            batch_size = input_batch.size(0)
            test_loss += loss.item() * batch_size
            num_examples += batch_size

            # Save up to num_examples_to_save output examples
            if num_saved < num_examples_to_save:
                to_save = min(num_examples_to_save - num_saved, batch_size)
                for j in range(to_save):
                    idx = i * batch_size + j
                    degraded_path, _ = test_pairs[idx]
                    base_name = os.path.splitext(os.path.basename(degraded_path))[0]
                    _, input_phase = process_file_to_tensor(degraded_path, device)
                    output_waveform = postprocess_audio(outputs[j].unsqueeze(0).cpu(), input_phase)
                    output_filename = os.path.join(
                        output_dir, f"test_output_example_{num_saved+1}_{base_name}.wav"
                    )
                    torchaudio.save(output_filename, output_waveform.cpu(), sample_rate)
                    saved_filenames.append(output_filename)
                    num_saved += 1

    avg_loss = test_loss / num_examples if num_examples > 0 else float('inf')
    print(f"Test Loss (MSE) on {num_examples} test examples: {avg_loss:.4f}")
    print(f"{num_examples_to_save} test output examples saved:")
    for fname in saved_filenames:
        print("-", fname)

    return avg_loss, saved_filenames