In [1]:
import os
import re
import numpy as np
import pandas as pd
import pretty_midi
import mido
import math
from sklearn.preprocessing import MinMaxScaler

TODO: Add a brief explanation of the preprocessing steps.

References:
<br/>
Simon et al. (2018) "Learning Latent Space..." use:
- 128 note on/offs
- 8 velocity change events -> quantized into 8 bins
- 96 time-shift (offset) events -> 24 step/quarter note.
- LSTM VAE (bidire. LSTM encoder w 1024 nodes, and a forward LSTM decoder with 3 layers of 512 nodes each)
- Conditioning
- latent space dim 512
- batch size 256
- Adam opt., l_rate 1e3 to 1e-5 with exponential decay rate 0.9999
- 100000 gradient steps
- Sampling autoregressively with a temperature param that controls the uniformity of the distribution
- 176K MIDI files (all 4/4) ->4092681 (1-bar) splits (deduped measures)
- ...more in the paper. <br/>
<br/>

Gillick et al. (2019) "Learning to Groove..." use:
- Groove MIDI dataset
- ...

In [2]:
#Utility functions

def mido2pretty(midi_data):
    '''Convert Mido data structure to pretty_midi
    TODO: Check the loop. Somethings odd...'''
    # Create a PrettyMIDI object
    pm = pretty_midi.PrettyMIDI()

    # Create an instrument object to add notes to
    instrument = pretty_midi.Instrument(program=0, is_drum=False)

    # Keep track of the current time in ticks
    current_time = 0

    # Loop through each message in the mido track
    for msg in midi_data.tracks[0]:
        # Update the current time by the message time delta
        current_time += msg.time
        
        # If the message is a note on message, add the note to the instrument
        if msg.type == 'note_on':
            note = pretty_midi.Note(
                velocity=msg.velocity,
                pitch=msg.note,
                start=current_time / midi_data.ticks_per_beat,
                end=(current_time + msg.time) / midi_data.ticks_per_beat,
            )
            instrument.notes.append(note)

    # Add the instrument to the PrettyMIDI object
    pm.instruments.append(instrument)
    return pm


def get_pitches(midi_data, inst=0, mido=True):
    # Get the note pitches of a single instrument in a MIDI track
    #TODO: add NOT mido
    pitches = []
    if mido:
        pmidi_data = mido2pretty(midi_data)
    # else:
    #     pmidi_data = midi_data
        for i, _ in enumerate(pmidi_data.instruments[inst].notes):
            pitch = pmidi_data.instruments[inst].notes[i].pitch
            pitches.append(pitch)
        pitches = np.array(pitches)
    return pitches


def get_onsets(midifile):
    '''This function returns the note onsets of a midi file using Mido.'''
    note_onsets = []
    current_time = 0
    current_notes = {}
    
    for msg in mido.MidiFile(midifile):
        current_time += msg.time
        
        if msg.type == 'note_on':
            if msg.velocity > 0:
                current_notes[msg.note] = current_time
            else:
                try:
                    note_onset = current_notes.pop(msg.note)
                    note_duration = current_time - note_onset
                    note_onsets.append((note_onset, msg.note)) #you can add note_duration too
                except KeyError:
                    pass

    return note_onsets


#Offset calculation - Looks the most reliable but REQUIREs more work!
def get_offsets(midifile, subdivision=8):
    '''This function returns the note onsets of a midi file using Mido, 
    and calculate the temporal distance with quantized onsets.'''
    note_onsets = []
    current_time = 0
    current_notes = {}
    mid = mido.MidiFile(midifile)

    for msg in mid:
        current_time += msg.time
        
        if msg.type == 'note_on':
            if msg.velocity > 0:
                current_notes[msg.note] = current_time
            else:
                try:
                    note_onset = current_notes.pop(msg.note)
                    note_duration = current_time - note_onset
                    note_onsets.append(note_onset) 
                except KeyError:
                    pass

    #quant interval:
    ticks_per_beat = mid.ticks_per_beat
    subdivision = subdivision
    ticks_per_div = ticks_per_beat/subdivision


    quantized_onsets = [round(onset / ticks_per_div) * ticks_per_div for onset in note_onsets]
    temporal_distance = [(q - o) / ticks_per_div for o, q in zip(note_onsets, quantized_onsets)]
    norm_list = [max(-0.5, min(0.5, d)) for d in temporal_distance]
    offsets_df = pd.DataFrame(norm_list, columns=['offsets'])

    return offsets_df


def get_meta(midi_file):
    '''This function returns the tempo of a midi file using Mido.'''
    if isinstance(midi_file, str):
        input_midi = mido.MidiFile(midi_file)
    elif isinstance(midi_file, mido.MidiFile):
        input_midi = midi_file
         
    meta_msgs = [msg for msg in input_midi if msg.is_meta]
    metas={}
    for msg in meta_msgs:
        metas.update(msg.dict())
    return metas


def get_midi_type(input_file):
    '''as the name suggests...'''
    if isinstance(input_file, str):
        input_midi = mido.MidiFile(input_file)
    elif isinstance(input_file, mido.MidiFile):
        input_midi = input_file
    return input_midi.type


def show_score(midi_file):
    '''Returns the score of a midi file as sheet music'''
    from music21 import converter, instrument, note, chord
    score = converter.parse(midi_file).chordify()
    return score.show()


def np_onehot(indices, depth, dtype=bool):
    """Converts 1D array of indices to a one-hot 2D array with given depth."""
    onehot_seq = np.zeros((len(indices), depth), dtype=dtype)
    onehot_seq[np.arange(len(indices)), indices] = 1.0
    return onehot_seq

def is_shorter_or_longer_than_n_bars(midi_file_path, n_bars=2, time_signature=(4, 4), shorter=True):
    if isinstance(midi_file_path, str):
        midi_file = mido.MidiFile(midi_file_path)
    elif isinstance(midi_file_path, mido.MidiFile):
        midi_file = midi_file_path

    # Calculate the length of the MIDI file in seconds
    midi_length_seconds = midi_file.length

    # Calculate the length of n bars in seconds
    ticks_per_beat = midi_file.ticks_per_beat
    time_sig_numerator = time_signature[0]
    time_sig_denominator = time_signature[1]

    for msg in midi_file.tracks[0]:
        if msg.type == "time_signature":
            time_sig_numerator = msg.numerator
            time_sig_denominator = msg.denominator
            break

    quarter_notes_per_bar = time_sig_numerator * (4 / time_sig_denominator)
    ticks_per_bar = ticks_per_beat * quarter_notes_per_bar
    n_bars_length_ticks = n_bars * ticks_per_bar

    # Convert ticks to seconds
    tempo = 500000  # Default tempo is 120 BPM
    for msg in midi_file.tracks[0]:
        if msg.type == "set_tempo":
            tempo = msg.tempo
            break

    n_bars_length_seconds = mido.tick2second(n_bars_length_ticks, ticks_per_beat, tempo)

    # Compare the length of the MIDI file with the length of n bars
    if shorter:
        return midi_length_seconds < n_bars_length_seconds
    return midi_length_seconds > n_bars_length_seconds



def print_zero_time_events(midi_path):
    midi_file = mido.MidiFile(midi_path)
    for i, track in enumerate(midi_file.tracks):
        print(f"Track {i}:")
        for msg in track:
            if msg.time == 0:
                print(msg)
# # Example usage
# midi_file_path = longs_w_issue[0]
# print_zero_time_events(midi_file_path)


def beat_threshold(df, bars=2):
    #TODO define a function that ignores the dataframes shorter than the given amount of bars
    pass 

def findstr(text, word_to_search):
    # Compile the regular expression pattern using the re.IGNORECASE flag
    pattern = re.compile(word_to_search, re.IGNORECASE)
    return pattern.findall(text) 

def list2text(list, filename):
    with open(filename, 'w') as f:
        for item in list:
            f.write(str(item))
            f.write('\n')

def extract_time_signature(midi_file_path):
    midi_file = mido.MidiFile(midi_file_path)

    for track in midi_file.tracks:
        for msg in track:
            if msg.type == 'time_signature':
                return f"{msg.numerator}/{msg.denominator}"

    # If no time signature is found, return a default value (e.g., '4/4').
    return "Other"

def print_uncommon_events(midi_path):
    uncommon_event_types = ["sysex", "polytouch", "aftertouch"]

    midi_file = mido.MidiFile(midi_path)
    for i, track in enumerate(midi_file.tracks):
        print(f"Track {i}:")
        for msg in track:
            if msg.type in uncommon_event_types:
                print(msg)

def add_delay_to_zero_time_notes(midi_path, output_path, delay=1):
    input_midi = mido.MidiFile(midi_path)
    output_midi = mido.MidiFile(type=input_midi.type)
    
    for track in input_midi.tracks:
        new_track = mido.MidiTrack()
        output_midi.tracks.append(new_track)

        for msg in track:
            if msg.type == 'note_on' and msg.time == 0:
                msg.time = delay
            new_track.append(msg)

    output_midi.save(output_path)



In [3]:
#More UTILITY functions for the preprocessing of MIDI files

def extend_midi_bars(input_file, num_bars=4, time_signature=(4, 4)):

    '''This function extends the drum loop to the desired number of bars without time-stretching. 
    It will turn a 1-bar loop into a 2-bar loop, a 2-bar loop into a 4-bar loop, and so on.'''

    input_midi = mido.MidiFile(input_file)

    # Calculate ticks per bar and total ticks for desired number of bars
    ticks_per_beat = input_midi.ticks_per_beat
    ticks_per_bar = ticks_per_beat * time_signature[0]
    total_ticks = ticks_per_bar * num_bars

    # Create a new MidiFile object for the output
    output_midi = mido.MidiFile(ticks_per_beat=ticks_per_beat)

    for track in input_midi.tracks:
        output_track = mido.MidiTrack()
        output_midi.tracks.append(output_track)

        # Duplicate meta messages to the output track
        for msg in track:
            if msg.is_meta:
                output_track.append(msg.copy())

        # Calculate the original loop length in ticks
        loop_length = 0
        for msg in track:
            if not msg.is_meta:
                loop_length += msg.time

        # Extend the loop to the desired number of bars
        accumulated_ticks = 0
        while accumulated_ticks < total_ticks:
            current_ticks = 0
            for msg in track:
                if not msg.is_meta:
                    # Calculate the tick position in the output track
                    output_ticks = current_ticks + accumulated_ticks

                    # If the message goes beyond the desired length, break the loop
                    if output_ticks >= total_ticks:
                        break

                    # Clone the message and add it to the output track
                    output_msg = msg.copy(time=msg.time)
                    output_track.append(output_msg)
                    current_ticks += msg.time

            # Update the accumulated_ticks for the next loop iteration
            accumulated_ticks += loop_length

    return output_midi


def truncate_midi_bars(input_file, num_bars=4, time_signature=(4, 4)):
    '''This function removes any extra content that extends beyond the desired number of bars. 
    It takes an input MIDI file, processes it to match the desired number of bars.'''

    if isinstance(input_file, str):
        input_midi = mido.MidiFile(input_file)
    elif isinstance(input_file, mido.MidiFile):
        input_midi = input_file

    # Calculate ticks per bar and total ticks for desired number of bars
    ticks_per_beat = input_midi.ticks_per_beat
    ticks_per_bar = ticks_per_beat * time_signature[0]
    total_ticks = ticks_per_bar * num_bars

    # Create a new MidiFile object for the output
    output_midi = mido.MidiFile(ticks_per_beat=ticks_per_beat)

    for track in input_midi.tracks:
        output_track = mido.MidiTrack()
        output_midi.tracks.append(output_track)
        current_ticks = 0

        for msg in track:
            if msg.is_meta:
                output_track.append(msg.copy())
            else:
                # Calculate the tick position in the output track
                output_ticks = current_ticks + msg.time

                # If the message goes beyond the desired length, break the loop
                if output_ticks >= total_ticks:
                    break

                # Clone the message and add it to the output track
                output_msg = msg.copy(time=msg.time)
                output_track.append(output_msg)
                current_ticks += msg.time

    return output_midi

def proc_meta(df, add_genre=False, tensor=False):
    '''Process the metadata and output DataFrame or tensor'''
    tempo = df.loc[df['type']=='set_tempo', 'tempo'].values[0]
    data = {'type': ['time_signature', 'tempo'],
            'value': [(df.numerator[0], df.denominator[0]), tempo]}
    if add_genre:
        #TODO: one-hot or integer encode the genre info later
        pass
    
    if tensor:
        return tf.constant([df.numerator[0], df.denominator[0], tempo], dtype=tf.float32)
    return pd.DataFrame(data)

In [4]:
#File management

class MIDIfiles:

    def __init__(self, midi_dir):
        '''This class imports a list of midi files in a given path.'''
        self.midi_dir = midi_dir
    
    def get_paths(self):
        '''This function returns a list of midi files in a given path.'''
        midi_files=[]
        for path, subdirs, files in os.walk(self.midi_dir): 
            for file in files:
                if file.endswith('.mid'):
                    midfile = os.path.join(path, file)
                    midi_files.append(midfile)
        return midi_files

In [5]:
class MIDIgroup:

    def __init__(self, midi_path, n_bars=2):
        self.midi_path = midi_path
        self.n_bars = n_bars

    def is_44(self):
        '''This function returns True if the midi file is in 4/4 time signature.'''
        if extract_time_signature(self.midi_path) == '4/4':
            return True
        return False
        
    def is_short(self):
        '''This function returns True if the midi file is shorter than 2 bars.'''
        if is_shorter_or_longer_than_n_bars(self.midi_path, self.n_bars):
            return True
        return False
    
    def is_long(self):
        '''This function returns True if the midi file is longer than 2 bars.'''
        if is_shorter_or_longer_than_n_bars(self.midi_path, self.n_bars, shorter=False):
            return True
        return False
        
    def is_fill(self):
        '''This function returns True if the midi file is a fill.'''
        if findstr(self.midi_path, 'fill'):
            return True
        return False


In [20]:
#DIRECTORY
dat_dir = "/Users/cagrierdem/Desktop/ongoing/POSTDOC/dB_workspace/drumbot/dB_dat"
MAIN_DIR = os.path.join(dat_dir, "DOOM")
TEST_DIR = os.path.join(dat_dir, "TEST")
midi_files = MIDIfiles(TEST_DIR)

In [21]:
#Grouping MIDI files in the dataset 
#based on their time signature and whether they are loops or fills

fills44 = []
shortfills44 = []
longfills44 = []
loops44 = []
rest=[]
num_bars = 2
for file in midi_files.get_paths():
    check = MIDIgroup(file, n_bars=num_bars)
    #1) Is the file is 4/4 (has metadata as 4/4)?
    if check.is_44():
        #2) Is the file a fill (is it in the name)?
        if check.is_fill():
            fills44.append(file)
            #3) Is the fill shorter or longer than 4 bars?
            if check.is_short():
                shortfills44.append(file)
            else:
                longfills44.append(file)
        else:
            #All the 4/4 files that are not fills
            loops44.append(file)
    else:
        #All the files that are not 4/4
        rest.append(file)

#Finally:
print(
    'Currently, there are {} MIDI files in the dataset.\n'.format(len(midi_files.get_paths())),
    '{} are 4/4 loops that are longer than {} bars.\n'.format(len(loops44), num_bars),
    '{} are 4/4 fills, of which {} are short and {} are long.\n'.format(len(fills44), len(shortfills44), len(longfills44)),
    'We disregard a total of {} files, which are not 4/4.\n'.format(len(rest))
)

type0 = []
type1 = []
for loop in loops44:
    type0.append(loop) if get_midi_type(loop) == 0 else type1.append(loop)
print(f' Among 4/4 loops, {len(type0)} are Type0 and {len(type1)} are Type1 MIDI')

Currently, there are 2953 MIDI files in the dataset.
 406 are 4/4 loops that are longer than 2 bars.
 98 are 4/4 fills, of which 39 are short and 59 are long.
 We disregard a total of 2449 files, which are not 4/4.

 Among 4/4 loops, 176 are Type0 and 230 are Type1 MIDI


In [22]:
#TODO: EXPLORE HOW TO DEAL WITH DIFFERENT RESOLUTIONS

resolutions=[]
res=[]
for i, midi_file in enumerate(loops44):
    midi_data = mido.MidiFile(midi_file)
    ticks = midi_data.ticks_per_beat
    resolutions.append([i, ticks])
    res.append(ticks)

resarr = np.array(res)
occured_ticks = np.unique(resarr)
print("There are {} different resolutions in the dataset:\n".format(len(occured_ticks)))
for t in occured_ticks:
    print("{}-tick occurs {} times.".format(t, res.count(t)))

There are 3 different resolutions in the dataset:

96-tick occurs 137 times.
960-tick occurs 39 times.
15360-tick occurs 230 times.


In [23]:
# q = input("Wanna export as text?")
# if q.startswith('y') or q.startswith('Y'):
#     with open ('dataset_midi_res.txt', 'w') as f:
#         for item in resolutions:
#             f.write(str(item))
#             f.write('\n')
# else:
#     pass

In [24]:
# Check if the ticks per beat value is the same across all MIDI files
non_consts = []
midi_ticks = []
ticks_per_beat = None
for midi_file in loops44:
    midi_data = mido.MidiFile(midi_file)
    if ticks_per_beat is None:
        ticks_per_beat = midi_data.ticks_per_beat
    elif midi_data.ticks_per_beat != ticks_per_beat:
        # print(f'Error: ticks per beat value is not consistent across MIDI files')
        non_consts.append([midi_file, midi_data.ticks_per_beat])

if ticks_per_beat is not None:
    print(f'Ticks per beat value is consistent across all MIDI files: {ticks_per_beat}')


Ticks per beat value is consistent across all MIDI files: 960


In [25]:
#Find and list the MIDI files that don't have tempo information
wout_meta = []
for i, midi_path in enumerate(loops44):
    metadata = get_meta(midi_path)
    if 'tempo' not in metadata:
        wout_meta.append(midi_path)

# list2text(wout_meta, 'midi_wout_tempo')

# There are 1744 files, each has metadata but NO tempo information. 
#And these are all from 8000000 dataset

In [26]:
#MIDI PREPROCESSING UTILS

def update_tempo_from_path(midi_path, save_to_file=False, verbose=False):
    '''Add tempo information to a MIDI file if it's not present.
    The function looks for a number followed by "BPM" (case-insensitive) in the file path, 
    extracts the BPM, checks for the presence of "_to<number>" at the end of the filename 
    and updates the tempo information accordingly, even if the MIDI file already has tempo information.'''
    
    midi_file = mido.MidiFile(midi_path)

    # Extract BPM from the file path using a regex pattern
    bpm_pattern = re.compile(r'\d{2,3}(?=BPM)', re.IGNORECASE)
    match = bpm_pattern.search(midi_path)

    # Extract the custom BPM from the file path, if present
    custom_bpm_pattern = re.compile(r'_to(\d{2,3})', re.IGNORECASE)
    custom_match = custom_bpm_pattern.search(midi_path)
    custom_bpm = None

    if custom_match:
        custom_bpm = int(custom_match.group(1))

    # Check if the MIDI file has tempo information
    has_tempo = False
    for track in midi_file.tracks:
        for msg in track:
            if msg.is_meta and msg.type == 'set_tempo':
                has_tempo = True
                break
        if has_tempo:
            break

    if match or custom_bpm or not has_tempo:
        # Add or update tempo information
        if not has_tempo:
            bpm = int(match.group()) if match else 120
            tempo = mido.bpm2tempo(bpm)
            new_track = [mido.MetaMessage('set_tempo', tempo=tempo, time=0)]
            new_track.extend(midi_file.tracks[0])
            midi_file.tracks[0] = new_track
            if verbose:
                print(f"Tempo information added to: {midi_path}")

        if custom_bpm:
            tempo = mido.bpm2tempo(custom_bpm)
            for track in midi_file.tracks:
                for msg in track:
                    if msg.is_meta and msg.type == 'set_tempo':
                        msg.tempo = tempo
            if verbose:
                print(f"Tempo information updated to {custom_bpm} in: {midi_path}")

        if save_to_file:
            midi_file.save(midi_path)

    else:
        if verbose:
            print(f"No BPM information found in the path: {midi_path}")

    return midi_file


def merge_tracks(input_midi):
    '''This function takes a multi-track MIDI file 
    and merges its tracks into a single track while preserving the timing & metadata'''
    if isinstance(input_midi, str):
        input_midi = mido.MidiFile(input_midi)

    merged_midi = mido.MidiFile(type=0, ticks_per_beat=input_midi.ticks_per_beat)
    merged_track = mido.MidiTrack()
    merged_midi.tracks.append(merged_track)

    # To keep track of whether the meta messages have been added to the merged track
    time_signature_added = False
    set_tempo_added = False

    for track in input_midi.tracks:
        for msg in track:
            if msg.is_meta:
                if msg.type == 'time_signature' and not time_signature_added:
                    merged_track.append(msg.copy())
                    time_signature_added = True
                elif msg.type == 'set_tempo' and not set_tempo_added:
                    merged_track.append(msg.copy())
                    set_tempo_added = True
                elif msg.type not in ('time_signature', 'set_tempo'):
                    merged_track.append(msg.copy())
            else:
                merged_track.append(msg.copy())

    return merged_midi


def keep_initial_metadata(midi_file):
    '''This function reads the input MIDI file and creates a new MIDI file with the same type. 
    It then goes through each track and message, copying the messages to the new MIDI file. 
    The function keeps only the initial tempo and time signature events, skipping any subsequent changes. 
    All other meta messages and non-meta messages are preserved.'''
    if isinstance(midi_file, str):
        input_midi = mido.MidiFile(midi_file)
    elif isinstance(midi_file, mido.MidiFile):
        input_midi = midi_file

    output_midi = mido.MidiFile(type=input_midi.type)
    
    for track in input_midi.tracks:
        new_track = mido.MidiTrack()
        output_midi.tracks.append(new_track)

        # Keep track of whether the initial tempo and time signature have been found
        initial_tempo_found = False
        initial_time_signature_found = False

        for msg in track:
            if msg.is_meta:
                if msg.type == 'set_tempo' and not initial_tempo_found:
                    new_track.append(msg)
                    initial_tempo_found = True
                elif msg.type == 'time_signature' and not initial_time_signature_found:
                    new_track.append(msg)
                    initial_time_signature_found = True
                elif msg.type not in ('set_tempo', 'time_signature'):
                    new_track.append(msg)
            else:
                new_track.append(msg)

    return output_midi


def unify_midi_res(midi_path, target_resolution=480, save_midi=False):
    '''
    Returns the MIDI file with a given resolution using Mido.
    '''
    if isinstance(midi_path, str):
        midi_data = mido.MidiFile(midi_path)
    elif isinstance(midi_path, mido.MidiFile):
        midi_data = midi_path

    # Calculate the conversion factor
    source_resolution = midi_data.ticks_per_beat
    conversion_factor = target_resolution / source_resolution

    # Check if source resolution is different from the target resolution
    if source_resolution != target_resolution:
        for track in midi_data.tracks:
            for event in track:
                if event.type in ['note_on', 'note_off', 'control_change']:
                    # Scale the tick values of note on/off and control change events
                    event.time = round(event.time * conversion_factor)
                # Tempo events are not modified, as the relationship between ticks and real-world time is maintained automatically

        # Update the ticks_per_beat of the MIDI file to the target resolution
        midi_data.ticks_per_beat = target_resolution
        if save_midi:
            # Save the modified MIDI file
            output_path = os.path.splitext(midi_path)[0] + '_standardized.mid'
            midi_data.save(output_path)

    return midi_data


def extract_note_info(midi_file):
    '''This function extracts note information from the original and merged MIDI files 
    and then compare the results to ensure that the note information is preserved during the merging process'''
    note_info = []
    
    for track in midi_file.tracks:
        current_time = 0
        for msg in track:
            current_time += msg.time
            if not msg.is_meta and msg.type == 'note_on':
                note_info.append((msg.channel, msg.note, msg.velocity, current_time))
    
    return note_info


def compare_midi_files(original_midi, converted_midi, tolerance=0.01):
    '''This function compares MIDI file durations. 
    --> it compares the number of tracks and messages instead of instruments and notes
    This is basically for a safety measure, in case TPQN conversion causes issues 
    (e.g., time-stretches) in the MIDI file structure'''

    # Calculate the total duration of both files
    original_duration = original_midi.length
    converted_duration = converted_midi.length

    # Compare the duration of the files
    duration_difference = abs(original_duration - converted_duration)

    # Initialize a dictionary to store the differences
    differences = {}

    # Check if the duration difference is above the tolerance
    if duration_difference > tolerance:
        differences['duration'] = (original_duration, converted_duration)

    # Compare the number of tracks
    if len(original_midi.tracks) != len(converted_midi.tracks):
        differences['track_count'] = (len(original_midi.tracks), len(converted_midi.tracks))

    # Compare the number of messages for each track
    for i, (original_track, converted_track) in enumerate(zip(original_midi.tracks, converted_midi.tracks)):
        if len(original_track) != len(converted_track):
            differences[f'track_{i}_message_count'] = (len(original_track), len(converted_track))

    return differences

def midi_length_ticks(midi_object):
    '''This function returns the maximum track length in ticks
    by iterating through all the tracks in a MIDI file and
    summing the ticks of all the messages'''

    midi_length_ticks = 0
    for track in midi_object.tracks:
        track_length_ticks = sum(msg.time for msg in track)
        midi_length_ticks = max(midi_length_ticks, track_length_ticks)

    return midi_length_ticks


def extract_metadata(midi_object):
    '''as the name suggests... same as get_data() so TODO: combine them into one'''
    metadata = []

    for track in midi_object.tracks:
        for msg in track:
            if msg.is_meta:
                metadata.append(msg.copy())

    return metadata


def print_midi_metadata(midi_path):
    '''not necessary, just for laziness...'''
    midi_file = mido.MidiFile(midi_path)
    for i, track in enumerate(midi_file.tracks):
        print(f"Track {i}:")
        for msg in track:
            if msg.is_meta:
                print(msg)


def split_midi(input_file, num_bars_per_section=2, time_signature=(4, 4), verbose=False):
    '''This function splits a MIDI file into sections of the desired length.'''

    if isinstance(input_file, str):
        input_midi = mido.MidiFile(input_file)
    elif isinstance(input_file, mido.MidiFile):
        input_midi = input_file
    
    #this verbose is new
    if verbose:
        print("Note messages before slicing:")
        for track in input_midi.tracks:
            for msg in track:
                if not msg.is_meta and msg.type in ('note_on', 'note_off'):
                    print(msg)

    ticks_per_beat = input_midi.ticks_per_beat #extract TPQN from the input file
    ticks_per_bar = ticks_per_beat * time_signature[0] #calculate ticks per bar
    ticks_per_section = ticks_per_bar * num_bars_per_section 

    #TODO: Make sure to cover the entire length of the input MIDI file!
    # --> Calculate the total number of bars and round it up to the nearest integer before calculating the section start ticks.
    # total_bars = math.ceil(input_midi.length * ticks_per_beat / ticks_per_bar)
    total_bars = math.ceil(midi_length_ticks(input_midi) / ticks_per_bar)

    metadata = extract_metadata(input_midi) #extract metadata from the input file

    if verbose:
        print("Total length of the MIDI file: {} seconds & {} bars".format(round(input_midi.length, 2), total_bars))
        print("Ticks per beat: {}".format(ticks_per_beat))
        print("The loop will start from 0, will go to {} with steps of {}".format(int(input_midi.length * ticks_per_beat), ticks_per_section))
        s = 0

    sections = []
    # for section_start_tick in range(0, int(input_midi.length * ticks_per_beat), ticks_per_section):
    for section_start_tick in range(0, total_bars * ticks_per_bar, ticks_per_section):
        section_midi = mido.MidiFile(ticks_per_beat=ticks_per_beat)

        if verbose:
            s += 1
            print(f'Section {s}')

        for track in input_midi.tracks:
            section_track = mido.MidiTrack()
            section_track.extend(metadata) #adding the extracted metadata to each section
            section_midi.tracks.append(section_track)

            current_ticks = 0
            section_ticks = 0
            for msg in track:
                current_ticks += msg.time
                if section_start_tick <= current_ticks < section_start_tick + ticks_per_section:
                    if msg.is_meta:
                        if msg.type == "end_of_track":
                            section_track.append(msg.copy(time=0))
                        else:
                            section_track.append(msg.copy())
                    else:
                        section_msg = msg.copy(time=current_ticks - section_ticks - section_start_tick)
                        section_track.append(section_msg)
                        section_ticks = current_ticks - section_start_tick  # Update the section_ticks

        sections.append(section_midi)

    #this verbose is new
    if verbose: 
        for i, section in enumerate(sections):
            print(f"Note messages in section {i + 1}:")
            for track in section.tracks:
                for msg in track:
                    if not msg.is_meta and msg.type in ('note_on', 'note_off'):
                        print(msg)


    return sections



In [27]:
TESTMID = loops44[49]
test_midi = mido.MidiFile(TESTMID)  
new_test_midi = unify_midi_res(TESTMID, target_resolution=480)
print("{} TPQN -> {} TPQN".format(test_midi.ticks_per_beat, new_test_midi.ticks_per_beat))

#7440 idx'li track 15360 TPQN bi sacmaliyo baska bisi deneyelim
#0 idx (240 TPQN) fistik gibi calisti
#7399 idx 96 TPQN de fistik
#5427 idx 6900 TPQN de gauyet oluyo

96 TPQN -> 480 TPQN


In [28]:
#Compare (inspect) MIDI files of adjusted resolutions to their original versions 
#and save them in a single folder in the given directory
# output_directory = '/Users/cagrierdem/Desktop/ongoing/POSTDOC/practice/practiceMIDI/datatest'

suspicious_files = []
diffs=[]

normal_loops=[]
short_loops=[]

splitted_loops = []
splitted_loop_path=[]
# Process each MIDI file in the input directory
for i, midi_file in enumerate(loops44):

    original_midi = mido.MidiFile(midi_file)

    #Update the tempo information in metadata
    updated_midi = update_tempo_from_path(midi_file, save_to_file=False, verbose=False)

    #Convert all your MIDI files into single-track MIDI files
    merged_updated_midi = merge_tracks(updated_midi)

    # Convert the MIDI file
    converted_midi = unify_midi_res(merged_updated_midi, target_resolution=480)

    # Compare the original and converted MIDI files
    differences = compare_midi_files(merged_updated_midi, converted_midi)
    diffs.append(differences)

    #According to the result of comparison:
    if not differences:

        #Disregard the files that are shorter than 2 bars
        if is_shorter_or_longer_than_n_bars(converted_midi, n_bars=1, shorter=False):

            #Keep the initial metadata –prevent duplicates
            processed_midi = keep_initial_metadata(converted_midi)
            normal_loops.append(processed_midi)

            #Split the MIDI file into 2-bar sections
            splits = split_midi(processed_midi, num_bars_per_section=2, verbose=False)

            #Disregard split lists that have more than 8 splits
            #TODO: Understand the problems when not disregarding
            if len(splits) <= 8:
                splitted_loops.append(splits) 
                splitted_loop_path.append(loop)
            else:
                continue
        else:
            short_loops.append(midi_file)

    else:
        # print(f"Suspicious files: {original_midi} vs {converted_midi}")
        suspicious_files.append(midi_file)
        print("Differences:")
        for key, value in differences.items():
            print(f"{key}: {value}")

print(f"Out of {len(loops44)} 4/4 loops in the dataset, {len(normal_loops)} are 2 bars or longer, {len(short_loops)} are shorter.")
print(f"There are a total of {len(splitted_loops)} lists of 2-bar splits now.")
print(f"{len(suspicious_files)} files are suspicious.\n")

#Check the lengths of the splitted loops
splits = [len(loop_list) for loop_list in splitted_loops]
total_loops = sum(splits)
print(f'In total, There are {total_loops} sliced loops in the dataset, and there are\n{np.unique(np.array(splits))} splits per loop.')

Out of 406 4/4 loops in the dataset, 328 are 2 bars or longer, 78 are shorter.
There are a total of 98 lists of 2-bar splits now.
0 files are suspicious.

In total, There are 218 sliced loops in the dataset, and there are
[1 2 3 4 5] splits per loop.


In [None]:
#Find out the loops with anormal number of splits
#Now that we disregard len(list) > 8, this loop may seem redundant, but can be functional in other cases, so keep it.
anorms=[]
for i, loop in enumerate(splitted_loops):
    if len(loop) > 8:
        anorms.append([i, len(loop)])

print(len(anorms))
#In the end, we should discard loops with more than 12 splits

In [29]:
#Iterate over each split in the new directory,
#DISREGARD all splits that are longer than 2 bars
#Save them into a new folder
#TODO: It may be highly unnecessary to create a new directory for the second time--> Make this process more efficient
splits_dir = '/Users/cagrierdem/Desktop/ongoing/POSTDOC/practice/practiceMIDI/2bar_loops_test'
longer=[]
normal=[]
for i, loop in enumerate(splitted_loops):
    for j, split in enumerate(loop):
        save_name = f'{i}_{j}.mid'
        if not is_shorter_or_longer_than_n_bars(split, n_bars=2, shorter=False) and not is_shorter_or_longer_than_n_bars(split, n_bars=1, shorter=True):
            normal.append(os.path.join(splits_dir, save_name))
            if not os.path.exists(os.path.join(splits_dir, save_name)):
                split.save(os.path.join(splits_dir, save_name))
            else:
                continue
        else:
            longer.append(os.path.join(splits_dir, save_name))

print(f'There are {len(os.listdir(splits_dir))} 2-bar (splitted) loops in the new dataset.')

There are 76 2-bar (splitted) loops in the new dataset.


In [30]:
#Inspect (once again) if there are splits longer than 2 bars or not
#If not, you are good to proceed to the make-matrix stage
sliced_loops = MIDIfiles(splits_dir)
sliced_datadir = sliced_loops.get_paths()
longs_w_issue=[]
longs2bar = []
for idx, file in enumerate(sliced_datadir):
    check = MIDIgroup(file, n_bars=2)
    if check.is_long():
        longs_w_issue.append(file)
    else:
        longs2bar.append(file)
        
#Finally:
print(f'{len(longs_w_issue)} over {len(sliced_datadir)} files are longer than 2 bars.\n')

0 over 75 files are longer than 2 bars.



In [94]:
DATASET = '/Users/cagrierdem/Desktop/ongoing/POSTDOC/practice/practiceMIDI/2bar_loops'
print(f"Processed DATASET dir: {DATASET}")

Loops = MIDIfiles(DATASET).get_paths()
Loops.sort()

# for loop in Loops:
#     metadata = get_meta(loop)
#     if not 'tempo' in metadata:
#         print(f"Tempo not found in {loop}")
#         break 

Processed DATASET dir: /Users/cagrierdem/Desktop/ongoing/POSTDOC/practice/practiceMIDI/2bar_loops
