# Generation with Variational auto encoders

In this notebook, we implement and train a Variational Autoencoder (VAE) to generate musical sequences consisting of notes/chords and their corresponding durations. The VAE learns a compressed latent representation of musical patterns and is then used to reconstruct or generate new, stylistically coherent sequences. This allows for creative applications such as "jazzifying" existing melodies or generating original music from random latent vectors.

## Import Dataset and Definition of Useful functions

### Import Dataset

In [205]:
root = 'data_processed/'

In [206]:
def safe_parse_all_columns_df(df):
    """
    Parse all columns in a DataFrame to numeric, coercing errors.
    """
    df['notes'] = df['notes'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
    df['chords'] = df['chords'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
    df['velocities'] = df['velocities'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
    df['durations'] = df['durations'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
    df['offsets'] = df['offsets'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
    df['ordered_events'] = df['ordered_events'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
    return df

def load_dataframe_from_two_csvs(file1, file2):
    """
    Load and concatenate two CSV files into a single pandas DataFrame.
    """
    df1 = pd.read_csv(file1)
    df2 = pd.read_csv(file2)
    full_df = pd.concat([df1, df2], ignore_index=True)
    full_df = safe_parse_all_columns_df(full_df)

    return full_df

def save_dataframe_to_two_csvs(df, file1, file2):
    """
    Split a DataFrame in half and save it into two CSV files.
    """
    halfway = len(df) // 2
    df.iloc[:halfway].to_csv(file1, index=False)
    df.iloc[halfway:].to_csv(file2, index=False)

def load_dataframe_from_one_csv(file):
    """
    Load a DataFrame from a single CSV file.
    """
    df = pd.read_csv(file)
    
    return df

def save_dataframe_to_one_csv(df, file):
    """
    Save a DataFrame to a single CSV file.
    """
    df.to_csv(file, index=True)

def load_reconstructed_events(file):
    """
    Loads the reconstructed events CSV and safely parses the 'sequence' column,
    converting notes to integers and chords to lists of integers.
    """
    df = pd.read_csv(file)

    def safe_parse(seq_str):
        try:
            parsed = ast.literal_eval(seq_str)
            if not isinstance(parsed, list):
                raise ValueError("Parsed sequence is not a list")

            normalized = []
            for el in parsed:
                if isinstance(el, list):
                    normalized.append([int(x) for x in el])
                else:
                    normalized.append(int(el))
            return normalized

        except Exception as e:
            print(f"Error parsing sequence: {seq_str}")
            raise e

    df['sequence'] = df['sequence'].apply(safe_parse)
    return df

In [208]:
file1 = root + 'data_part1.csv'
file2 = root + 'data_part2.csv'

df = load_dataframe_from_two_csvs(file1, file2)

In [209]:
reconstructed_dataset = load_reconstructed_events(root + 'reconstructed_ordered_events.csv')

In [210]:
ordered_events_with_durations = reconstructed_dataset.copy()
ordered_events_with_durations['durations'] = df['durations']

In [211]:
ordered_events_with_durations

Unnamed: 0,index,sequence,durations
0,0,"[88, 38, 38, 45, 45, 45, [44, 45], 45, 45, 38,...","[0.25, 2.5, 0.25, 0.3333, 0.3333, 0.25, 0.3333..."
1,1,"[[56, 65, 68, 80, 60], [54, 61, 65, 75], 77, 7...","[1.25, 1.6667, 0.25, 1.0, 1.0, 0.5, 0.6667, 0...."
2,2,"[46, 70, [53, 60, 62], [65, 46, 53, 60], 70, [...","[2.6667, 0.25, 1.25, 0.75, 0.75, 1.6667, 1.75,..."
3,3,"[[52, 28, 40], 64, [65, 68], [53, 55, 56], [58...","[3.75, 2.5, 1.0, 2.0, 1.6667, 1.3333, 0.75, 0...."
4,4,"[59, 59, [60, 62, 38, 54], 62, [43, 62, 50, 52...","[0.75, 0.6667, 0.6667, 0.5, 1.0, 0.25, 2.0, 0...."
...,...,...,...
2770,2770,"[61, 49, [63, 60], 65, 70, [58, 61, 66, 49], [...","[1.3333, 1.6667, 1.0, 0.5, 0.5, 0.5, 0.6667, 0..."
2771,2771,"[[45, 55, 67, 69, 72, 75], [46, 74, 56, 62, 68...","[1.0, 1.0, 0.25, 1.25, 0.3333, 0.25, 0.25, 0.7..."
2772,2772,"[[45, 57], 71, 52, 61, 64, 45, 69, [45, 69], 4...","[1.6667, 0.6667, 3.3333, 2.75, 2.25, 0.75, 0.5..."
2773,2773,"[48, 41, 60, [67, 68, 72, 79], 66, [65, 68, 72...","[0.3333, 0.3333, 0.25, 0.25, 0.3333, 0.3333, 0..."


In [212]:
mismatches = ordered_events_with_durations.apply(
    lambda row: len(row['sequence']) != len(row['durations']),
    axis=1
)

invalid_rows = ordered_events_with_durations[mismatches]


if not invalid_rows.empty:
    print("Rows with mismatched 'sequence' and 'durations':")
    print(invalid_rows)
else:
    print("All rows have matching 'sequence' and 'durations' lengths.")

All rows have matching 'sequence' and 'durations' lengths.


In [213]:
df_reconstructed = load_reconstructed_events('data_processed/reconstructed_with_durations.csv')

In [214]:
df_reconstructed=ordered_events_with_durations

### Useful functions

In [215]:
def parse_chord_to_list(chord):
    """
    Convert a chord string to a list of integers.
    """
    if isinstance(chord, str):
        print([int(x) for x in chord.split(',') if x.isdigit()])
        return [int(x) for x in chord.split(',') if x.isdigit()]
    return []

In [216]:
def reconstruct_ordered_events(df):
    """
    Reconstruct the ordered list of events (notes and chords) for each song.
    """
    sequences  = []

    for i in range(len(df)):
        idx_note = 0
        idx_chord = 0
        reconstructed = []

        for element in df['ordered_events'][i]:
            if element == 'n':
                reconstructed.append(df['notes'][i][idx_note])
                idx_note += 1
            elif element == 'c':
                parsed_chord = parse_chord_to_list(df['chords'][i][idx_chord])
                reconstructed.append(df['chords'][i][idx_chord])
                idx_chord += 1
            else:
                raise ValueError(f"Unknown event type: {e}")
        
        sequences.append(reconstructed)

    reconstructed_dataset = pd.DataFrame({'sequence': sequences})
    reconstructed_dataset.index.name = 'index'

    return reconstructed_dataset

In [217]:
df_reconstructed

Unnamed: 0,index,sequence,durations
0,0,"[88, 38, 38, 45, 45, 45, [44, 45], 45, 45, 38,...","[0.25, 2.5, 0.25, 0.3333, 0.3333, 0.25, 0.3333..."
1,1,"[[56, 65, 68, 80, 60], [54, 61, 65, 75], 77, 7...","[1.25, 1.6667, 0.25, 1.0, 1.0, 0.5, 0.6667, 0...."
2,2,"[46, 70, [53, 60, 62], [65, 46, 53, 60], 70, [...","[2.6667, 0.25, 1.25, 0.75, 0.75, 1.6667, 1.75,..."
3,3,"[[52, 28, 40], 64, [65, 68], [53, 55, 56], [58...","[3.75, 2.5, 1.0, 2.0, 1.6667, 1.3333, 0.75, 0...."
4,4,"[59, 59, [60, 62, 38, 54], 62, [43, 62, 50, 52...","[0.75, 0.6667, 0.6667, 0.5, 1.0, 0.25, 2.0, 0...."
...,...,...,...
2770,2770,"[61, 49, [63, 60], 65, 70, [58, 61, 66, 49], [...","[1.3333, 1.6667, 1.0, 0.5, 0.5, 0.5, 0.6667, 0..."
2771,2771,"[[45, 55, 67, 69, 72, 75], [46, 74, 56, 62, 68...","[1.0, 1.0, 0.25, 1.25, 0.3333, 0.25, 0.25, 0.7..."
2772,2772,"[[45, 57], 71, 52, 61, 64, 45, 69, [45, 69], 4...","[1.6667, 0.6667, 3.3333, 2.75, 2.25, 0.75, 0.5..."
2773,2773,"[48, 41, 60, [67, 68, 72, 79], 66, [65, 68, 72...","[0.3333, 0.3333, 0.25, 0.25, 0.3333, 0.3333, 0..."


In [218]:
save_dataframe_to_one_csv((df_reconstructed), root + 'reconstructed_with_durations.csv')

## Predict Events (Notes and Chords) and Durations

### Creating the data: Fixed number of events 

Idea for creating the input sequences:
- we take subsets of the list of events representing each song 
- we take the next event of each subset as corresponding training output sequences

This is easy to implement and we will have a consistent sequence lenght for batching, but we are ignoring the timing aspect.

In [219]:
class Vocabulary:
    def __init__(self, reconstructed_df):
        """
        Build vocabulary of unique single notes only.
        """
        self.notes = set()
        for i in range(len(reconstructed_df)):
            sequence = reconstructed_df['sequence'][i]
            for event in sequence:
                if isinstance(event, list):
                    for note in event:
                        self.notes.add(note)
                else:
                    self.notes.add(event)

        self.notes = sorted(self.notes)
        self.note_to_idx = {note: idx for idx, note in enumerate(self.notes)}
        self.idx_to_note = {idx: note for idx, note in enumerate(self.notes)}
        self.vocab_size = len(self.notes)

    def encode_event(self, event):
        """
        Encode an event as a multi-hot vector over single notes.
        """
        vec = np.zeros(self.vocab_size, dtype=np.float32)
        if isinstance(event, list):
            for note in event:
                vec[self.note_to_idx[note]] = 1.0
        else:
            vec[self.note_to_idx[event]] = 1.0
        return vec

    def decode_event(self, vec, threshold=0.5):
        """
        Decode multi-hot vector to list of notes.
        """
        indices = np.where(vec >= threshold)[0]
        notes = [self.idx_to_note[idx] for idx in indices]
        if len(notes) == 1:
            return notes[0]
        else:
            return notes

    def __len__(self):
        return self.vocab_size


## Create Dataset object

Here we consider also the duration while creating the dataset, we return two targets independently, related to durations and events

In [220]:
class MusicEventDataset(Dataset):
    def __init__(self, reconstructed_df, vocab, seq_length=50, max_samples_per_song=None):
        """
        Args:
            reconstructed_df: DataFrame with 'sequence' and 'durations' columns
            vocab: Vocabulary object to encode events
            seq_length: Input sequence length
            max_samples_per_song: Max number of (input, target) pairs to keep per song
        """
        self.samples = []
        self.seq_length = seq_length
        self.vocab = vocab

        for row_index in range(len(reconstructed_df)):
            sequence = reconstructed_df['sequence'][row_index]
            durations = reconstructed_df['durations'][row_index]

            if isinstance(durations, str):
                durations = ast.literal_eval(durations)

            n_events = min(len(sequence), len(durations))
            if n_events <= seq_length:
                continue

            num_added = 0
            for i in range(n_events - seq_length):
                if max_samples_per_song is not None and num_added >= max_samples_per_song:
                    break

                input_seq = sequence[i : i + seq_length]
                input_durs = durations[i : i + seq_length]
                target_event = sequence[i + seq_length]
                target_dur = durations[i + seq_length]

                self.samples.append((input_seq, input_durs, target_event, target_dur))
                num_added += 1

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        input_seq, input_durs, target_event, target_dur = self.samples[idx]

        input_encoded = [self.vocab.encode_event(event) for event in input_seq]
        input_tensor = torch.tensor(input_encoded, dtype=torch.float32)

        dur_tensor = torch.tensor(input_durs, dtype=torch.float32).unsqueeze(-1)

        target_encoded = self.vocab.encode_event(target_event)
        target_tensor = torch.tensor(target_encoded, dtype=torch.float32)

        target_dur_tensor = torch.tensor([target_dur], dtype=torch.float32)

        return input_tensor, dur_tensor, target_tensor, target_dur_tensor


In [221]:
reconstructed_dataset = load_reconstructed_events(root + 'reconstructed_ordered_events.csv')
vocab = Vocabulary(reconstructed_dataset)
dataset = MusicEventDataset(ordered_events_with_durations, vocab=vocab, seq_length=16)


Here we show how the first song is enconded, only the first 16 events. 

In [222]:
x, dur, y, target_dur = dataset[0]

print("Input sequence shape:", x.shape)
print("Duration sequence shape:", dur.shape)
print("Next event shape:", y.shape)
print("Target duration shape:", target_dur.shape)

print("Input sequence (multi-hot vectors):\n", x)
print("Duration sequence:\n", dur)
print("Next event (multi-hot vector):", y)
print("Target duration:", target_dur)


Input sequence shape: torch.Size([16, 88])
Duration sequence shape: torch.Size([16, 1])
Next event shape: torch.Size([88])
Target duration shape: torch.Size([1])
Input sequence (multi-hot vectors):
 tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
Duration sequence:
 tensor([[0.2500],
        [2.5000],
        [0.2500],
        [0.3333],
        [0.3333],
        [0.2500],
        [0.3333],
        [0.2500],
        [0.3333],
        [0.7500],
        [1.5000],
        [0.3333],
        [0.2500],
        [1.0000],
        [0.3333],
        [0.3333]])
Next event (multi-hot vector): tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,

### Defining the VAE model

In [223]:
batch_size = 32
data_loader = DataLoader(dataset, batch_size=32, shuffle=True, drop_last=True)


In [224]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [269]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class VAE(nn.Module):
    """
    Variational Autoencoder for music sequence modeling with separate encoders
    for note/chord events and durations.

    Args:
        input_dim (int): Size of the input feature dimension (e.g., 88 for piano roll).
        latent_dim (int): Dimension of the latent space.
        seq_length (int): Length of the input sequence.

    Architecture:
        - Encodes both event sequences and duration sequences.
        - Projects to latent space with reparameterization trick.
        - Decodes back to event predictions (multi-hot) and duration values (scalar).
    """
    def __init__(self, input_dim, latent_dim, seq_length):
        super(VAE, self).__init__()
        self.seq_length = seq_length
        self.latent_dim = latent_dim

        self.encoder = nn.Sequential(
            nn.Linear(input_dim * seq_length, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
        )

        self.duration_encoder = nn.Sequential(
            nn.Linear(seq_length, 64),
            nn.ReLU(),
        )

        self.fc_mu = nn.Linear(256 + 64, latent_dim)
        self.fc_var = nn.Linear(256 + 64, latent_dim)

        self.event_decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, input_dim),
            nn.Sigmoid(),  
        )

        self.duration_decoder = nn.Sequential(
            nn.Linear(latent_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
        )

    def encode(self, x, durs):
        h_seq = self.encoder(x.view(-1, self.seq_length * x.size(-1)))
        h_dur = self.duration_encoder(durs.view(-1, self.seq_length))
        h = torch.cat((h_seq, h_dur), dim=1)
        mu = self.fc_mu(h)
        log_var = self.fc_var(h)
        return mu, log_var

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        event = self.event_decoder(z)
        duration = self.duration_decoder(z)
        return event, duration

    def forward(self, x, durs):
        mu, log_var = self.encode(x, durs)
        z = self.reparameterize(mu, log_var)
        return self.decode(z), mu, log_var

input_dim = 88  
latent_dim = 20
seq_length = 16

vae = VAE(input_dim, latent_dim, seq_length).to(device)


In [226]:
import torch.optim as optim

input_dim = 88  
latent_dim = 20
seq_length = 16
vae = VAE(input_dim, latent_dim, seq_length).to(device)

optimizer = optim.Adam(vae.parameters(), lr=1e-3)

def loss_function(recon_event, event, recon_dur, dur, mu, log_var):
    """
    Computes the total VAE loss: reconstruction loss for events and durations,
    plus the KL divergence between approximate posterior and prior.

    Args:
        recon_event (Tensor): Predicted event output [B, input_dim]
        event (Tensor): Ground truth event [B, input_dim]
        recon_dur (Tensor): Predicted duration output [B, 1]
        dur (Tensor): Ground truth duration [B, 1]
        mu (Tensor): Latent mean [B, latent_dim]
        log_var (Tensor): Latent log variance [B, latent_dim]

    Returns:
        total_loss (Tensor): Scalar tensor representing total loss
    """
    recon_loss_event = F.binary_cross_entropy(recon_event, event, reduction='sum')
    recon_loss_dur = F.mse_loss(recon_dur, dur, reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return recon_loss_event + recon_loss_dur + KLD


In [227]:
num_epochs = 50

for epoch in range(num_epochs):
    """
    Training loop for the Variational Autoencoder (VAE).
    Each epoch iterates through the entire dataset using batches.
    The model learns to reconstruct musical events and durations while regularizing its latent space.
    """
    vae.train()
    train_loss = 0

    for batch_idx, (input_seq, dur_seq, target_event, target_dur) in enumerate(data_loader):


        input_seq = input_seq.to(device)
        dur_seq = dur_seq.to(device)
        target_event = target_event.to(device)
        target_dur = target_dur.to(device)
        optimizer.zero_grad()

        recon_batch, mu, log_var = vae(input_seq, dur_seq)
        recon_event, recon_dur = recon_batch

        loss = loss_function(recon_event, target_event, recon_dur, target_dur, mu, log_var)

        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % 2000 == 0:
            print(f'batch: {batch_idx}/{len(data_loader)}')

    print(f'Epoch: {epoch+1}, Loss: {train_loss / len(data_loader.dataset)}')


batch: 0/139758


batch: 2000/139758


KeyboardInterrupt: 

In [None]:
torch.save(vae.state_dict(), 'vae_model.pth')


### Using a smaller data set for training

In [245]:
ordered_events_with_durations_short = ordered_events_with_durations.iloc[:1000]

In [246]:
dataset = MusicEventDataset(
    reconstructed_df=ordered_events_with_durations,
    vocab=vocab,
    seq_length=16,
    max_samples_per_song=15  
)


In [247]:
batch_size = 32
data_loader = DataLoader(dataset, batch_size=32, shuffle=True, drop_last=True)

In [None]:
num_epochs=20
for epoch in range(num_epochs):
    """
    Training loop with KL annealing for the Variational Autoencoder (VAE).
    KL weight (beta) increases linearly from 0 to 1 over the first 10 epochs.
    """
    vae.train()
    train_loss = 0

    beta = min(1.0, epoch / 10)  

    for batch_idx, (input_seq, dur_seq, target_event, target_dur) in enumerate(data_loader):
        input_seq = input_seq.to(device)
        dur_seq = dur_seq.to(device)
        target_event = target_event.to(device)
        target_dur = target_dur.to(device)
        optimizer.zero_grad()
        recon_batch, mu, log_var = vae(input_seq, dur_seq)
        recon_event, recon_dur = recon_batch

        recon_loss_event = F.binary_cross_entropy(recon_event, target_event, reduction='sum')
        recon_loss_dur = F.mse_loss(recon_dur, target_dur, reduction='sum')
        kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

        loss = recon_loss_event + recon_loss_dur + beta * kl_loss

        loss.backward()
        train_loss += loss.item()
        optimizer.step()


    print(f'Epoch: {epoch+1}, Loss: {train_loss / len(data_loader.dataset)}')
    print(f"Epoch {epoch+1}: Recon loss = {recon_loss_event.item():.2f}, KL = {kl_loss.item():.2f}, Total loss = {loss.item():.2f}")



Epoch: 1, Loss: 9.482798246243998
Epoch 1: Recon loss = 232.11, KL = 1.08, Total loss = 256.84
Epoch: 2, Loss: 9.479013233305787
Epoch 2: Recon loss = 241.69, KL = 0.58, Total loss = 291.27
Epoch: 3, Loss: 9.493823725055003
Epoch 3: Recon loss = 215.61, KL = 0.64, Total loss = 273.34
Epoch: 4, Loss: 9.497717188524096
Epoch 4: Recon loss = 248.54, KL = 1.88, Total loss = 286.12
Epoch: 5, Loss: 9.498617762078602
Epoch 5: Recon loss = 243.18, KL = 0.64, Total loss = 273.20
Epoch: 6, Loss: 9.50224819253001
Epoch 6: Recon loss = 218.38, KL = 1.09, Total loss = 260.85
Epoch: 7, Loss: 9.505693193355386
Epoch 7: Recon loss = 218.73, KL = 1.09, Total loss = 288.14
Epoch: 8, Loss: 9.511159260823318
Epoch 8: Recon loss = 289.76, KL = 0.84, Total loss = 360.55
Epoch: 9, Loss: 9.508938914190319
Epoch 9: Recon loss = 268.31, KL = 0.40, Total loss = 322.76
Epoch: 10, Loss: 9.523902287099217
Epoch 10: Recon loss = 275.80, KL = 1.16, Total loss = 315.78
Epoch: 11, Loss: 9.518570483338024
Epoch 11: Reco

In [None]:
torch.save(vae.state_dict(), 'vae_model_KLdivsmalldata.pth')


#### Hyperparameter tuning for smaller models

In [None]:
import optuna

def objective(trial):
    """
    Objective function for Optuna hyperparameter tuning.

    Args:
        trial (optuna.trial.Trial): A single trial run to evaluate.

    Returns:
        float: Average loss over 20 batches for the current hyperparameter setting.
    """
    latent_dim = trial.suggest_categorical('latent_dim', [8, 16, 32])
    hidden_size = trial.suggest_categorical('hidden_size', [128, 256, 512])
    lr = trial.suggest_loguniform('lr', 1e-4, 1e-2)

    model = VAE(input_dim=88, latent_dim=latent_dim, seq_length=16).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    total_loss = 0
    model.train()
    for batch_idx, (x, dur, y, y_dur) in enumerate(data_loader):
        x, dur, y, y_dur = x.to(device), dur.to(device), y.to(device), y_dur.to(device)
        optimizer.zero_grad()
        recon, mu, log_var = model(x, dur)
        recon_event, recon_dur = recon
        loss = loss_function(recon_event, y, recon_dur, y_dur, mu, log_var)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        if batch_idx > 20:
            break  

    return total_loss / 20


In [None]:
study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=30)
print("Best hyperparameters:", study.best_params)


[I 2025-06-01 18:17:11,239] A new study created in memory with name: no-name-7a9213ab-5a39-4c7c-87cd-359a4f79099a
  lr = trial.suggest_loguniform('lr', 1e-4, 1e-2)
[I 2025-06-01 18:17:12,344] Trial 0 finished with value: 1909.0768676757812 and parameters: {'latent_dim': 16, 'hidden_size': 128, 'lr': 0.00023143635707429156}. Best is trial 0 with value: 1909.0768676757812.
[I 2025-06-01 18:17:13,509] Trial 1 finished with value: 663.9562545776367 and parameters: {'latent_dim': 8, 'hidden_size': 512, 'lr': 0.0024741274585548}. Best is trial 1 with value: 663.9562545776367.
[I 2025-06-01 18:17:14,334] Trial 2 finished with value: 595.090219116211 and parameters: {'latent_dim': 32, 'hidden_size': 128, 'lr': 0.0036996539293656532}. Best is trial 2 with value: 595.090219116211.
[I 2025-06-01 18:17:15,212] Trial 3 finished with value: 1104.8008758544922 and parameters: {'latent_dim': 8, 'hidden_size': 256, 'lr': 0.0007163988662801532}. Best is trial 2 with value: 595.090219116211.
[I 2025-06-0

Best hyperparameters: {'latent_dim': 16, 'hidden_size': 128, 'lr': 0.007322621104329893}


In [242]:
vae = VAE(
    input_dim=88,
    latent_dim=16,
    seq_length=16,
).to(device)

optimizer = torch.optim.Adam(vae.parameters(), lr=0.0073)


In [249]:
num_epochs=20
for epoch in range(num_epochs):
    """
    Training loop for VAE with KL annealing.
    The KL term is weighted by β, which increases linearly from 0 to 1 by epoch 10.
    """
    vae.train()
    train_loss = 0

    beta = min(1.0, epoch / 10)  

    for batch_idx, (input_seq, dur_seq, target_event, target_dur) in enumerate(data_loader):
        input_seq = input_seq.to(device)
        dur_seq = dur_seq.to(device)
        target_event = target_event.to(device)
        target_dur = target_dur.to(device)
        optimizer.zero_grad()
        recon_batch, mu, log_var = vae(input_seq, dur_seq)
        recon_event, recon_dur = recon_batch

        recon_loss_event = F.binary_cross_entropy(recon_event, target_event, reduction='sum')
        recon_loss_dur = F.mse_loss(recon_dur, target_dur, reduction='sum')
        kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

        loss = recon_loss_event + recon_loss_dur + beta * kl_loss

        loss.backward()
        train_loss += loss.item()
        optimizer.step()


    print(f'Epoch: {epoch+1}, Loss: {train_loss / len(data_loader.dataset)}')
    print(f"Epoch {epoch+1}: Recon loss = {recon_loss_event.item():.2f}, KL = {kl_loss.item():.2f}, Total loss = {loss.item():.2f}")



Epoch: 1, Loss: 259232335.18434486
Epoch 1: Recon loss = 241.23, KL = 30233.47, Total loss = 285.48
Epoch: 2, Loss: 2.47310193869874e+22
Epoch 2: Recon loss = 223.20, KL = 49445.65, Total loss = 9210.66
Epoch: 3, Loss: 2.7958120468247243e+24
Epoch 3: Recon loss = 232.35, KL = 60712.52, Total loss = 13209.80
Epoch: 4, Loss: 2.1955094555948867e+24
Epoch 4: Recon loss = 238.70, KL = 26403.71, Total loss = 8385.95
Epoch: 5, Loss: 3.4229836111250214e+25
Epoch 5: Recon loss = 199.98, KL = 149801.64, Total loss = 117298.65
Epoch: 6, Loss: 4.836802302492654e+26
Epoch 6: Recon loss = 235.46, KL = 157569.31, Total loss = 92288.76
Epoch: 7, Loss: 4.780526032072211e+19
Epoch 7: Recon loss = 253.00, KL = 305095.62, Total loss = 282386.69
Epoch: 8, Loss: 3.1706394544881197e+19
Epoch 8: Recon loss = 273.93, KL = 270564.97, Total loss = 407648.75
Epoch: 9, Loss: 1.9492545171726975e+19
Epoch 9: Recon loss = 256.39, KL = 523088.12, Total loss = 486682.78
Epoch: 10, Loss: 2.1606448106054218e+19
Epoch 10:

RuntimeError: all elements of input should be between 0 and 1

In [250]:
torch.save(vae.state_dict(), 'vae_model_hyperparameterKL.pth')


## Evaluating our models

We choose to both generate random sequences of notes as well as trying to jazzify popular songs. We also retrieve some main evaluation measures in order to picture how our model works. The genrated songs are all saved can all be listened to to perform qualitative evaluations.

#### Random generated output 

We start off by generating random outputs to check their format is coherent

In [228]:
subset_dataset = Subset(dataset, range(100))  

val_loader = DataLoader(subset_dataset, batch_size=32, shuffle=False)

def evaluate_model(model, data_loader, device):
    """
    Evaluates the VAE model on the given DataLoader.

    Args:
        model: Trained VAE model
        data_loader: DataLoader for validation/test set
        device: 'cuda' or 'cpu'
    
    Returns:
        avg_event_loss: Average binary cross-entropy loss for event reconstruction
        avg_dur_loss: Average MSE loss for duration reconstruction
        avg_kl_loss: Average KL divergence
        avg_total_loss: Sum of all components per sample
    """

    model.eval()
    total_recon_event_loss = 0
    total_recon_dur_loss = 0
    total_kl_loss = 0
    total_samples = 0

    with torch.no_grad():
        for input_seq, dur_seq, target_event, target_dur in data_loader:
            input_seq = input_seq.to(device)
            dur_seq = dur_seq.to(device)
            target_event = target_event.to(device)
            target_dur = target_dur.to(device)

            (recon_event, recon_dur), mu, log_var = model(input_seq, dur_seq)

            recon_loss_event = F.binary_cross_entropy(recon_event, target_event, reduction='sum')
            recon_loss_dur = F.mse_loss(recon_dur, target_dur, reduction='sum')

            kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

            total_recon_event_loss += recon_loss_event.item()
            total_recon_dur_loss += recon_loss_dur.item()
            total_kl_loss += kl_loss.item()
            total_samples += input_seq.size(0)

    avg_event_loss = total_recon_event_loss / total_samples
    avg_dur_loss = total_recon_dur_loss / total_samples
    avg_kl_loss = total_kl_loss / total_samples
    avg_total_loss = (total_recon_event_loss + total_recon_dur_loss + total_kl_loss) / total_samples

    print(f"Evaluation Results:")
    print(f"Event BCE Loss:   {avg_event_loss:.4f}")
    print(f"Duration MSE Loss:{avg_dur_loss:.4f}")
    print(f"KL Divergence:    {avg_kl_loss:.4f}")
    print(f"Total Loss:       {avg_total_loss:.4f}")

    return avg_event_loss, avg_dur_loss, avg_kl_loss, avg_total_loss


In [229]:
def generate_random (model, latent_dim, seq_length):
    """
    Generates a random jazzified sequence from a trained VAE model by sampling from the latent space.

    Args:
        model: Trained VAE model
        latent_dim: Dimensionality of latent space
        seq_length: Number of events to generate (not used internally but useful for control)

    Returns:
        generated_sequence: List of decoded musical events (notes/chords)
        generated_durations: List of durations for each event
    """
    generated_sequence = []
    generated_durations = []

    top_k = 5  


    for _ in range(32):  
        z = torch.randn(1, latent_dim).to(device)
        recon_event, recon_dur = vae.decode(z)

        event_vec = recon_event.squeeze().detach().cpu().numpy()
        duration_val = recon_dur.item()

        top_indices = np.argsort(event_vec)[-top_k:]
        decoded_event = [vocab.idx_to_note[idx] for idx in top_indices if event_vec[idx] > 0.05]

        if not decoded_event:
            decoded_event = [vocab.idx_to_note[np.argmax(event_vec)]]

        generated_sequence.append(decoded_event)
        generated_durations.append(duration_val)

    print("Generated sequence:", generated_sequence)
    print("Durations:", generated_durations)


#### Useful Functions

In [259]:
import torch
import numpy as np
import random
from music21 import converter, note, chord


def set_seed(seed=16):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)


def parse_midi_to_event_sequence(midi_path):
    """
    Parses a MIDI file into a list of events (note or chord),
    their durations (in quarterLength), and velocities.

    Returns:
        events: List of MIDI pitch numbers or lists of pitch numbers (for chords)
        durations: List of corresponding durations (floats)
        velocities: List of velocity values (int, 0–127)
    """
    score = converter.parse(midi_path)
    flat = score.flat.notes
    events = []
    durations = []
    velocities = []

    for el in flat:
        if isinstance(el, note.Note):
            events.append(el.pitch.midi)
            durations.append(el.duration.quarterLength)
            velocities.append(el.volume.velocity if el.volume.velocity is not None else 64)
        elif isinstance(el, chord.Chord):
            events.append([p.midi for p in el.pitches])
            durations.append(el.duration.quarterLength)
            velocities.append(el.volume.velocity if el.volume.velocity is not None else 64)

    return events, durations, velocities

def encode_input_sequence(melody, durations, seq_length, vocab):
    """
    Encodes a melody and its corresponding durations into tensor sequences suitable for VAE input.

    Args:
        melody: List of note pitches (e.g., MIDI numbers or chord representations)
        durations: List of durations (floats), aligned with the melody
        seq_length: Length of each input sequence (number of time steps)
        vocab: Vocabulary object with `encode_event()` method for converting notes to one-hot or multi-hot format

    Returns:
        encoded_sequences: List of tensors of shape [seq_length, input_dim] for notes
        encoded_durations: List of tensors of shape [seq_length, 1] for durations
    """
    if len(melody) <= seq_length:
        raise ValueError("Melody too short for the given sequence length.")

    encoded_sequences = []
    encoded_durations = []

    for i in range(len(melody) - seq_length):
        input_seq = melody[i:i+seq_length]
        input_dur = durations[i:i+seq_length]

        input_encoded = [vocab.encode_event(n) for n in input_seq]
        encoded_sequences.append(torch.tensor(input_encoded, dtype=torch.float32))
        encoded_durations.append(torch.tensor(input_dur, dtype=torch.float32).unsqueeze(-1))

    return encoded_sequences, encoded_durations

def write_midi(events, durations, filename, velocities=None, default_velocity=64):
    """
    Writes a list of musical events and durations to a MIDI file.

    Args:
        events: List of notes or chords (each element is an int or a list of ints)
        durations: List of durations in beats, aligned with events
        filename: Output filename (e.g., "output.mid")
        velocities: Optional list of velocities per event (default applied if not provided)
        default_velocity: Default velocity to use if not specified per event
    """
    mid = MidiFile()
    track = MidiTrack()
    mid.tracks.append(track)
    ticks_per_beat = mid.ticks_per_beat

    for i, (event, dur) in enumerate(zip(events, durations)):
        duration_ticks = int(dur * ticks_per_beat)
        notes = event if isinstance(event, list) else [event]

        if velocities is not None:
            velocity = velocities[i] if i < len(velocities) else default_velocity
        else:
            velocity = default_velocity

        for note in notes:
            track.append(Message('note_on', note=note, velocity=velocity, time=0))
        for j, note in enumerate(notes):
            time = duration_ticks if j == 0 else 0
            track.append(Message('note_off', note=note, velocity=velocity, time=time))

    mid.save(filename)
    print(f" MIDI saved: {filename}")


def write_midi(events, durations, filename):
    """
    Converts a list of musical events and durations into a MIDI file with fixed velocity.
    
    Args:
        events: List of note or chord events (each item is an int or list of ints).
        durations: List of durations (floats, in beats) corresponding to each event.
        filename: Output filename for the MIDI file (e.g., "output.mid").
    """
    mid = MidiFile()
    track = MidiTrack()
    mid.tracks.append(track)
    ticks_per_beat = mid.ticks_per_beat
    velocity = 64  

    for event, dur in zip(events, durations):
        duration_ticks = int(dur * ticks_per_beat)
        notes = event if isinstance(event, list) else [event]

        for note in notes:
            track.append(Message('note_on', note=note, velocity=velocity, time=0))
        for i, note in enumerate(notes):
            time = duration_ticks if i == 0 else 0
            track.append(Message('note_off', note=note, velocity=velocity, time=time))

    mid.save(filename)
    print(f"MIDI saved with velocity=64: {filename}")


def generate(encoded_sequences, encoded_durations, top_k=5, temperature=1.2):
    """
    Generates a jazzified version of an input melody using a trained VAE model.

    Args:
        encoded_sequences: List of tensors containing encoded input note sequences.
        encoded_durations: List of tensors containing corresponding duration sequences.
        top_k: Number of top candidate notes to sample from.
        temperature: Controls randomness; higher = more diverse output.

    Returns:
        jazzified_events: List of decoded note/chord events.
        jazzified_durations: List of durations corresponding to each event.
    """
    jazzified_events = []
    jazzified_durations = []

    with torch.no_grad():
        for x_seq, dur_seq in zip(encoded_sequences, encoded_durations):
            x_seq = x_seq.unsqueeze(0).to(device)
            dur_seq = dur_seq.unsqueeze(0).to(device)

            z_mu, z_logvar = vae.encode(x_seq, dur_seq)
            z = vae.reparameterize(z_mu, z_logvar)

            recon_event, recon_dur = vae.decode(z)
            event_vec = recon_event.squeeze().cpu().numpy()
            duration_val = max(0.2, recon_dur.item())

            logits = np.log(np.clip(event_vec, 1e-8, 1.0))  
            scaled_logits = logits / temperature
            probs = np.exp(scaled_logits)
            probs /= probs.sum()

            top_indices = np.argsort(probs)[-top_k:]
            top_probs = probs[top_indices]
            top_probs /= top_probs.sum()

            note_count = np.random.choice([1, 2, 3, 4, 5], p=[0.5, 0.2, 0.2, 0.05, 0.05])
            note_count = min(note_count, len(top_indices))

            try:
                selected_indices = np.random.choice(top_indices, size=note_count, p=top_probs, replace=False)
            except ValueError:
                selected_indices = [top_indices[np.argmax(top_probs)]]

            decoded_event = [vocab.idx_to_note[idx] for idx in selected_indices]

            jazzified_events.append(decoded_event)
            jazzified_durations.append(duration_val)

    return jazzified_events, jazzified_durations

def from_song_generate (original_filename, target_filename):
    """
    Loads a MIDI file, encodes it, passes it through the VAE to jazzify it,
    and saves the generated sequence as a new MIDI file.

    Args:
        original_filename (str): Path to the input MIDI file.
        target_filename (str): Path where the jazzified MIDI will be saved.
    """
    melody, durations, velocities = parse_midi_to_event_sequence(original_filename)
    encoded_sequences, encoded_durations = encode_input_sequence(
        melody, durations, seq_length=16, vocab=vocab)
    jazzified_events, jazzified_durations = generate(encoded_sequences, encoded_durations)
    write_midi(jazzified_events, jazzified_durations, target_filename)


#### Generating with general VAE

In [254]:
vae = VAE(input_dim=88, latent_dim=20, seq_length=16).to(device)
vae.load_state_dict(torch.load('notebook/vae_model_data/vae_model.pth'))
vae.eval()
root='generated_music/generated_music_vae/vae_general/'


In [255]:
generate_random(vae, 20, 16)

Generated sequence: [[60, 62], [64, 57, 65, 62, 60], [65, 57, 67, 62, 60], [64, 62, 63, 67, 60], [63, 65, 57, 62, 60], [65, 67, 57, 62, 60], [65, 62, 57, 64, 60], [57, 64, 67, 62, 60], [57, 65, 62, 60], [67, 57, 64, 62, 60], [57, 63, 62, 65, 60], [64, 55, 67, 62, 60], [65, 63, 59, 62, 60], [58, 62, 67, 64, 60], [62, 64, 57, 60, 65], [64, 65, 62, 60, 63], [65, 63, 57, 62, 60], [57, 67, 63, 62, 60], [63, 62, 57, 67, 60], [63, 67, 65, 60], [64, 63, 67, 60], [64, 67, 62, 63, 60], [72, 67, 65, 63, 60], [63, 57, 65, 67, 60], [63, 67, 60], [60], [63, 64, 65, 62, 60], [64, 57, 67, 62, 60], [55, 63, 67, 62, 60], [67, 65, 62, 57, 60], [58, 64, 57, 62, 60], [64, 65, 57, 62, 60]]
Durations: [0.8988903164863586, 0.7114179730415344, 1.1524525880813599, 1.0198121070861816, 0.8861413598060608, 1.0399314165115356, 1.0773284435272217, 1.138029932975769, 0.6716316342353821, 1.2533535957336426, 0.9123578071594238, 1.117919921875, 1.1200757026672363, 0.8795191049575806, 0.9867668747901917, 1.07492661476135

In [260]:
from_song_generate ("data/popular_songs/HappyBirthday.mid", f"{root}HappyBirthday.mid")
from_song_generate ("data/popular_songs/HipsDontLie.mid", f"{root}HipsDontLie.mid")
from_song_generate ("data/popular_songs/BackInBlack.mid", f"{root}BackInBlack.mid")
from_song_generate ("data/popular_songs/FinalCountDown.mid", f"{root}FinalCountDown.mid")


MIDI saved with velocity=64: generated_music/generated_music_vae/vae_general/HappyBirthday.mid
MIDI saved with velocity=64: generated_music/generated_music_vae/vae_general/HipsDontLie.mid
MIDI saved with velocity=64: generated_music/generated_music_vae/vae_general/BackInBlack.mid
MIDI saved with velocity=64: generated_music/generated_music_vae/vae_general/FinalCountDown.mid


In [235]:
evaluate_model(vae, val_loader, device)

Evaluation Results:
Event BCE Loss:   6.3511
Duration MSE Loss:0.3961
KL Divergence:    0.0085
Total Loss:       6.7557


(6.351111392974854,
 0.39614493429660796,
 0.008490124642848968,
 6.755746451914311)

The event BCE loss, 6.35, dominates the total loss, reflecting the challenge of reconstructing the correct multi-hot event vector. A value around 6 is typical given the sparsity and multi-label nature of the output (e.g., chords or overlapping notes). Lower values would indicate more accurate pitch reconstruction.
The duration mean squared error, 0.396, is relatively low, suggesting that the model predicts durations with good precision. Durations are scalar values and generally easier to regress than high-dimensional multi-hot vectors.
The very small value of KL divergence, 0.0085, indicates that the learned posterior distribution $q(z|x)$ is already very close to the standard Gaussian prior. While this helps stabilize the latent space, excessively low KL values may imply posterior collapse, where the model underuses the latent variable $z$.
The total loss, 6.7557 is simply the sum of the three components. The fact that the BCE term contributes ~94% of the total loss highlights that most of the model's effort is spent reconstructing pitch content, with timing and regularization playing secondary roles.



##### Generating with VAE with KL divergence on small dataset

In [271]:
vae_KL_small = VAE(input_dim=88, latent_dim=20, seq_length=16).to(device)
vae_KL_small.load_state_dict(torch.load('notebook/vae_model_data/vae_model_KLdivsmalldata.pth'))
vae_KL_small.eval()
root = 'generated_music/generated_music_vae/vae_KL_small/'
generate_random(vae_KL_small, 20, 16)
from_song_generate ("data/popular_songs/HappyBirthday.mid", f"{root}HappyBirthday.mid")
from_song_generate ("data/popular_songs/HipsDontLie.mid", f"{root}HipsDontLie.mid")
from_song_generate ("data/popular_songs/BackInBlack.mid", f"{root}BackInBlack.mid")
from_song_generate ("data/popular_songs/FinalCountDown.mid", f"{root}FinalCountDown.mid")

evaluate_model(vae_KL_small, val_loader, device)

Generated sequence: [[45, 50, 48, 88, 85], [60, 91, 35, 85, 88], [52, 88, 58, 67, 80], [21, 50, 68, 85, 71], [67, 24, 68, 102, 88], [99, 21, 42, 67, 102], [71, 35, 85, 88, 50], [85, 108, 75, 45, 93], [85, 108, 83, 88, 94], [50, 35, 83, 67, 102], [85, 108, 52, 88, 30], [93, 58, 68, 52, 102], [91, 34, 35, 62, 94], [88, 85, 83, 35, 94], [58, 30, 67, 42, 85], [50, 52, 75, 96, 88], [88, 85, 45, 102, 50], [67, 88, 58, 85, 102], [49, 89, 24, 75, 88], [35, 45, 108, 88, 85], [50, 67, 108, 40, 85], [67, 50, 42, 85, 71], [85, 76, 88, 42, 50], [96, 77, 85, 62, 35], [49, 85, 102, 58, 45], [88, 96, 85, 49, 108], [62, 80, 73, 102, 38], [45, 48, 75, 50, 85], [50, 40, 108, 75, 102], [108, 89, 35, 30, 85], [85, 67, 30, 50, 88], [76, 63, 91, 88, 50]]
Durations: [0.3971548080444336, 0.07005634903907776, 0.20556096732616425, 0.10306544601917267, 0.36613160371780396, -0.04971347004175186, 0.09934122115373611, 0.20370665192604065, 0.20289696753025055, 0.40567731857299805, 0.015450932085514069, 0.018192991614

  return self.iter().getElementsByClass(classFilterList)


MIDI saved with velocity=64: generated_music/generated_music_vae/vae_KL_small/HappyBirthday.mid
MIDI saved with velocity=64: generated_music/generated_music_vae/vae_KL_small/HipsDontLie.mid
MIDI saved with velocity=64: generated_music/generated_music_vae/vae_KL_small/BackInBlack.mid
MIDI saved with velocity=64: generated_music/generated_music_vae/vae_KL_small/FinalCountDown.mid
Evaluation Results:
Event BCE Loss:   6.3600
Duration MSE Loss:0.4433
KL Divergence:    0.0085
Total Loss:       6.8118


(6.360015144348145,
 0.4433128082752228,
 0.008490124642848968,
 6.811818077266216)

The binary cross-entropy loss for event reconstruction, 6.3832, is relatively high. This suggests the model struggles to accurately predict multi-hot note or chord vectors, possibly due to the complexity and sparsity of the target representation.
The mean squared error for duration prediction, 0.4448, is moderate. The model captures duration trends reasonably well but could benefit from additional tuning or structure. The KL divergence, 0.0085 is extremely low, indicating that the latent space is not being utilized effectively. This often points to posterior collapse, where the decoder ignores the latent code. The total loss, 6.8365, is driven almost entirely by the reconstruction losses, particularly the event loss, with very little contribution from KL divergence.