In [1]:
# Concepts we will cover today

# 1. controllable generative modeling of electric guitar
# 2. the "domain shift" problem
# 3. the possibilities for future work with this code and dataset

# Other concepts covered in this notebook:
# A. Variational Auto-encoder (VAE)
# B. how to train, save, and rebuild a pytorch model
# C. griffin-lim algorithm (audio reconstruction from spectrogram)
# D. Electri guitar pitch and effects

In [2]:
# Setup and installation
# Install the mirdata library for accessing music datasets
!pip install mirdata
# https://www.youtube.com/watch?v=ebddOjzolkc <- see my video

# Core imports for numerical operations and deep learning
import torch  # PyTorch for building and training neural networks
import numpy as np  # NumPy for numerical operations

# PyTorch utilities for data loading and processing
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F  # Functions module from PyTorch

# Libraries for audio processing and visualization
import mirdata  # Access to music datasets
import librosa  # Library for audio and music analysis
import librosa.display  # Specific librosa functionalities for visualizations

# Neural network module from PyTorch
from torch import nn  # Importing the neural network module

Collecting mirdata
  Downloading mirdata-0.3.8-py3-none-any.whl (17.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.2/17.2 MB[0m [31m25.5 MB/s[0m eta [36m0:00:00[0m
Collecting black>=23.3.0 (from mirdata)
  Downloading black-24.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m66.3 MB/s[0m eta [36m0:00:00[0m
Collecting Deprecated>=1.2.14 (from mirdata)
  Downloading Deprecated-1.2.14-py2.py3-none-any.whl (9.6 kB)
Collecting jams>=0.3.4 (from mirdata)
  Downloading jams-0.3.4.tar.gz (51 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.3/51.3 kB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting pretty-midi>=0.2.10 (from mirdata)
  Downloading pretty_midi-0.2.10.tar.gz (5.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.6/5.6 MB[0m [31m65.1 MB/

In [None]:
# Importing necessary library for Google Drive access
from google.colab import drive

# Mounting Google Drive to the Colab environment
# This allows us to access files stored in Google Drive directly from this notebook.
# The 'force_remount=True' parameter ensures the Drive is remounted to reflect any recent changes.
drive.mount('/content/drive', force_remount=True)

# Creating a directory for the dataset if it doesn't already exist
# This ensures that we have a dedicated place to store our guitar sound dataset.
!mkdir -p /content/drive/MyDrive/mirdatasets/egfxset

# Initializing the dataset with mirdata
# 'egfxset' is specified as the dataset to initialize, and its location is set to the previously created directory.
egfxset = mirdata.initialize('egfxset', data_home='/content/drive/MyDrive/mirdatasets/egfxset')

# Downloading the dataset

##########################################################################
# UNCOMMENT these lines if you need to download and validate the dataset #
##########################################################################
# egfxset.download()
# egfxset.validate()

In [5]:
  #    Dataset Initialization
  #
  #
  #    +---------------------+
  #    |     Guitar Tones    | (4 second long)
  #    +---------------------+
  #              |
  #              v
  #    Pitch Calculation: 0 - 46
  #              |
  #              v
  #    +------------------------+
  #    | Compute Melspectrogram |
  #    +------------------------+
  #              |
  #              v
  #    Torch Tensors Conversion
  #              |
  #              v
  #    DataLoader; Train/Test Split (string number 3 is saved for testing)
  #              |
  #              v
  #       Model Training

In [6]:
# Custom dataset class for converting guitar sounds to torch tensors
class Data(Dataset):
    """
    A custom PyTorch Dataset for loading and processing guitar sounds.
    Filters guitar tones based on specified string numbers, computes melspectrograms,
    and converts data to torch tensors for model training.
    """
    def __init__(self, dataset_name, split_strings, hop_length=256, n_fft=512, dur=4.9, data_home='/content/drive/MyDrive/mirdatasets/egfxset'):
        """
        Initializes the dataset by filtering guitar tones based on specified strings and other parameters.
        """
        self.dataset = mirdata.initialize(dataset_name, data_home=data_home)
        self.split_strings = split_strings
        self.guitar_tones = []

        for t in self.dataset.track_ids:
            stringfret = self.dataset.track(t).stringfret_tuple
            if not stringfret:
                continue  # Skip tracks without string-fret tuple annotation
            string, fret = stringfret
            if string in self.split_strings:
                if fret > 22:
                    continue # skipping another bug in the metadata
                # Accurate pitch calculation considering the guitar's unique tuning
                inv_string = 6 - string
                pitch = fret
                for i in range(inv_string+1):
                    if i == 0:
                        continue
                    if i != 4:  # Adding 5 for all strings except the transition from G to B (3rd to 2nd string)
                        pitch += 5
                    else:  # The transition from the G string to the B string adds 4 instead of 5
                        pitch += 4
                self.guitar_tones.append({'track_name': t, 'pitch': pitch})

        self.dur = dur
        self.hop_length = hop_length
        self.n_fft = n_fft

    def __getitem__(self, index):
        """
        Load the audio signal for a given index and compute its melspectrogram.
        """
        x, fs = self.dataset.track(self.guitar_tones[index]['track_name']).audio
        x = x[:int(fs * self.dur)]  # Trim audio to specified duration

        pitch = self.guitar_tones[index]['pitch']
        S = np.abs(librosa.stft(x, n_fft=self.n_fft, hop_length=self.hop_length))  # Spectrogram computation

        return torch.from_numpy(S / np.max(S)), torch.from_numpy(np.array([pitch]).astype(np.float32))

    def __len__(self):
        return len(self.guitar_tones)

# Configuration parameters for dataset preparation
DATASET_NAME = 'egfxset'
TRAIN_STRINGS = [1, 2, 4, 5, 6]  # Guitar strings included in the training set
TEST_STRINGS = [3]  # Guitar string included in the test set
HOP_LENGTH = 512  # Hop length for STFT
N_FFT = 1024  # Number of FFT components

batch_size = 32  # Batch size for data loading

# Instantiate and setup DataLoader for training and test datasets
train_data = Data(DATASET_NAME, TRAIN_STRINGS, n_fft=N_FFT, hop_length=HOP_LENGTH)
train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True, num_workers=2)

test_data = Data(DATASET_NAME, TEST_STRINGS, n_fft=N_FFT, hop_length=HOP_LENGTH)
test_loader = DataLoader(dataset=test_data, batch_size=batch_size, shuffle=False, num_workers=2)

# Quick check to ensure data loading is working as expected
for batch, (X, m) in enumerate(test_loader):
    print(f"Batch: {batch + 1}")
    print(f"X shape: {X.shape}")  # Spectrogram shape
    print(f"m shape: {m.shape}")  # Pitch labels shape
    break


Batch: 1
X shape: torch.Size([32, 513, 460])
m shape: torch.Size([32, 1])


In [7]:
# Setting the index for the data point to be heard
nth_datapoint = 0

# Necessary imports for displaying audio within a Jupyter notebook
from IPython.display import Audio, display

# Extracting the spectrogram for the nth data point from the dataset and converting it to a numpy array for processing
y = train_data[nth_datapoint][0].numpy()

# Printing the pitch information for the nth data point to understand its musical properties
print(f"Pitch label for the selected data point: {train_data[nth_datapoint][1].numpy()}")

# Reconstructing the audio from the spectrogram using the Griffin-Lim algorithm
y_inv = librosa.griffinlim(y, n_fft=N_FFT, hop_length=HOP_LENGTH)

# Displaying an audio player for the reconstructed signal to allow auditory evaluation
display(Audio(y_inv, rate=48000))  # Assumes the audio sampling rate is 48000 Hz

# Printing the shape of the spectrogram to provide insights into the dimensions of input data for the model
print(f"Spectrogram shape: {y.shape} (Frequency Bins, Time Frames)")

# Displaying the total number of data points in the dataset to give a sense of its size
print(f"Total number of data points in the dataset: {train_data.__len__()}")


Pitch label for the selected data point: [24.]


Spectrogram shape: (513, 460) (Frequency Bins, Time Frames)
Total number of data points in the dataset: 6876


In [8]:
#   Input Spectrogram
#          |
#          v
#     +-----------+
#     |  Flatten  |
#     +-----------+
#          |
#          v
#     +-----------+
#     |  Encoder  |
#     +-----------+
#          |
#     ----- -----
#      [μ]  [σ^2]
#     ----- -----
#          |
#          v
#  +----------------+ +-------------------+
#  |     Latent     | |   Concatenate     |
#  |     Space      | |     with f0       |
#  +----------------+ +-------------------+
#                    |
#                    v
#               +-----------+
#               |  Decoder  |
#               +-----------+
#                    |
#                    v
#         Reconstructed Spectrogram
#                    |
#                    v
#             +-------------+
#             | Griffin-Lim |
#             |  Algorithm  |
#             +-------------+
#                    |
#                    v
#          Generated Guitar Sound

In [43]:
# defining the VAE

class VAE(nn.Module):
    def __init__(self, x_dim, h_dim1, h_dim2, z_dim):
        super(VAE, self).__init__()

        # Encoder network
        self.fc1 = nn.Linear(x_dim, h_dim1)
        self.fc2 = nn.Linear(h_dim1, h_dim2)
        self.fc31 = nn.Linear(h_dim2, z_dim)  # Mean μ of z
        self.fc32 = nn.Linear(h_dim2, z_dim)  # Log variance σ^2 of z

        # Decoder network
        self.fc4 = nn.Linear(z_dim+4, h_dim2)  # Note: Adjust the input size based on your f0 concatenation strategy
        self.fcf0 = nn.Linear(1, 4)  # Transform f0 to match the dimensionality for concatenation
        self.fc5 = nn.Linear(h_dim2, h_dim1)
        self.fc6 = nn.Linear(h_dim1, x_dim)

    def encoder(self, x):
        """Encodes the input by passing through the encoder network and returns latent codes."""
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return self.fc31(h), self.fc32(h)  # Returns mean and log variance

    def sampling(self, mu, log_var):
        """Reparameterization trick by sampling from an isotropic unit Gaussian."""
        std = torch.exp(0.5 * log_var)  # Standard deviation σ
        eps = torch.randn_like(std)  # Sampling ε
        return eps.mul(std).add_(mu)  # Returns z sample

    def decoder(self, z, f0):
        """Decodes the latent space back into input space."""
        z = torch.cat((z, self.fcf0(f0)), -1)  # Concatenate z with transformed f0
        h = F.relu(self.fc4(z))
        h = F.relu(self.fc5(h))
        return F.relu(self.fc6(h))

    def forward(self, x):
        """Defines the forward pass of the VAE."""
        x, f0 = x  # Unpack the input tuple
        mu, log_var = self.encoder(x.view(-1, 513*460))  # Flatten x and encode
        z = self.sampling(mu, log_var)  # Sample z
        return self.decoder(z, f0), mu, log_var  # Decode z and return reconstruction along with μ and σ^2

# Model instantiation with specified dimensions
vae = VAE(x_dim=513*460, h_dim1=512, h_dim2=256, z_dim=64)

# Moving the model to GPU if available
if torch.cuda.is_available():
    vae.cuda()

In [45]:
optimizer = torch.optim.Adam(vae.parameters(), lr=0.001)

def loss_function(recon_x, x, mu, log_var):
    # Calculate the mean squared error loss
    MSE = F.mse_loss(recon_x, x.view(-1, 513*460), reduction='sum')

    # KL divergence between the posterior and a standard normal distribution
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

    # Total loss is the sum of MSE and KLD
    return MSE + KLD

In [46]:
def train(epoch):
    vae.train()  # Set the model to training mode
    train_loss = 0  # Initialize the total loss

    # Iterate over batches of data in the training loader
    for batch_idx, (data, note) in enumerate(train_loader):
        # Move data to GPU if available
        data = data.cuda()
        note = note.cuda()

        # Clear the gradients of all optimized variables
        optimizer.zero_grad()

        # Forward pass: compute predicted outputs by passing inputs to the model
        recon_batch, mu, log_var = vae((data, note))

        # Calculate the loss
        loss = loss_function(recon_batch, data, mu, log_var)

        # Backward pass: compute gradient of the loss with respect to model parameters
        loss.backward()

        # Update the model parameters
        optimizer.step()

        # Update the total loss
        train_loss += loss.item()

        # Print log info
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item() / len(data):.6f}')

    # Print average loss for the epoch
    print(f'====> Epoch: {epoch} Average loss: {train_loss / len(train_loader.dataset):.4f}')

In [47]:
def test():
    vae.eval()  # Set the model to evaluation mode
    test_loss = 0  # Initialize the total loss for the test set

    # Disables gradient calculation to save memory and computations, which is beneficial for evaluation
    with torch.no_grad():
        for data, note in test_loader:  # Iterate over test data
            data, note = data.cuda(), note.cuda()  # Move data to GPU if available

            # Forward pass: compute predicted outputs by passing inputs to the model
            recon, mu, log_var = vae((data, note))

            # Sum up batch loss
            test_loss += loss_function(recon, data, mu, log_var).item()

    # Compute the average loss over all test data
    test_loss /= len(test_loader.dataset)

    # Print the average test loss
    print(f'====> Test set loss: {test_loss:.4f}')


In [None]:
# training loop
for epoch in range(1, 10):
    train(epoch)
    test()

In [60]:
# Save only the model parameters
torch.save(vae.state_dict(), '/content/drive/MyDrive/vae_model_state_dict.pth')

In [65]:
## uncomment to download Iran's pretrained weights
# !gdown 1-1fxTb3yl22ZWsSxmnWY6n8WTKpYTeAS

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Reconstruct the model
vae = VAE(x_dim=513*460, h_dim1=512, h_dim2=256, z_dim=64).to(device)

# Load the model parameters into the reconstructed model
state_dict = torch.load('/content/drive/MyDrive/vae_model_state_dict.pth',map_location=device)
vae.load_state_dict(state_dict)

Downloading...
From (original): https://drive.google.com/uc?id=1-1fxTb3yl22ZWsSxmnWY6n8WTKpYTeAS
From (redirected): https://drive.google.com/uc?id=1-1fxTb3yl22ZWsSxmnWY6n8WTKpYTeAS&confirm=t&uuid=60cf45c2-5048-4a99-bcd1-3c5f3cb15cec
To: /content/vae_model_state_dict.pth
100% 969M/969M [00:07<00:00, 131MB/s] 


<All keys matched successfully>

In [69]:
pitch = float(12)

# Disable gradient calculations for inference
with torch.no_grad():
    # Generate a random latent vector 'z' and set 'f0' for pitch control
    z = torch.randn(1, 64).to(device)  # Random latent vector
    f0 = torch.tensor([[pitch]]).to(device)  # Pitch control tensor

    # Decode 'z' and 'f0' to produce a spectrogram
    sample = vae.decoder(z, f0).to(device)

# Reshape the output to match the expected spectrogram shape
sample = sample.view([-1, 513, 460])

# Convert the tensor from GPU to CPU, then to a NumPy array for Griffin-Lim
sample_np = sample.cpu().numpy()[0]

# Use Griffin-Lim algorithm to convert spectrogram to time-domain waveform
y_inv = librosa.griffinlim(sample_np, n_fft=N_FFT, hop_length=HOP_LENGTH)

# Use IPython.display.Audio to play the generated audio
display(Audio(y_inv, rate=48000))