In [1]:
import pretty_midi
from tqdm.auto import tqdm
from pathlib import Path
import torch
import numpy as np
from collections import defaultdict
import warnings

In [9]:
min_drum_pitch = 35
max_drum_pitch = 81
n_drum_pitches = max_drum_pitch - min_drum_pitch + 1

min_pitch = 0
max_pitch = 127
n_pitches = max_pitch - min_pitch + 1

sequence_length = 128
n_velocities = 128
n_instruments = 4

max_files = 10

In [None]:
midi_files = list(Path(r"../Lakh_MIDI_Dataset").rglob("*.mid"))
print(len(midi_files))

In [4]:
# return what an instrument be simplified to
def categorize_instrument(instrument):
  if instrument.is_drum:
    return "Drums"
  prog = instrument.program
  if 32 <= prog < 40:
    return "Bass"
  if 80 <= prog < 88:
    return "Lead"
  if 40 <= prog < 48:
    if "violin" in instrument.name.lower():
      return "Lead"
    else:
      return "Chords"
  #  pianos, organs, guitars, and synth pads
  if (0 <= prog < 8) or (16 <= prog < 24) or (24 <= prog < 32) or (88 <= prog < 96):
    return "Chords"
  # default
  return "Chords"

CATEGORY_PROGRAMS = {
  "Drums": (0, True),      # No Program
  "Bass": (33, False),     # Acoustic Bass
  "Chords": (0, False),    # Acoustic Grand Piano
  "Lead": (56, False)      # Trumpet
}

# return a new midi object after merging similar instruments
def merge_instruments(midi_obj):
  merged_tracks = defaultdict(lambda: None)

  for instrument in midi_obj.instruments:
    category = categorize_instrument(instrument)
    if category:
      if merged_tracks[category] is None:
        program, is_drum = CATEGORY_PROGRAMS[category]
        merged_tracks[category] = pretty_midi.Instrument(
            program=program, 
            is_drum=is_drum, 
            name=category
        )
      merged_tracks[category].notes.extend(instrument.notes)

  CATEGORY_ORDER = ["Drums", "Bass", "Chords", "Lead"]
  new_midi = pretty_midi.PrettyMIDI()
  for cat in CATEGORY_ORDER:
    if cat in merged_tracks and merged_tracks[cat] is not None:
      new_midi.instruments.append(merged_tracks[cat])

  return new_midi

In [5]:
# Pitch Range of Drum: 35-81
# note_seq consists of notes which consists of pitch, velocity, duration, step, instrument(0:drum, 1:bass, 2:chords, 3:lead)
def create_roll(midi_file):
  # preprocessing and saving it into a single list containing all attributes
  try:
    with warnings.catch_warnings():
      warnings.simplefilter("ignore", RuntimeWarning)
      ex = pretty_midi.PrettyMIDI(str(midi_file))
  except Exception as e:
    return None
  
  ex = merge_instruments(ex)
  """
  note_seq_list = [
        (note.pitch, note.velocity, note.end - note.start, note.start, ins_idx)
        for ins_idx, instrument in enumerate(ex.instruments[:4])
        for note in instrument.notes
    ]
  """
  note_seq_list = [
    (max(35, min(note.pitch, 81)) if ins_idx == 0 else note.pitch,
     note.velocity, 
     note.end - note.start, 
     note.start, 
     ins_idx)
    for ins_idx, instrument in enumerate(ex.instruments[:4])
    for note in instrument.notes
]


  # sorting based on note start time
  note_seq_arr = np.array(note_seq_list, dtype=np.float32)
  del note_seq_list
  note_seq_arr = note_seq_arr[np.argsort(note_seq_arr[:, 3])]

  # converting start time to step
  note_seq_arr[1:, 3] -= note_seq_arr[:-1, 3].copy()
  note_seq_arr[0, 3] = 0

  # convert to tensor
  note_seq = torch.from_numpy(note_seq_arr)
  return note_seq

In [6]:
def create_sequence(roll, sequence_length):
  sequences = torch.zeros((len(roll) - sequence_length - 1, sequence_length, 5), dtype=torch.float32)
  targets = torch.zeros((len(roll) - sequence_length - 1, 5), dtype=torch.float32)
  
  for i in range(0, len(roll) - sequence_length - 1):
    sequences[i] = roll[i:i + sequence_length]
    targets[i] = roll[i + sequence_length]

  return sequences, targets

In [None]:
sequences = []
targets = []
files_done = 0
bar = tqdm(total=max_files)
for idx, midi_file in enumerate(midi_files):
  if (files_done == max_files):
    print(f"{max_files} valid files out of {idx + 1}")
    break
  roll = create_roll(midi_file)
  if roll is None:
    continue
  file_sequences, file_targets = create_sequence(roll, sequence_length)
  sequences.append(file_sequences)
  targets.append(file_targets)
  files_done += 1
  bar.update(1)
bar.close()

In [11]:
sequences_tensor = torch.cat(sequences, dim=0)
targets_tensor = torch.cat(targets, dim=0)

std_duration = sequences_tensor[:, :, 2].std()
mean_duration = sequences_tensor[:, :, 2].mean()
std_step = sequences_tensor[:, :, 3].std()
mean_step = sequences_tensor[:, :, 3].mean()

sequences_tensor[:, :, 2] = (sequences_tensor[:, :, 2] - mean_duration) / std_duration
sequences_tensor[:, :, 3] = (sequences_tensor[:, :, 3] - mean_step) / std_step
targets_tensor[:, 2] = (targets_tensor[:, 2] - mean_duration) / std_duration
targets_tensor[:, 3] = (targets_tensor[:, 3] - mean_step) / std_step

dataset_name = "lmd"
torch.save({"sequences": sequences_tensor, "targets": targets_tensor}, f"data/{dataset_name}-{max_files}.pth")
torch.save({"std_duration": std_duration, "mean_duration": mean_duration, "std_step": std_step, "mean_step": mean_step}, f"data/{dataset_name}-{max_files}-denorm.pth")

In [None]:
print(targets_tensor.shape)

In [None]:
print(sequences_tensor[0])

In [42]:
def tensor_to_midi(seqs, output_file="output.mid"):
    # Define MIDI program mapping for each instrument category
    CATEGORY_PROGRAMS = {
        0: (0, True),   # Drums (is_drum=True)
        1: (32, False), # Bass (Acoustic Bass, program 32)
        2: (0, False),  # Chords (Acoustic Grand Piano, program 0)
        3: (56, False)  # Lead (Trumpet, program 56)
    }

    midi = pretty_midi.PrettyMIDI()
    instruments = {
        i: pretty_midi.Instrument(program=CATEGORY_PROGRAMS[i][0], is_drum=CATEGORY_PROGRAMS[i][1])
        for i in range(4)  # 0: Drums, 1: Bass, 2: Chords, 3: Lead
    }
    
    current_time = 0.0
    for i in range(seqs.shape[0]):
        pitch = int(seqs[i, 0].item())
        velocity = int(seqs[i, 1].item())
        duration = float(seqs[i, 2].item())
        step = float(seqs[i, 3].item())
        instrument = int(seqs[i, 4].item())
        current_time += step
        end_time = current_time + duration
        midi_note = pretty_midi.Note(
            velocity=int(velocity),
            pitch=int(pitch),
            start=current_time,
            end=end_time
        )
        
        instruments[instrument].notes.append(midi_note)

    for instr in instruments.values():
        midi.instruments.append(instr)

    midi.write(output_file)

In [43]:
tensor_to_midi(sequences[0][0], 'test9090.mid')

In [None]:
for i in tqdm(range(sequences_tensor.shape[0])):
  for j in range(sequences_tensor.shape[1]):
    if (sequences_tensor[i, j, 4] < 0 or sequences_tensor[i, j, 4] > 3):
      print("impossible")