In [22]:
import mido
import numpy as np
from mido import Message, MidiFile, MidiTrack
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd

In [23]:
mid = MidiFile("midis/Q1__8v0MFBZoco_0.mid")

In [24]:
mid

MidiFile(type=1, ticks_per_beat=384, tracks=[
  MidiTrack([
    MetaMessage('set_tempo', tempo=500000, time=0),
    MetaMessage('time_signature', numerator=4, denominator=4, clocks_per_click=24, notated_32nd_notes_per_beat=8, time=0),
    MetaMessage('end_of_track', time=1)]),
  MidiTrack([
    Message('note_on', channel=0, note=52, velocity=89, time=305),
    Message('note_on', channel=0, note=64, velocity=90, time=2),
    Message('note_on', channel=0, note=52, velocity=0, time=63),
    Message('note_on', channel=0, note=64, velocity=0, time=2),
    Message('note_on', channel=0, note=58, velocity=91, time=118),
    Message('note_on', channel=0, note=70, velocity=83, time=2),
    Message('note_on', channel=0, note=58, velocity=0, time=69),
    Message('note_on', channel=0, note=70, velocity=0, time=25),
    Message('note_on', channel=0, note=59, velocity=86, time=2),
    Message('note_on', channel=0, note=71, velocity=78, time=0),
    Message('note_on', channel=0, note=62, velocity=88,

In [25]:
def extract_notes_from_midi(file_path):
    midi_file = mido.MidiFile(file_path)
    notes = []
    for msg in midi_file:
        if msg.type == 'note_on':
            notes.append({
                'pitch': msg.note,
                'start_time': msg.time,
                'velocity': msg.velocity
            })
    return notes

In [26]:
data_label_path = "midis/label.csv"

In [27]:
label_df = pd.read_csv(data_label_path)

In [21]:
# change label_df to a dictionary
label_dict = {}
for i in range(len(label_df)):
    label_dict[label_df.iloc[i,0]] = label_df.iloc[i,1]

label_dict

{'Q1_0vLPYiPN7qY_0': 1,
 'Q1_0vLPYiPN7qY_1': 1,
 'Q1_0vLPYiPN7qY_2': 1,
 'Q1_1Qc15G0ZHIg_1': 1,
 'Q1_1Qc15G0ZHIg_2': 1,
 'Q1_1Qc15G0ZHIg_3': 1,
 'Q1_1vjy9oMFa8c_2': 1,
 'Q1_1vjy9oMFa8c_4': 1,
 'Q1_1vjy9oMFa8c_5': 1,
 'Q1_2Z9SjI131jA_0': 1,
 'Q1_2Z9SjI131jA_1': 1,
 'Q1_2Z9SjI131jA_10': 1,
 'Q1_2Z9SjI131jA_11': 1,
 'Q1_2Z9SjI131jA_12': 1,
 'Q1_2Z9SjI131jA_13': 1,
 'Q1_2Z9SjI131jA_14': 1,
 'Q1_2Z9SjI131jA_15': 1,
 'Q1_2Z9SjI131jA_2': 1,
 'Q1_2Z9SjI131jA_3': 1,
 'Q1_2Z9SjI131jA_4': 1,
 'Q1_2Z9SjI131jA_7': 1,
 'Q1_2Z9SjI131jA_8': 1,
 'Q1_2Z9SjI131jA_9': 1,
 'Q1_3N2G21U7guk_3': 1,
 'Q1_3N2G21U7guk_5': 1,
 'Q1_3N2G21U7guk_6': 1,
 'Q1_3ahg_eQZhxs_0': 1,
 'Q1_3ahg_eQZhxs_1': 1,
 'Q1_4dXC1cC7crw_0': 1,
 'Q1_4dXC1cC7crw_1': 1,
 'Q1_4ydjOX3pWds_0': 1,
 'Q1_4ydjOX3pWds_1': 1,
 'Q1_5Ju9q1N2x0E_2': 1,
 'Q1_5NW0zDu6IYM_2': 1,
 'Q1_60LLKmpgzRM_0': 1,
 'Q1_6Uf9XBUD3wE_0': 1,
 'Q1_6Uf9XBUD3wE_1': 1,
 'Q1_6kRPHamGDSo_0': 1,
 'Q1_6kRPHamGDSo_1': 1,
 'Q1_6kRPHamGDSo_3': 1,
 'Q1_6wFJhmhNeeg_0': 1,
 'Q1_7yW9c

In [30]:
class TransformerDataset(torch.utils.data.Dataset):
    
    def __init__(self, dataset, seq_len, cond=False):
      """
        Args:
            dataset: token dataset
            ttps: tokens per second
            seconds: seconds per example
      """
      
      self.ids = []
      self.dataset = []
      self.conditions = []
      
      self.cond = cond
      for i, data in enumerate(dataset['sequences']):
      
        for j in range(0, len(data)-(seq_len+1), seq_len+1):
          #if dataset['conditions'][i][0] == 18.:
          self.ids.append(dataset['ids'][i])
                    
          # we will use seq[:-1] as input and seq[1:] as target
          self.dataset.append(data[j: j+seq_len+1])

          if cond:
            self.conditions.append(dataset['conditions'][i])
          
      if type(self.ids[0] == str):
        self.ids = torch.arange(len(self.ids)).unsqueeze(-1)
      else:
        self.ids = torch.Tensor(self.ids)#[:4]

      self.dataset = torch.Tensor(self.dataset)#[:4]
      if cond:
        self.conditions = torch.Tensor(self.conditions)#[:4]  

    def __len__(self):
      return len(self.dataset)

    def __getitem__(self, idx):
      input = self.dataset[idx][:-1].long()
      target = self.dataset[idx][1:].long()

      if self.cond:
        batch = {'ids': self.ids[idx], 'inputs': input, 'targets': target, 'conditions': self.conditions[idx].long()}
      else:
        batch = {'ids': self.ids[idx], 'inputs': input, 'targets': target, 'conditions': torch.Tensor([float('nan')])}
      return batch


In [33]:
import itertools
import os
import glob
import time
from collections import Counter
from re import S
import itertools
import pickle as pkl
from note_seq.chord_symbols_lib import ChordSymbolError
from note_seq.musicxml_parser import ChordSymbol
import numpy as np
import note_seq
from collections import defaultdict

In [31]:
class MidiEncoder():
    def __init__(self, steps_per_sec, num_vel_bins, min_pitch, max_pitch, stretch_factors=[1.0], pitch_transpose_range=[0]):
        self.steps_per_sec = steps_per_sec
        self.num_vel_bins = num_vel_bins
        self.min_pitch = min_pitch
        self.max_pitch = max_pitch

        self.events_to_ids = self.make_vocab()
        self.ids_to_events = {value: key for key, value in self.events_to_ids.items()}
        self.vocab_size = len(self.events_to_ids)
        
        self.strech_factors = stretch_factors
        self.transpose_amounts = list(range(pitch_transpose_range[0],
                                     pitch_transpose_range[-1]+1))
        
        self.augment_params = [(s, t) for s in self.strech_factors for t in self.transpose_amounts]

        self.encoded_sequences = {'ids':[], 'sequences': []}
                
    def make_vocab(self):
        vocab = defaultdict(list)
        items = 0
        for note in range(self.min_pitch, self.max_pitch + 1):
            vocab[f"NOTE_ON_{note}"] = items
            items+=1
        for note in range(self.min_pitch, self.max_pitch + 1):
            vocab[f"NOTE_OFF_{note}"] = items
            items+=1
        for shift in range(1, self.steps_per_sec+1):
            vocab[f"TIME_SHIFT_{shift}"] = items
            items+=1
        for vel in range(1, self.num_vel_bins+1):
            vocab[f"VELOCITY_{vel}"] = items
            items+=1
        return dict(vocab)

    def filter_pitches(self, note_sequence):
        note_list = []
        deleted_note_count = 0
        end_time = 0
        for note in note_sequence.notes:
            if self.min_pitch <= note.pitch <= self.max_pitch:
                end_time = max(end_time, note.end_time)
                note_list.append(note)
            else:
                deleted_note_count += 1
        
        if deleted_note_count >= 0:
            del note_sequence.notes[:]
            note_sequence.notes.extend(note_list)
        note_sequence.total_time = end_time
        
    def encode_midi_file(self, midi_file, strech_factor=1, transpose_amount=0):
        note_sequence = note_seq.midi_file_to_note_sequence(midi_file)
        
        note_sequence = note_seq.apply_sustain_control_changes(note_sequence)
        del note_sequence.control_changes[:]

        self.filter_pitches(note_sequence)

        if strech_factor != 1 or transpose_amount != 0:
            note_sequence = self.augment(note_sequence, strech_factor, transpose_amount)
        encoded_performance = self.encode_note_sequence(note_sequence)
        
        return encoded_performance
    
    def encode_note_sequence(self, note_sequence):
        quantized_seq = note_seq.quantize_note_sequence_absolute(note_sequence, self.steps_per_sec)
                
        performance = note_seq.Performance(quantized_seq, num_velocity_bins=self.num_vel_bins)
        
        encoded_performance = self.encode_performance(performance)

        return encoded_performance

    def encode_performance(self, performance):
        encoded_performance = []
        for event in performance:
            event_name=None
            if event.event_type == note_seq.PerformanceEvent.NOTE_ON:
                event_name = f"NOTE_ON_{event.event_value}"
            if event.event_type == note_seq.PerformanceEvent.NOTE_OFF:
                event_name = f"NOTE_OFF_{event.event_value}"
            if event.event_type == note_seq.PerformanceEvent.TIME_SHIFT:
                event_name = f"TIME_SHIFT_{event.event_value}"
            if event.event_type == note_seq.PerformanceEvent.VELOCITY:
                event_name = f"VELOCITY_{event.event_value}"
            
            if event_name:
                encoded_performance.append(self.events_to_ids[event_name])
            else: 
                raise ValueError(f"Unknown event type: {event.event_type} at position {len(performance)}")
        
        return encoded_performance
    
    def decode_to_performance(self, encoded_performance):
        decoded_performance = note_seq.Performance(quantized_sequence=None,
        steps_per_second=self.steps_per_sec, num_velocity_bins=self.num_vel_bins)
        
        #INCUDE?
        '''
        tokens = []
        
        for i, event_id in enumerate(encoded_performance):
            if len(tokens) >= 2 and self.ids_to_events[tokens[-1]] == 'TIME_SHIFT_100' and self.ids_to_events[event_id] == 'TIME_SHIFT_100':
                continue
            tokens.append(event_id)
        '''

        for id in encoded_performance:
            try:
                event_name = self.ids_to_events[id]
                event_splits = event_name.split('_')
                event_type, event_value = '_'.join(event_splits[:-1]), int(event_splits[-1])
                if event_type == 'NOTE_ON':
                    event = note_seq.PerformanceEvent(
                            event_type = note_seq.PerformanceEvent.NOTE_ON, event_value=event_value)
                if event_type == 'NOTE_OFF':
                    event = note_seq.PerformanceEvent(
                            event_type = note_seq.PerformanceEvent.NOTE_OFF, event_value=event_value)
                if event_type == 'TIME_SHIFT':
                    event = note_seq.PerformanceEvent(
                            event_type = note_seq.PerformanceEvent.TIME_SHIFT, event_value=event_value)
                if event_type == 'VELOCITY':
                    event = note_seq.PerformanceEvent(
                            event_type = note_seq.PerformanceEvent.VELOCITY, event_value=event_value)
                    
                decoded_performance.append(event)
            except:
                raise ValueError("Unknown event index: %s" % id)
        return decoded_performance

    def decode_to_note_sequence(self, encoded_performance):
        decoded_performance = self.decode_to_performance(encoded_performance)
        note_sequence = decoded_performance.to_sequence(max_note_duration=3)
        return note_sequence

    def decode_to_midi_file(self, encoded_performance, save_path):
        note_sequence = self.decode_to_note_sequence(encoded_performance)
        note_seq.note_sequence_to_midi_file(note_sequence, save_path)

    def augment(self, note_sequence, stretch_factor, transpose_amount):
        augmented_note_sequence = note_seq.sequences_lib.stretch_note_sequence(note_sequence,
                                    stretch_factor, in_place=False)
        
        try: 
            _, num_deleted_notes = note_seq.sequences_lib.transpose_note_sequence(
                    augmented_note_sequence, transpose_amount,
                    min_allowed_pitch = self.min_pitch, max_allowed_pitch=self.max_pitch,
                    in_place=True
            )
        except ChordSymbolError:
            print('Transposition of chord symbol(s) failed.')
        if num_deleted_notes:
            print('Transposition caused out-of-range pitch(es)')
        return augmented_note_sequence 

    def encode_midi_list(self, midi_list, pkl_path=None):
        for midi_file in midi_list:
            print(midi_file)
            root, ext = os.path.splitext(os.path.basename(midi_file))
            for sf, ta in self.augment_params:
                self.encoded_sequences['ids'].append(root + '_' + str(sf) + '_' + str(ta) + ext)
                time0 = time.time()
                
                encoded_sequence = self.encode_midi_file(midi_file, sf, ta)
                
                print(time.time()-time0)
                self.encoded_sequences['sequences'].append(encoded_sequence)
        
        if pkl_path != None:
            with open(pkl_path, 'wb') as handle:
                pkl.dump(self.encoded_sequences, handle, protocol=pkl.HIGHEST_PROTOCOL)
        return self.encoded_sequences   

    def calculate_scores(self, midi_file, which_scores='all'):
        if type(midi_file) == str:
            midi_file = [midi_file]
        if which_scores == 'all':
            func_list = [pitch_count, note_count, note_range, average_inter_onset_interval, average_pitch_interval]
        else:
            func_list  = {func.__name__: [] for func in which_scores}
             
            
        scores = {func.__name__: [] for func in func_list}
        for filename in midi_file:
                for func in func_list:
                    scores[func.__name__].append(func(filename, self.min_pitch, self.max_pitch) if func.__name__ != 'average_inter_onset_interval' else \
                                            average_inter_onset_interval(self.min_pitch, self.max_pitch, self.steps_per_sec))
    
        return scores

### data representation

In [6]:
# # Map notes to integers
# note_to_int = {note['pitch']: idx for idx, note in enumerate(notes)}
# int_to_note = {idx: note for note, idx in note_to_int.items()}

# # Convert notes to integers
# notes_as_int = [note_to_int[note['pitch']] for note in notes]

# # Define sequence length and create input-output pairs
# sequence_length = 100  # Define the length of input sequences
# input_sequences = []
# output_sequences = []

# for i in range(0, len(notes_as_int) - sequence_length, 1):
#     input_seq = notes_as_int[i:i + sequence_length]
#     output_seq = notes_as_int[i + sequence_length]
#     input_sequences.append(input_seq)
#     output_sequences.append(output_seq)

# # Convert to PyTorch tensors
# input_sequences = torch.LongTensor(input_sequences)
# output_sequences = torch.LongTensor(output_sequences)


In [7]:
# input_sequences.shape, output_sequences.shape

(torch.Size([1154, 100]), torch.Size([1154]))

### transformer model

In [8]:
import torch
import torch.nn as nn

class MusicTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=256, num_layers=6, num_heads=8, d_feedforward=512, dropout=0.1):
        super(MusicTransformer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        
        # Transformer layers
        self.encoder_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model, num_heads, d_feedforward, dropout) 
            for _ in range(num_layers)
        ])
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layers, num_layers)
        
        self.decoder = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        x = self.embedding(x)
        x = x.permute(1, 0, 2)  # Change dimensions for transformer
        
        output = self.transformer_encoder(x)
        output = self.decoder(output[-1])  # Only use the output of the last time step
        
        return output


In [9]:
vocab_size = len(note_to_int)  # Adjust based on your vocabulary size
model = MusicTransformer(vocab_size)

In [10]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


In [11]:
dataset = "midis/"

In [12]:
from torch.utils.data import DataLoader, TensorDataset

# Convert input and output sequences to PyTorch Dataset
dataset = TensorDataset(input_sequences, output_sequences)

# Create a DataLoader
batch_size = 32  # Define your batch size
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [None]:
num_epochs = 10  # Define the number of epochs
for epoch in range(num_epochs):
    total_loss = 0
    
    for i, (inputs, targets) in enumerate(dataloader):
        optimizer.zero_grad()
        output = model(inputs)
        
        # Reshape targets to match the expected shape [batch_size]
        targets = targets.view(-1)  
        
        # Calculate the loss
        loss = criterion(output, targets)
        total_loss += loss.item()
        
        loss.backward()
        optimizer.step()

    # Calculate average loss per epoch
    average_loss = total_loss / len(dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {average_loss}")
