In [35]:
import miditok
import pretty_midi
import pandas as pd
import json
import glob
import os
from natsort import natsorted

In [2]:
def extract_guitar_and_drums(midi_file):
    """This function extracts the guitar and drum tracks from a midi file.
       The input is a path to a midi file (for example: 'raw_data/song_name.mid') in string format
       The output is a dictionary with the song name, guitar track and drum track"""    
    
    mid = pretty_midi.PrettyMIDI(midi_file)
    
    guitars = []
    lengths_guitar = []
    drums = []
    lengths_drums = []
    
    for instrument in mid.instruments:
        if instrument.is_drum:
            drums.append(instrument)
            lengths_drums.append(len(instrument.notes))
        if (instrument.program >= 25) and (instrument.program <= 31):
            guitars.append(instrument)
            lengths_guitar.append(len(instrument.notes))
    drum_track = drums[lengths_drums.index(max(lengths_drums))]
    guitar_track = guitars[lengths_guitar.index(max(lengths_guitar))]
    song_title = os.path.splitext(os.path.basename(midi_file))[0]
    
    song_dict = {'title': song_title,
                 'down_beats': mid.get_downbeats(),
                 'guitar': guitar_track,
                 'drums': drum_track
                }
    
    return song_dict

In [3]:
def tracks_to_bars(song_dict: dict) -> dict:
    """This function accepts a dictionary as an input with 4 keys: 'title', 'down_beats', 'guitar', 'drums'.
    The function takes the guitar and drums, both pretty_midi instrument objects, and cuts them up into a sequence of individual bars.
    The output is a dictionary that contains the following keys/values: song_title, a list of guitar bars, a list of drum bars, and a list of the song's downbeats
    """
    
    new_dict={}
    new_dict['song_title']=song_dict['title']
    guitar = song_dict['guitar']
    drums = song_dict['drums']
    down_beats_array = song_dict['down_beats']


    guitar_bars_list = []
    drums_bars_list = []

    for index, start_time in enumerate(down_beats_array):
        if index < len(down_beats_array) - 1:
            end_time = down_beats_array[index+1] 
        else:
            end_time = down_beats_array[index] + 2
            
        guitar_bar = []
        drums_bar = []
        for guitar_note in guitar.notes:
            if (guitar_note.start >= start_time) and (guitar_note.end < end_time):
                guitar_bar.append(guitar_note)    

        for drum_note in drums.notes:
            if (drum_note.start >= start_time) and (drum_note.end < end_time):
                drums_bar.append(drum_note)
                  
        drums_bars_list.append(drums_bar)
        guitar_bars_list.append(guitar_bar)

    new_dict['guitar_bars'] = guitar_bars_list
    new_dict['drum_bars'] = drums_bars_list
    new_dict['down_beats'] = down_beats_array.tolist()
    return new_dict

In [131]:
def standardize_bars(song_name, list_of_bars, downbeats):
    """ 
    This function standardizes the timing of musical bars 
    so that each bar will start at the same time point.
    It gets a list of bars and the list of downbeats as inputs
    and returns a list of bars that all start with time = 1
    """
    
    for i in range(len(list_of_bars)):
        for j in range(len(list_of_bars[i])):
            if i == 0:
                list_of_bars[0][j].start = list_of_bars[0][j].start / downbeats[1]
                list_of_bars[0][j].end = list_of_bars[0][j].end / downbeats[1]

            list_of_bars[i][j].start = list_of_bars[i][j].start / downbeats[i]
            list_of_bars[i][j].end = list_of_bars[i][j].end / downbeats[i]

    return {'song_title': song_name,
            'list_of_bars': list_of_bars,
            'down_beats': downbeats}

In [141]:
song = pretty_midi.PrettyMIDI('test_data/Another-One-Bites-The-Dust-1.mid')

In [142]:
temp_dict = extract_guitar_and_drums('test_data/Another-One-Bites-The-Dust-1.mid')

In [None]:
temp_dict

In [133]:
next_dict = tracks_to_bars(temp_dict)

In [None]:
next_dict['drum_bars']

In [134]:
dict_guitar_bars = standardize_bars(next_dict['song_title'], next_dict['guitar_bars'], next_dict['down_beats'])
dict_drum_bars = standardize_bars(next_dict['song_title'], next_dict['drum_bars'], next_dict['down_beats'])

In [136]:
dict_guitar_bars['list_of_bars'][0]

[]

In [10]:
def bars_to_track(modelled_drum_bars: list) -> pretty_midi:
    drums_bars_list = modelled_drum_bars
    drum_track = pretty_midi.PrettyMIDI()
    drum_track.instruments.append(pretty_midi.Instrument(0, is_drum=True, name=''))
    
    for bar in drums_bars_list:
        for note in bar:
            drum_track.instruments[0].notes.append(note)

    drum_track.write('test_drum_track.mid')

In [11]:
bars_to_track(next_dict['drum_bars'])

In [16]:
def bars_to_midi_bars(dict_guitar_bars: dict, dict_drum_bars: dict):
    guitar_bars_list = dict_guitar_bars['list_of_bars']
    for i, guitar_bar in enumerate(guitar_bars_list):
        guitar_bar_midi = pretty_midi.PrettyMIDI()
        guitar_bar_midi.instruments.append(pretty_midi.Instrument(28, is_drum=False, name='Guitar'))
        for note in guitar_bar:
            guitar_bar_midi.instruments[0].notes.append(note)
        if not os.path.exists(f"midi_bars/{dict_guitar_bars['song_title']}/guitar/"):
            os.makedirs(f"midi_bars/{dict_guitar_bars['song_title']}/guitar/")
        guitar_bar_midi.write(f"midi_bars/{dict_guitar_bars['song_title']}/guitar/bar_{i+1}.mid")

    drum_bars_list = dict_drum_bars['list_of_bars']
    for i, drum_bar in enumerate(drum_bars_list):
        drum_bar_midi = pretty_midi.PrettyMIDI()
        drum_bar_midi.instruments.append(pretty_midi.Instrument(0, is_drum=True, name='Drums'))
        for note in drum_bar:
            drum_bar_midi.instruments[0].notes.append(note)
        if not os.path.exists(f"midi_bars/{dict_guitar_bars['song_title']}/drums/"):
            os.makedirs(f"midi_bars/{dict_guitar_bars['song_title']}/drums/")
        drum_bar_midi.write(f"midi_bars/{dict_guitar_bars['song_title']}/drums/bar_{i+1}.mid")

In [17]:
bars_to_midi_bars(dict_guitar_bars, dict_drum_bars)

In [66]:
guitar_filepaths = []
for filepath in glob.iglob('midi_bars/Another-One-Bites-The-Dust-1/guitar/*.mid'):
    guitar_filepaths.append(filepath)

drum_filepaths = []
for filepath in glob.iglob('midi_bars/Another-One-Bites-The-Dust-1/drums/*.mid'):
    drum_filepaths.append(filepath)

In [None]:
guitar_filepaths

In [139]:
song = pretty_midi.PrettyMIDI()
song.instruments.append(pretty_midi.Instrument(25, is_drum=False, name='Guitar'))
song.instruments.append(pretty_midi.Instrument(0, is_drum=True, name='Drums'))

drum_filepaths = natsorted(drum_filepaths)
guitar_filepaths = natsorted(guitar_filepaths)

for i, guitar_filepath in enumerate(guitar_filepaths):
    if pretty_midi.PrettyMIDI(guitar_filepath).instruments != []:
        for note in pretty_midi.PrettyMIDI(guitar_filepath).instruments[0].notes:
            if i == 0:
                note.start = note.start*dict_guitar_bars['down_beats'][1]
                note.end = note.end*dict_guitar_bars['down_beats'][1]
            else:
                note.start = note.start*dict_guitar_bars['down_beats'][i]
                note.end = note.end*dict_guitar_bars['down_beats'][i]

            song.instruments[0].notes.append(note)
    else:
        continue
        
for j, drum_filepath in enumerate(drum_filepaths):
    if pretty_midi.PrettyMIDI(drum_filepath).instruments != []:
        for note in pretty_midi.PrettyMIDI(drum_filepath).instruments[0].notes:
            if j == 0:
                note.start = note.start*dict_drum_bars['down_beats'][1]
                note.end = note.end*dict_drum_bars['down_beats'][1]
            else:
                note.start = note.start*dict_drum_bars['down_beats'][j]
                note.end = note.end*dict_drum_bars['down_beats'][j]

            song.instruments[1].notes.append(note)
    else:
        continue

song.write('temp_reconst_midi_song.mid')

In [135]:
def revert_standardization(list_of_std_bars: list, down_beats: list) -> list:
    """ 
    This function reverts the standardization performed on the timing of musical bars 
    so that each bar will start, once again, at their original timing.
    It gets a list of standardized bars and the list of original downbeats as inputs
    and returns a list of bars with their original timing.
    """
    original_down_beats = down_beats
    list_of_unstd_bars = []

    for i, std_bar in enumerate(list_of_std_bars):
        if std_bar == []:
            list_of_unstd_bars.append(std_bar)
        else:
            list_of_notes = []
            for note in std_bar:           
                if i == 0:
                    note.start = note.start*original_down_beats[1]
                    note.end = note.end*original_down_beats[1]
                    list_of_notes.append(note)
                
                note.start = note.start*original_down_beats[i]
                note.end = note.end*original_down_beats[i]
                list_of_notes.append(note)
            list_of_unstd_bars.append(list_of_notes)

    return list_of_unstd_bars

In [137]:
unstd_modelled_guitar_bars = revert_standardization(dict_guitar_bars['list_of_bars'], dict_guitar_bars['down_beats'])
unstd_modelled_drum_bars = revert_standardization(dict_drum_bars['list_of_bars'], dict_drum_bars['down_beats'])

In [138]:
unstd_modelled_guitar_bars

[[],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [Note(start=62.000000, end=62.062500, pitch=71, velocity=47),
  Note(start=62.000000, end=62.067708, pitch=64, velocity=50),
  Note(start=62.000000, end=62.072917, pitch=67, velocity=53),
  Note(start=62.000000, end=62.072917, pitch=62, velocity=53),
  Note(start=62.125000, end=62.192708, pitch=62, velocity=96),
  Note(start=62.125000, end=62.197917, pitch=71, velocity=96),
  Note(start=62.125000, end=62.197917, pitch=67, velocity=96),
  Note(start=62.125000, end=62.208333, pitch=64, velocity=96),
  Note(start=62.250000, end=62.291667, pitch=62, velocity=24),
  Note(start=62.250000, end=62.296875, pitch=64, velocity=24),
  Note(start=62.375000, end=62.416667, pitch=62, velocity=35),
  Note(start=62.375000, end=62.427083, pitch=64, velocity=35),
  Note(start=62.500000, end=62.562500, pitch=71, velocity=60),
  Note(start=62.500000, en

In [None]:
song.instruments[0].notes

In [117]:
bar = pretty_midi.PrettyMIDI('midi_bars/Another-One-Bites-The-Dust-1/drums/bar_87.mid')

In [122]:
bar.instruments[0].notes

[Note(start=1.000000, end=1.002273, pitch=35, velocity=116),
 Note(start=1.000000, end=1.002273, pitch=42, velocity=92),
 Note(start=1.002273, end=1.004545, pitch=42, velocity=5),
 Note(start=1.002273, end=1.004545, pitch=42, velocity=92),
 Note(start=1.002273, end=1.006818, pitch=35, velocity=112),
 Note(start=1.002273, end=1.006818, pitch=40, velocity=113),
 Note(start=1.004545, end=1.006818, pitch=42, velocity=2),
 Note(start=1.006818, end=1.009091, pitch=35, velocity=113),
 Note(start=1.006818, end=1.009091, pitch=40, velocity=59),
 Note(start=1.006818, end=1.009091, pitch=40, velocity=73),
 Note(start=1.006818, end=1.009091, pitch=42, velocity=37),
 Note(start=1.006818, end=1.009091, pitch=42, velocity=91)]