This notebook allows to generate output with the trained autoencoder model, and some data that can be audio generated by MusicGen. Be careful you need to specify the path to the model as well as the path to the data.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import os
import glob
from tqdm.notebook import tqdm

# Defining the model

In [2]:
class SimpleConvAutoencoder(nn.Module):
    def __init__(self):
        super(SimpleConvAutoencoder, self).__init__()
        self.enc_conv1 = nn.Sequential(nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1), nn.ReLU())
        self.enc_conv2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), nn.ReLU())
        self.enc_conv3 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), nn.ReLU())
        self.dec_conv1 = nn.Sequential(nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1), nn.ReLU())
        self.dec_conv2 = nn.Sequential(nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1), nn.ReLU())
        self.dec_conv3 = nn.Sequential(nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1), nn.Sigmoid())

    def forward(self, x):
        x1 = self.enc_conv1(x)
        x2 = self.enc_conv2(x1)
        x3 = self.enc_conv3(x2)
        x4 = self.dec_conv1(x3)
        x4 = self.crop(x4, x2.shape)
        x5 = self.dec_conv2(x4)
        x5 = self.crop(x5, x1.shape)
        x6 = self.dec_conv3(x5)
        output = self.crop(x6, x.shape)
        return output

    def crop(self, tensor_to_crop, target_shape):
        target_height, target_width = target_shape[2], target_shape[3]
        current_height, current_width = tensor_to_crop.shape[2], tensor_to_crop.shape[3]
        delta_h = current_height - target_height
        delta_w = current_width - target_width
        h_start = delta_h // 2
        w_start = delta_w // 2
        return tensor_to_crop[:, :, h_start : h_start + target_height, w_start : w_start + target_width]

def process_file_to_tensor(file_path, device, n_fft=1024):
    waveform, sr = torchaudio.load(file_path)
    waveform = waveform.to(device)
    if sr != 48000:
        waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=48000)(waveform)
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    window = torch.hann_window(n_fft, device=device)
    stft = torch.stft(waveform, n_fft=n_fft, hop_length=n_fft//4, window=window, return_complex=True)
    magnitude = torch.abs(stft)
    phase = torch.angle(stft)
    magnitude_norm = (magnitude - magnitude.min()) / (magnitude.max() - magnitude.min() + 1e-8)
    magnitude_norm_3d = magnitude_norm.unsqueeze(0)
    return magnitude_norm_3d, phase

# Load trained model

In [4]:
# 4. Load the trained model
model_path = "trained_autoencoder_simple.pth"  # Change this to your model path
device = "cuda" if torch.cuda.is_available() else "cpu"

model = SimpleConvAutoencoder().to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
print(f"Loaded model from {model_path}")

Mounted at /content/drive
Loaded model from /content/drive/MyDrive/finetuned_autoencoder_simple.pth


# Load data

It can be audio generated with MusicGen

In [5]:
def list_audio_files(directory, extension=".wav"):
    return sorted(glob.glob(os.path.join(directory, f"*{extension}")))


audio_dir = "MusicGen_data"  # Change this to your data directory
audio_files = list_audio_files(audio_dir, extension=".wav")
print(f"Found {len(audio_files)} audio files.")

Found 43 audio files.


# Get output of data through trained model

In [7]:
def get_model_outputs_on_files(model, file_list, device, preprocess_fn, max_files=None):
    """
    Applies the model to a list of audio files and returns the outputs.

    Args:
        model (nn.Module): Trained autoencoder model.
        file_list (list): List of audio file paths.
        device (str or torch.device): Device to use.
        preprocess_fn (callable): Function to preprocess audio file to model input.
        max_files (int, optional): If set, only process up to this many files.

    Returns:
        outputs (list): List of model outputs (tensors).
        inputs (list): List of input tensors (for reference).
        file_names (list): List of file names processed.
    """
    model.eval()
    outputs = []
    inputs = []
    file_names = []
    with torch.no_grad():
        for idx, file_path in enumerate(tqdm(file_list, desc="Processing files")):
            if max_files is not None and idx >= max_files:
                break
            input_tensor, _ = preprocess_fn(file_path, device)
            if input_tensor.dim() == 4:
                input_tensor = input_tensor.squeeze(0)
            input_tensor = input_tensor.unsqueeze(0).to(device)  # Add batch dimension
            output = model(input_tensor)
            outputs.append(output.cpu())
            inputs.append(input_tensor.cpu())
            file_names.append(file_path)
    return outputs, inputs, file_names

In [9]:
outputs, inputs, processed_files = get_model_outputs_on_files(
    model, [first_audio_file], device, process_file_to_tensor, max_files=5
)
print(f"Processed {len(outputs)} files.")

Processing files:   0%|          | 0/1 [00:00<?, ?it/s]

Processed 1 files.


# Reconstruct the audio files

In [10]:
def postprocess_audio(output_magnitude, original_phase, n_fft=1024):
    """
    Converts the model's output magnitude and original phase back to a waveform.

    Args:
        output_magnitude (Tensor): The output magnitude tensor from the model (shape: [1, freq, time] or [batch, 1, freq, time]).
        original_phase (Tensor): The original phase tensor (shape: [1, freq, time] or [freq, time]).
        n_fft (int): FFT window size.

    Returns:
        Tensor: The reconstructed waveform, normalized.
    """
    # If output_magnitude has batch and channel dimensions, squeeze them
    if output_magnitude.dim() == 4:
        output_magnitude = output_magnitude.squeeze(0).squeeze(0)
    elif output_magnitude.dim() == 3:
        output_magnitude = output_magnitude.squeeze(0)

    # Resize output magnitude to match the original phase shape
    target_shape = original_phase.shape
    output_magnitude_resized = F.interpolate(
        output_magnitude.unsqueeze(0).unsqueeze(0),  # add batch and channel dims
        size=(target_shape[-2], target_shape[-1]),
        mode='bilinear',
        align_corners=False
    ).squeeze(0).squeeze(0)

    # Combine magnitude and phase to get the complex spectrogram
    spectrogram_complex = torch.polar(output_magnitude_resized, original_phase)

    # Inverse STFT to get waveform
    waveform = torch.istft(spectrogram_complex, n_fft=n_fft, hop_length=n_fft//4)

    # Normalize waveform
    waveform = waveform / (waveform.abs().max() + 1e-8)
    return waveform

In [12]:
for output_tensor, file_path in zip(outputs, processed_files):
    # Get the original phase
    _, phase = process_file_to_tensor(file_path, device)
    # Reconstruct the audio waveform
    waveform = postprocess_audio(output_tensor, phase)
    # Build the output filename
    base_name = os.path.basename(file_path)
    output_filename = f"reconstructed_{base_name}"
    # Save the waveform as a .wav file
    torchaudio.save(output_filename, waveform.cpu(), 48000)
    print(f"Saved reconstructed audio as {output_filename}")

Saved reconstructed audio as reconstructed_A_Balkan_brass_band_with_high_energy_and.wav
