In [2]:
import os
import ast
import random
import time

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset, random_split
from sklearn.model_selection import KFold

root = 'data_processed/'

# Generation with Recurrent Neural Networks

To establish a baseline of music generation that we can improve on, we use Recurrent Neural Networks. We formulate the problem as a next-note prediciton problem. This method is quite similar to  recurrence-based language models that are used in NLP.

The input is sequential, but unlike words in NLP, timing and dynamics (duration, velocity, offset) matter a lot in music. To be able to predict notes/chords + durations + offsets + velocities we might need multi-output heads (e.g., softmax for notes/chords/velocities, regression for durations/offsets).

## Import Dataset and Definition of Useful functions

### Import Dataset

In [3]:
def safe_parse_all_columns_df(df):
    """
    Parse all columns in a DataFrame to numeric, coercing errors.
    """
    df['notes'] = df['notes'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
    df['chords'] = df['chords'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
    df['velocities'] = df['velocities'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
    df['durations'] = df['durations'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
    df['offsets'] = df['offsets'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
    df['ordered_events'] = df['ordered_events'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
    return df

def load_dataframe_from_two_csvs(file1, file2):
    """
    Load and concatenate two CSV files into a single pandas DataFrame.
    """
    df1 = pd.read_csv(file1)
    df2 = pd.read_csv(file2)
    full_df = pd.concat([df1, df2], ignore_index=True)
    full_df = safe_parse_all_columns_df(full_df)

    return full_df

def save_dataframe_to_two_csvs(df, file1, file2):
    """
    Split a DataFrame in half and save it into two CSV files.
    """
    halfway = len(df) // 2
    df.iloc[:halfway].to_csv(file1, index=False)
    df.iloc[halfway:].to_csv(file2, index=False)

def load_dataframe_from_one_csv(file):
    """
    Load a DataFrame from a single CSV file.
    """
    df = pd.read_csv(file)
    
    return df

def save_dataframe_to_one_csv(df, file):
    """
    Save a DataFrame to a single CSV file.
    """
    df.to_csv(file, index=True)

def load_reconstructed_events(file):
    """
    Loads the reconstructed events CSV and safely parses the 'sequence' column,
    converting notes to integers and chords to lists of integers.
    """
    df = pd.read_csv(file)

    def safe_parse(seq_str):
        try:
            parsed = ast.literal_eval(seq_str)
            if not isinstance(parsed, list):
                raise ValueError("Parsed sequence is not a list")

            normalized = []
            for el in parsed:
                if isinstance(el, list):
                    normalized.append([int(x) for x in el])
                else:
                    normalized.append(int(el))
            return normalized

        except Exception as e:
            print(f"Error parsing sequence: {seq_str}")
            raise e

    df['sequence'] = df['sequence'].apply(safe_parse)
    return df

In [4]:
file1 = root + 'data_part1.csv'
file2 = root + 'data_part2.csv'

df = load_dataframe_from_two_csvs(file1, file2)

In [5]:
reconstructed_dataset = load_reconstructed_events(root + 'reconstructed_ordered_events.csv')

In [6]:
ordered_events_with_durations = reconstructed_dataset.copy()
ordered_events_with_durations['durations'] = df['durations']

In [7]:
ordered_events_with_durations

Unnamed: 0,index,sequence,durations
0,0,"[88, 38, 38, 45, 45, 45, [44, 45], 45, 45, 38,...","[0.25, 2.5, 0.25, 0.3333, 0.3333, 0.25, 0.3333..."
1,1,"[[56, 65, 68, 80, 60], [54, 61, 65, 75], 77, 7...","[1.25, 1.6667, 0.25, 1.0, 1.0, 0.5, 0.6667, 0...."
2,2,"[46, 70, [53, 60, 62], [65, 46, 53, 60], 70, [...","[2.6667, 0.25, 1.25, 0.75, 0.75, 1.6667, 1.75,..."
3,3,"[[52, 28, 40], 64, [65, 68], [53, 55, 56], [58...","[3.75, 2.5, 1.0, 2.0, 1.6667, 1.3333, 0.75, 0...."
4,4,"[59, 59, [60, 62, 38, 54], 62, [43, 62, 50, 52...","[0.75, 0.6667, 0.6667, 0.5, 1.0, 0.25, 2.0, 0...."
...,...,...,...
2770,2770,"[61, 49, [63, 60], 65, 70, [58, 61, 66, 49], [...","[1.3333, 1.6667, 1.0, 0.5, 0.5, 0.5, 0.6667, 0..."
2771,2771,"[[45, 55, 67, 69, 72, 75], [46, 74, 56, 62, 68...","[1.0, 1.0, 0.25, 1.25, 0.3333, 0.25, 0.25, 0.7..."
2772,2772,"[[45, 57], 71, 52, 61, 64, 45, 69, [45, 69], 4...","[1.6667, 0.6667, 3.3333, 2.75, 2.25, 0.75, 0.5..."
2773,2773,"[48, 41, 60, [67, 68, 72, 79], 66, [65, 68, 72...","[0.3333, 0.3333, 0.25, 0.25, 0.3333, 0.3333, 0..."


In [8]:
mismatches = ordered_events_with_durations.apply(
    lambda row: len(row['sequence']) != len(row['durations']),
    axis=1
)

invalid_rows = ordered_events_with_durations[mismatches]


if not invalid_rows.empty:
    print("Rows with mismatched 'sequence' and 'durations':")
    print(invalid_rows)
else:
    print("All rows have matching 'sequence' and 'durations' lengths.")

All rows have matching 'sequence' and 'durations' lengths.


In [27]:
df_reconstructed = load_reconstructed_events('data_processed/reconstructed_with_durations.csv')

In [28]:
len(df['durations'][0])

2429

In [10]:
df_reconstructed=ordered_events_with_durations
len(df_reconstructed['sequence'][0])

2429

### Useful functions

In [11]:
def parse_chord_to_list(chord):
    """
    Convert a chord string to a list of integers.
    """
    if isinstance(chord, str):
        print([int(x) for x in chord.split(',') if x.isdigit()])
        return [int(x) for x in chord.split(',') if x.isdigit()]
    return []

In [13]:
def reconstruct_ordered_events(df):
    """
    Reconstruct the ordered list of events (notes and chords) for each song.
    """
    sequences  = []

    for i in range(len(df)):
        idx_note = 0
        idx_chord = 0
        reconstructed = []

        for element in df['ordered_events'][i]:
            if element == 'n':
                reconstructed.append(df['notes'][i][idx_note])
                idx_note += 1
            elif element == 'c':
                parsed_chord = parse_chord_to_list(df['chords'][i][idx_chord])
                reconstructed.append(df['chords'][i][idx_chord])
                idx_chord += 1
            else:
                raise ValueError(f"Unknown event type: {e}")
        
        sequences.append(reconstructed)

    reconstructed_dataset = pd.DataFrame({'sequence': sequences})
    reconstructed_dataset.index.name = 'index'

    return reconstructed_dataset

In [15]:
df_reconstructed

Unnamed: 0,index,sequence,durations
0,0,"[88, 38, 38, 45, 45, 45, [44, 45], 45, 45, 38,...","[0.25, 2.5, 0.25, 0.3333, 0.3333, 0.25, 0.3333..."
1,1,"[[56, 65, 68, 80, 60], [54, 61, 65, 75], 77, 7...","[1.25, 1.6667, 0.25, 1.0, 1.0, 0.5, 0.6667, 0...."
2,2,"[46, 70, [53, 60, 62], [65, 46, 53, 60], 70, [...","[2.6667, 0.25, 1.25, 0.75, 0.75, 1.6667, 1.75,..."
3,3,"[[52, 28, 40], 64, [65, 68], [53, 55, 56], [58...","[3.75, 2.5, 1.0, 2.0, 1.6667, 1.3333, 0.75, 0...."
4,4,"[59, 59, [60, 62, 38, 54], 62, [43, 62, 50, 52...","[0.75, 0.6667, 0.6667, 0.5, 1.0, 0.25, 2.0, 0...."
...,...,...,...
2770,2770,"[61, 49, [63, 60], 65, 70, [58, 61, 66, 49], [...","[1.3333, 1.6667, 1.0, 0.5, 0.5, 0.5, 0.6667, 0..."
2771,2771,"[[45, 55, 67, 69, 72, 75], [46, 74, 56, 62, 68...","[1.0, 1.0, 0.25, 1.25, 0.3333, 0.25, 0.25, 0.7..."
2772,2772,"[[45, 57], 71, 52, 61, 64, 45, 69, [45, 69], 4...","[1.6667, 0.6667, 3.3333, 2.75, 2.25, 0.75, 0.5..."
2773,2773,"[48, 41, 60, [67, 68, 72, 79], 66, [65, 68, 72...","[0.3333, 0.3333, 0.25, 0.25, 0.3333, 0.3333, 0..."


In [16]:
save_dataframe_to_one_csv((df_reconstructed), root + 'reconstructed_with_durations.csv')

## Predict only Events (Notes and Chords)

### Creating the data: Fixed number of events 

Idea for creating the input sequences:
- we take subsets of the list of events representing each song 
- we take the next event of each subset as corresponding training output sequences

This is easy to implement and we will have a consistent sequence lenght for batching, but we are ignoring the timing aspect.

In [17]:
class Vocabulary:
    def __init__(self, reconstructed_df):
        """
        Build vocabulary of unique single notes only.
        """
        self.notes = set()
        for i in range(len(reconstructed_df)):
            sequence = reconstructed_df['sequence'][i]
            for event in sequence:
                if isinstance(event, list):
                    for note in event:
                        self.notes.add(note)
                else:
                    self.notes.add(event)

        self.notes = sorted(self.notes)
        self.note_to_idx = {note: idx for idx, note in enumerate(self.notes)}
        self.idx_to_note = {idx: note for idx, note in enumerate(self.notes)}
        self.vocab_size = len(self.notes)

    def encode_event(self, event):
        """
        Encode an event as a multi-hot vector over single notes.
        """
        vec = np.zeros(self.vocab_size, dtype=np.float32)
        if isinstance(event, list):
            for note in event:
                vec[self.note_to_idx[note]] = 1.0
        else:
            vec[self.note_to_idx[event]] = 1.0
        return vec

    def decode_event(self, vec, threshold=0.5):
        """
        Decode multi-hot vector to list of notes.
        """
        indices = np.where(vec >= threshold)[0]
        notes = [self.idx_to_note[idx] for idx in indices]
        if len(notes) == 1:
            return notes[0]
        else:
            return notes

    def __len__(self):
        return self.vocab_size


Create Dataset object

In [32]:
class MusicEventDataset(Dataset):
    def __init__(self, reconstructed_df, vocab, seq_length=50):
        """
        Constructs all valid (input_seq, input_dur_seq, target_event, target_dur) pairs.

        Args:
            reconstructed_df: DataFrame with 'sequence' and 'durations' columns
            vocab: Vocabulary object to encode events
            seq_length: Length of each training input sequence (target is the next event)
        """
        self.samples = []
        self.seq_length = seq_length
        self.vocab = vocab

        for row_index in range(len(reconstructed_df)):
            sequence = reconstructed_df['sequence'][row_index]
            durations = df['durations'][row_index]
            
            if isinstance(durations, str):
                durations = eval(durations)

            n_events = min(len(sequence), len(durations))
            if n_events <= seq_length:
                continue

            for i in range(n_events - seq_length):
                input_seq = sequence[i : i + seq_length]
                input_durs = durations[i : i + seq_length]
                target_event = sequence[i + seq_length]
                target_dur = durations[i + seq_length]

                self.samples.append((input_seq, input_durs, target_event, target_dur))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        input_seq, input_durs, target_event, target_dur = self.samples[idx]

        input_encoded = [self.vocab.encode_event(event) for event in input_seq]
        input_tensor = torch.tensor(input_encoded, dtype=torch.float32)

        dur_tensor = torch.tensor(input_durs, dtype=torch.float32).unsqueeze(-1)  

        target_encoded = self.vocab.encode_event(target_event)
        target_tensor = torch.tensor(target_encoded, dtype=torch.float32)

        target_dur_tensor = torch.tensor([target_dur], dtype=torch.float32)  
        return input_tensor, dur_tensor, target_tensor, target_dur_tensor


In [18]:
import torch
import numpy as np
from torch.utils.data import Dataset
import ast

class MusicEventDataset(Dataset):
    def __init__(self, reconstructed_df, vocab, seq_length=50):
        """
        Constructs all valid (input_seq, input_dur_seq, target_event, target_dur) pairs.

        Args:
            reconstructed_df: DataFrame with 'sequence' and 'durations' columns
            vocab: Vocabulary object to encode events
            seq_length: Length of each training input sequence (target is the next event)
        """
        self.samples = []
        self.seq_length = seq_length
        self.vocab = vocab

        for row_index in range(len(reconstructed_df)):
            sequence = reconstructed_df['sequence'][row_index]
            durations = reconstructed_df['durations'][row_index]

            if isinstance(durations, str):
                durations = ast.literal_eval(durations)

            n_events = min(len(sequence), len(durations))
            if n_events <= seq_length:
                continue

            for i in range(n_events - seq_length):
                input_seq = sequence[i:i + seq_length]
                input_durs = durations[i:i + seq_length]
                target_event = sequence[i + seq_length]
                target_dur = durations[i + seq_length]

                self.samples.append((input_seq, input_durs, target_event, target_dur))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        input_seq, input_durs, target_event, target_dur = self.samples[idx]

        input_encoded = [self.vocab.encode_event(event) for event in input_seq]
        input_tensor = torch.tensor(input_encoded, dtype=torch.float32)

        dur_tensor = torch.tensor(input_durs, dtype=torch.float32).unsqueeze(-1)  

        target_encoded = self.vocab.encode_event(target_event)
        target_tensor = torch.tensor(target_encoded, dtype=torch.float32)

        target_dur_tensor = torch.tensor([target_dur], dtype=torch.float32)  

        return input_tensor, dur_tensor, target_tensor, target_dur_tensor


In [21]:
reconstructed_dataset = load_reconstructed_events(root + 'reconstructed_ordered_events.csv')
vocab = Vocabulary(reconstructed_dataset)
dataset = MusicEventDataset(ordered_events_with_durations, vocab=vocab, seq_length=16)




In [22]:
x, dur, y, target_dur = dataset[0]

print("Input sequence shape:", x.shape)
print("Duration sequence shape:", dur.shape)
print("Next event shape:", y.shape)
print("Target duration shape:", target_dur.shape)

print("Input sequence (multi-hot vectors):\n", x)
print("Duration sequence:\n", dur)
print("Next event (multi-hot vector):", y)
print("Target duration:", target_dur)


Input sequence shape: torch.Size([16, 88])
Duration sequence shape: torch.Size([16, 1])
Next event shape: torch.Size([88])
Target duration shape: torch.Size([1])
Input sequence (multi-hot vectors):
 tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
Duration sequence:
 tensor([[0.2500],
        [2.5000],
        [0.2500],
        [0.3333],
        [0.3333],
        [0.2500],
        [0.3333],
        [0.2500],
        [0.3333],
        [0.7500],
        [1.5000],
        [0.3333],
        [0.2500],
        [1.0000],
        [0.3333],
        [0.3333]])
Next event (multi-hot vector): tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,

  input_tensor = torch.tensor(input_encoded, dtype=torch.float32)


In [23]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [26]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class VAE(nn.Module):
    def __init__(self, input_dim, latent_dim, seq_length):
        super(VAE, self).__init__()
        self.seq_length = seq_length
        self.latent_dim = latent_dim

        # Encoder for the input sequence
        self.encoder = nn.Sequential(
            nn.Linear(input_dim * seq_length, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
        )

        # Encoder for the duration sequence
        self.duration_encoder = nn.Sequential(
            nn.Linear(seq_length, 64),
            nn.ReLU(),
        )

        # Combined latent space
        self.fc_mu = nn.Linear(256 + 64, latent_dim)
        self.fc_var = nn.Linear(256 + 64, latent_dim)

        # Decoder for the next event
        self.event_decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, input_dim),
            nn.Sigmoid(),  # Using Sigmoid to output values between 0 and 1
        )

        # Decoder for the duration
        self.duration_decoder = nn.Sequential(
            nn.Linear(latent_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
        )

    def encode(self, x, durs):
        h_seq = self.encoder(x.view(-1, self.seq_length * x.size(-1)))
        h_dur = self.duration_encoder(durs.view(-1, self.seq_length))
        h = torch.cat((h_seq, h_dur), dim=1)
        mu = self.fc_mu(h)
        log_var = self.fc_var(h)
        return mu, log_var

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        event = self.event_decoder(z)
        duration = self.duration_decoder(z)
        return event, duration

    def forward(self, x, durs):
        mu, log_var = self.encode(x, durs)
        z = self.reparameterize(mu, log_var)
        return self.decode(z), mu, log_var

input_dim = 88  
latent_dim = 20
seq_length = 16

vae = VAE(input_dim, latent_dim, seq_length).to(device)


In [24]:
from torch.utils.data import DataLoader

batch_size = 32
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [27]:
import torch.optim as optim

input_dim = 88  
latent_dim = 20
seq_length = 16
vae = VAE(input_dim, latent_dim, seq_length).to(device)

optimizer = optim.Adam(vae.parameters(), lr=1e-3)

def loss_function(recon_event, event, recon_dur, dur, mu, log_var):
    recon_loss_event = F.binary_cross_entropy(recon_event, event, reduction='sum')
    recon_loss_dur = F.mse_loss(recon_dur, dur, reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return recon_loss_event + recon_loss_dur + KLD


In [28]:
num_epochs = 50

for epoch in range(num_epochs):
    vae.train()
    train_loss = 0

    for batch_idx, (input_seq, dur_seq, target_event, target_dur) in enumerate(data_loader):
        input_seq = input_seq.to(device)
        dur_seq = dur_seq.to(device)
        target_event = target_event.to(device)
        target_dur = target_dur.to(device)
        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        recon_batch, mu, log_var = vae(input_seq, dur_seq)
        recon_event, recon_dur = recon_batch

        # Compute loss
        loss = loss_function(recon_event, target_event, recon_dur, target_dur, mu, log_var)

        # Backward pass and optimize
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    # Print training progress
    print(f'Epoch: {epoch+1}, Loss: {train_loss / len(data_loader.dataset)}')


KeyboardInterrupt: 

In [29]:
torch.save(vae.state_dict(), 'vae_model.pth')


In [30]:
vae = VAE(input_dim=88, latent_dim=20, seq_length=16).to(device)
vae.load_state_dict(torch.load('vae_model.pth'))
vae.eval()


VAE(
  (encoder): Sequential(
    (0): Linear(in_features=1408, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): ReLU()
  )
  (duration_encoder): Sequential(
    (0): Linear(in_features=16, out_features=64, bias=True)
    (1): ReLU()
  )
  (fc_mu): Linear(in_features=320, out_features=20, bias=True)
  (fc_var): Linear(in_features=320, out_features=20, bias=True)
  (event_decoder): Sequential(
    (0): Linear(in_features=20, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=88, bias=True)
    (5): Sigmoid()
  )
  (duration_decoder): Sequential(
    (0): Linear(in_features=20, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=1, bias=True)
  )
)

In [55]:
z = torch.randn(1, latent_dim).to(device)
recon_event, recon_dur = vae.decode(z)


In [61]:
event_vec = recon_event.squeeze().detach().cpu().numpy()
duration_val = recon_dur.item()

decoded_event = vocab.decode_event(event_vec, threshold=0.3)
if isinstance(decoded_event, list) and len(decoded_event) < 1:
    decoded_event = [vocab.idx_to_note[np.argmax(event_vec)]]
print("🎵 Generated event (note or chord):", decoded_event)
print("🕒 Duration:", duration_val)


🎵 Generated event (note or chord): [60]
🕒 Duration: 0.9472131729125977


In [51]:
generated_sequence = []
generated_durations = []

for _ in range(32):  # Generate 32 events
    z = torch.randn(1, latent_dim).to(device)
    recon_event, recon_dur = vae.decode(z)

    event_vec = recon_event.squeeze().detach().cpu().numpy()
    duration_val = recon_dur.item()

    decoded_event = vocab.decode_event(event_vec, threshold=0.3)
    if decoded_event == []:
        decoded_event = [vocab.idx_to_note[np.argmax(event_vec)]]
    generated_sequence.append(decoded_event)
    generated_durations.append(duration_val)

print("Generated sequence:", generated_sequence)
print("Durations:", generated_durations)


Generated sequence: [[60], [60], [60], [60], [60], [60], [60], [60], [60], [60], [60], [60], [60], [60], [62], [60], [60], [60], [60], [60], [60], [60], [60], [60], [60], [60], [60], [60], [60], [60], [57], [60]]
Durations: [0.9182330369949341, 0.9744817018508911, 0.8778908252716064, 0.6624100804328918, 0.8217076659202576, 0.8415372967720032, 0.9237037897109985, 0.9346046447753906, 1.1181890964508057, 0.9538810849189758, 1.0869005918502808, 1.4057949781417847, 0.8418909311294556, 1.0128835439682007, 1.3889644145965576, 0.9156970977783203, 0.6798717379570007, 1.1689029932022095, 1.359874963760376, 0.9114537239074707, 0.886300802230835, 1.005086064338684, 0.678954541683197, 0.8203637003898621, 0.7732694149017334, 0.8833941221237183, 0.8167291879653931, 0.9068774580955505, 1.088042140007019, 0.7365081310272217, 0.9041471481323242, 1.230223298072815]


In [62]:
for _ in range(10):
    z = torch.randn(1, latent_dim).to(device)
    event_vec, dur = vae.decode(z)
    event_vec = event_vec.squeeze().detach().cpu().numpy()
    print("Top notes:", np.argsort(event_vec)[-5:][::-1])


Top notes: [39 41 36 46 44]
Top notes: [44 39 36 46 41]
Top notes: [39 41 36 43 40]
Top notes: [39 44 41 42 36]
Top notes: [39 41 44 36 46]
Top notes: [39 41 43 46 42]
Top notes: [39 41 46 44 36]
Top notes: [39 46 38 42 43]
Top notes: [39 43 42 44 46]
Top notes: [39 46 36 41 44]


In [69]:
generated_sequence = []
generated_durations = []

top_k = 5  # number of top notes to include in multi-note events

for _ in range(32):  # Generate 32 events
    z = torch.randn(1, latent_dim).to(device)
    recon_event, recon_dur = vae.decode(z)

    event_vec = recon_event.squeeze().detach().cpu().numpy()
    duration_val = recon_dur.item()

    # Get indices of top-k note activations
    top_indices = np.argsort(event_vec)[-top_k:]
    decoded_event = [vocab.idx_to_note[idx] for idx in top_indices if event_vec[idx] > 0.05]

    # Fallback: if still empty, take the single top note
    if not decoded_event:
        decoded_event = [vocab.idx_to_note[np.argmax(event_vec)]]

    generated_sequence.append(decoded_event)
    generated_durations.append(duration_val)

print("Generated sequence:", generated_sequence)
print("Durations:", generated_durations)


Generated sequence: [[60], [63, 65, 57, 67, 60], [57, 62, 60], [63, 62, 60], [67, 62, 63, 65, 60], [65, 62, 57, 60], [64, 65, 57, 62, 60], [64, 65, 57, 62, 60], [64, 67, 62, 57, 60], [67, 62, 65, 57, 60], [64, 62, 65, 57, 60], [65, 61, 62, 57, 60], [61, 64, 57, 62, 60], [57, 63, 64, 62, 60], [57, 62, 60], [65, 67, 57, 62, 60], [62, 67, 65, 57, 60], [64, 58, 57, 62, 60], [62], [63, 65, 57, 62, 60], [67, 57, 65, 62, 60], [62, 58, 65, 67, 60], [64, 59, 65, 60, 62], [67, 57, 65, 62, 60], [70, 62, 64, 57, 60], [62, 65, 57, 60], [63, 62, 67, 57, 60], [65, 67, 63, 62, 60], [57, 63, 64, 67, 60], [64, 67, 63, 62, 60], [57, 64, 67, 62, 60], [64, 65, 57, 62, 60]]
Durations: [1.315582275390625, 1.1147102117538452, 1.1477317810058594, 1.1244229078292847, 0.932746171951294, 1.0402549505233765, 0.6961724162101746, 0.7498365640640259, 1.0500375032424927, 0.8795987963676453, 0.8913076519966125, 0.8326601386070251, 0.9814002513885498, 0.9566794633865356, 0.9741085171699524, 0.8632479906082153, 1.1383975

In [70]:
import pretty_midi

def extract_melody_from_midi(midi_path):
    pm = pretty_midi.PrettyMIDI(midi_path)
    melody = []
    durations = []

    # Choose melody instrument or just first instrument
    instrument = pm.instruments[0]

    for note in instrument.notes:
        melody.append(note.pitch)
        durations.append(note.end - note.start)

    return melody, durations


In [73]:
def encode_input_sequence(melody, durations, seq_length, vocab):
    encoded_sequences = []
    encoded_durations = []

    for i in range(len(melody) - seq_length):
        input_seq = melody[i:i+seq_length]
        input_dur = durations[i:i+seq_length]

        input_encoded = [vocab.encode_event(n) for n in input_seq]
        encoded_sequences.append(torch.tensor(input_encoded, dtype=torch.float32))
        encoded_durations.append(torch.tensor(input_dur, dtype=torch.float32).unsqueeze(-1))

    return encoded_sequences, encoded_durations


In [76]:
melody, durations = extract_melody_from_midi('HipsDontLie.mid')

In [79]:
encoded_sequences, encoded_durations = encode_input_sequence(
    melody, durations, seq_length=16, vocab=vocab
)


In [80]:
jazzified_events = []
jazzified_durations = []

vae.eval()
with torch.no_grad():
    for x_seq, dur_seq in zip(encoded_sequences, encoded_durations):
        x_seq = x_seq.unsqueeze(0).to(device)     # [1, seq_length, 88]
        dur_seq = dur_seq.unsqueeze(0).to(device) # [1, seq_length, 1]

        z_mu, z_logvar = vae.encode(x_seq, dur_seq)
        z = vae.reparameterize(z_mu, z_logvar)

        recon_event, recon_dur = vae.decode(z)

        event_vec = recon_event.squeeze().cpu().numpy()
        duration_val = recon_dur.item()

        # Decode with top-k
        top_k = 5
        top_indices = np.argsort(event_vec)[-top_k:]
        decoded_event = [vocab.idx_to_note[idx] for idx in top_indices if event_vec[idx] > 0.05]
        if not decoded_event:
            decoded_event = [vocab.idx_to_note[np.argmax(event_vec)]]

        jazzified_events.append(decoded_event)
        jazzified_durations.append(duration_val)


In [81]:
jazzified_events

[[67, 64, 62, 63, 60],
 [65, 62, 67, 63, 60],
 [57, 62, 60],
 [57, 67, 62, 65, 60],
 [57, 65, 62, 60],
 [65, 57, 62, 60],
 [67, 57, 62, 65, 60],
 [65, 57, 62, 60],
 [65, 67, 57, 62, 60],
 [58, 64, 63, 62, 60],
 [67, 65, 62, 57, 60],
 [65, 62, 57, 60],
 [63, 57, 67, 62, 60],
 [67, 58, 55, 62, 60],
 [67, 62, 57, 64, 60],
 [57, 63, 65, 62, 60],
 [57, 65, 62, 67, 60],
 [43, 65, 62, 57, 60],
 [62, 65, 63, 67, 60],
 [64, 65, 62, 63, 60],
 [65, 58, 67, 64, 60],
 [67, 62, 57, 65, 60],
 [57, 65, 64, 62, 60],
 [67, 63, 62, 60],
 [63, 67, 57, 62, 60],
 [62, 60],
 [64, 67, 63, 65, 60],
 [57, 65, 63, 62, 60],
 [67, 64, 60],
 [43, 65, 57, 62, 60],
 [67, 63, 62, 65, 60],
 [57, 60, 62],
 [65, 64, 57, 62, 60],
 [57, 59, 65, 62, 60],
 [58, 67, 64, 62, 60],
 [57, 64, 65, 62, 60],
 [67, 62, 57, 60],
 [65, 64, 67, 63, 60],
 [64, 65, 67, 63, 60],
 [59, 57, 65, 62, 60],
 [63, 65, 67, 62, 60],
 [67, 65, 62, 57, 60],
 [55, 63, 57, 62, 60],
 [62, 65, 63, 57, 60],
 [58, 63, 65, 64, 60],
 [64, 57, 65, 62, 60],
 [

In [82]:
jazzified_durations

[0.7764978408813477,
 0.9244034886360168,
 1.3015140295028687,
 0.9015295505523682,
 1.54661226272583,
 1.1763676404953003,
 0.8743936419487,
 0.9516654014587402,
 0.9062855839729309,
 0.7415578365325928,
 0.640366792678833,
 0.7332487106323242,
 1.0003092288970947,
 0.9289456605911255,
 0.936152994632721,
 0.8813924789428711,
 0.8187229633331299,
 0.888008713722229,
 1.0868819952011108,
 1.0769031047821045,
 1.2029246091842651,
 1.0335792303085327,
 0.7078806757926941,
 1.1397569179534912,
 0.9589075446128845,
 0.8732542991638184,
 1.0826421976089478,
 0.7601902484893799,
 1.2721787691116333,
 0.9611872434616089,
 0.8703715205192566,
 0.8654482364654541,
 1.1127575635910034,
 0.7526496648788452,
 0.9971199035644531,
 0.8963857293128967,
 0.9063019752502441,
 0.7944746017456055,
 1.0056101083755493,
 0.8124434947967529,
 0.9204176664352417,
 0.7082937955856323,
 0.6781424283981323,
 1.029675841331482,
 1.3585549592971802,
 1.1596977710723877,
 0.6099212169647217,
 0.8798952698707581,
 

In [86]:
from mido import Message, MidiFile, MidiTrack

def write_midi(events, durations, filename, velocity=64):
    """
    Converts a list of note/chord events and durations into a MIDI file.
    """
    mid = MidiFile()
    track = MidiTrack()
    mid.tracks.append(track)
    ticks_per_beat = mid.ticks_per_beat

    for event, dur in zip(events, durations):
        duration_ticks = int(ticks_per_beat * dur)
        notes = event if isinstance(event, list) else [event]

        for note in notes:
            track.append(Message('note_on', note=note, velocity=velocity, time=0))
        for note in notes:
            track.append(Message('note_off', note=note, velocity=velocity, time=duration_ticks))

    mid.save(filename)
    print(f"🎼 Saved MIDI to {filename}")


In [93]:
def write_midi(events, durations, filename="jazzified_output.mid", velocity=64):
    from mido import MidiFile, MidiTrack, Message

    mid = MidiFile()
    track = MidiTrack()
    mid.tracks.append(track)
    ticks_per_beat = mid.ticks_per_beat

    for event, dur in zip(events, durations):
        duration_ticks = int(dur * ticks_per_beat)
        notes = event if isinstance(event, list) else [event]

        # All note_on events at time=0
        for note in notes:
            track.append(Message('note_on', note=note, velocity=velocity, time=0))

        # note_offs: give duration to just one note, rest time=0
        for i, note in enumerate(notes):
            time = duration_ticks if i == 0 else 0
            track.append(Message('note_off', note=note, velocity=velocity, time=time))
        for i, dur in enumerate(jazzified_durations):
            if dur < 0.01:
                jazzified_durations[i] = 0.2  # minimum duration in beats

    mid.save(filename)
    print(f"✅ MIDI saved: {filename}")


In [94]:
write_midi(jazzified_events, jazzified_durations, filename="jazzified_happybday.mid")


✅ MIDI saved: jazzified_happybday.mid


In [96]:
jazzified_durations

[0.7764978408813477,
 0.9244034886360168,
 1.3015140295028687,
 0.9015295505523682,
 1.54661226272583,
 1.1763676404953003,
 0.8743936419487,
 0.9516654014587402,
 0.9062855839729309,
 0.7415578365325928,
 0.640366792678833,
 0.7332487106323242,
 1.0003092288970947,
 0.9289456605911255,
 0.936152994632721,
 0.8813924789428711,
 0.8187229633331299,
 0.888008713722229,
 1.0868819952011108,
 1.0769031047821045,
 1.2029246091842651,
 1.0335792303085327,
 0.7078806757926941,
 1.1397569179534912,
 0.9589075446128845,
 0.8732542991638184,
 1.0826421976089478,
 0.7601902484893799,
 1.2721787691116333,
 0.9611872434616089,
 0.8703715205192566,
 0.8654482364654541,
 1.1127575635910034,
 0.7526496648788452,
 0.9971199035644531,
 0.8963857293128967,
 0.9063019752502441,
 0.7944746017456055,
 1.0056101083755493,
 0.8124434947967529,
 0.9204176664352417,
 0.7082937955856323,
 0.6781424283981323,
 1.029675841331482,
 1.3585549592971802,
 1.1596977710723877,
 0.6099212169647217,
 0.8798952698707581,
 

In [95]:
from IPython.display import Audio
Audio("jazzified_happybday.mid")
