<a href="https://colab.research.google.com/github/brandonso994/AttnLSTMMusicGeneration/blob/main/V2/Generate_Notes_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from music21 import converter, instrument, note, chord, stream, volume
from fractions import Fraction
import matplotlib.pyplot as plt
import glob
import numpy as np
import nltk
import pandas as pd
import pickle
from keras.models import Sequential, Model
from keras.layers import LSTM, Dropout, Dense, Activation, Bidirectional, BatchNormalization, Input, Concatenate
from keras.utils import np_utils
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras.metrics import CategoricalAccuracy
import tensorflow as tf
from keras import backend as K

from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
def flatten(array):
  new_array = [item for array in array for item in array]
  return new_array

# Create Sequence with window length = sequence_length for input into model
def sequence(pitches, durations, full_pitches, full_durations, pitch_set_len, duration_set_len, sequence_length, midi_input):
    pitch_values = sorted(set(pitch for pitch in full_pitches))
    duration_values = sorted(set(duration for duration in full_durations))

    pitch_encode = dict((pitch, num) for num, pitch in enumerate(pitch_values))
    duration_encode = dict((duration, num) for num, duration in enumerate(duration_values))

    pitch_network_input = []
    duration_network_input = []

    for i in range(len(pitches) - sequence_length):
        pitch_sequence = pitches[i:i + sequence_length]
        duration_sequence = durations[i:i + sequence_length]

        pitch_input_sequence = [pitch_encode[pitch] for pitch in pitch_sequence]
        duration_input_sequence = [duration_encode[duration] for duration in duration_sequence]

        pitch_network_input.append(pitch_input_sequence)
        duration_network_input.append(duration_input_sequence)

        if midi_input:
          break

    n_patterns_pitch = len(pitch_network_input)
    n_patterns_duration = len(duration_network_input)

    pitch_network_input_norm = np.array(pitch_network_input)
    duration_network_input_norm = np.array(duration_network_input)

    pitch_network_input_norm = np.reshape(pitch_network_input, (n_patterns_pitch, sequence_length, 1))
    duration_network_input_norm = np.reshape(duration_network_input, (n_patterns_duration, sequence_length, 1))

    pitch_network_input_norm = pitch_network_input_norm / float(pitch_set_len)
    duration_network_input_norm = duration_network_input_norm / float(duration_set_len)

    return pitch_network_input, duration_network_input, pitch_network_input_norm, duration_network_input_norm

def create_model(pitch_network_input, duration_network_input, pitch_set_len, duration_set_len, version_num):
    # Pitch input
    pitch_input = Input(shape=(pitch_network_input.shape[1], pitch_network_input.shape[2]))
    pitch_lstm1 = Bidirectional(LSTM(512, return_sequences=True))(pitch_input)

    # Duration input
    duration_input = Input(shape=(duration_network_input.shape[1], duration_network_input.shape[2]))
    duration_lstm1 = Bidirectional(LSTM(512, return_sequences=True))(duration_input)

    # Concatenate the LSTM outputs
    lstm_concat = Concatenate()([pitch_lstm1, duration_lstm1])
    joint_lstm1 = LSTM(1024, return_sequences=True)(lstm_concat)
    joint_dropout1 = Dropout(0.3)(joint_lstm1)
    join_lstm2 = LSTM(1024, return_sequences=False)(joint_dropout1)
    joint_dropout2 = Dropout(0.3)(join_lstm2)

    # Output layers for pitch, duration
    pitch_output = Dense(pitch_set_len, activation='softmax', name='pitch')(joint_dropout2)
    duration_output = Dense(duration_set_len, activation='softmax', name='duration')(joint_dropout2)

    # Combine the outputs
    model = Model(inputs=[pitch_input, duration_input], outputs=[pitch_output, duration_output])
    model.load_weights('/content/drive/My Drive/MRP/Model/model_weights_' + version_num + '_checkpoints.h5')

    return model

# Predicts the next note given the sequence
def note_prediction(model, pitch_network_input, duration_network_input, pitches, durations, seq_len, num_notes=500, midi_input=False, pitch_temp= 0.5, duration_temp = 0.5):
  pitch_set_len = len(set(pitches))
  duration_set_len = len(set(durations))

  pitch_values = sorted(set(pitch for pitch in pitches))
  duration_values = sorted(set(duration for duration in durations))

  pitch_decode = dict((num, pitch) for num, pitch in enumerate(pitch_values))
  duration_decode = dict((num, duration) for num, duration in enumerate(duration_values))

  pitch_encode = dict((pitch, num) for num, pitch in enumerate(pitch_values))
  duration_encode = dict((duration, num) for num, duration in enumerate(duration_values))

  # Whether input was given, or start randomly from sequence
  if midi_input:
    start = 0
  else:
    start = np.random.randint(0, len(pitch_network_input) - 1)

  pitch_pattern = pitch_network_input[start]
  duration_pattern = duration_network_input[start]

  pitch_network_input_orig = np.array(pitch_network_input).tolist()
  duration_network_input_orig = np.array(duration_network_input).tolist()

  generated_pitches = []
  generated_durations = []

  # Loop, predicting each note and appending the prediction to the sequence, removing the oldest note to preserve sequence length
  for item in range(num_notes):
    prediction_input_pitch = np.reshape(pitch_pattern, (1, len(pitch_pattern), 1))
    prediction_input_duration = np.reshape(duration_pattern, (1, len(duration_pattern), 1))

    prediction_input_pitch = prediction_input_pitch / float(pitch_set_len)
    prediction_input_duration = prediction_input_duration / float(duration_set_len)

    predictions = model.predict([prediction_input_pitch, prediction_input_duration], verbose=0)

    predicted_pitch = predictions[0]
    predicted_duration = predictions[1]

    # Apply temperature
    predicted_pitch = np.log(predicted_pitch) / pitch_temp
    predicted_duration = np.log(predicted_duration) / duration_temp

    pitch_probs = np.exp(predicted_pitch) / np.sum(np.exp(predicted_pitch))
    duration_probs = np.exp(predicted_duration) / np.sum(np.exp(predicted_duration))

    # Sample the next pitch and duration using the temperature-adjusted probabilities
    next_pitch = np.random.choice(len(pitch_probs[0]), p=pitch_probs[0])
    next_duration = np.random.choice(len(duration_probs[0]), p=duration_probs[0])

    decoded_pitch = pitch_decode[next_pitch]
    decoded_duration = duration_decode[next_duration]

    generated_pitches.append(decoded_pitch)
    generated_durations.append(decoded_duration)

    pitch_pattern.append(next_pitch)
    duration_pattern.append(next_duration)

    pitch_pattern = pitch_pattern[1:len(pitch_pattern)]
    duration_pattern = duration_pattern[1:len(duration_pattern)]

    # Decrease temperature slowly if greater than 1
    if pitch_temp > 1:
      pitch_temp -= 0.02

    if duration_temp > 1:
      duration_temp -= 0.01

  return generated_pitches, generated_durations

# Generate midi from predicted notes
def generate_midi(pitch_sequence, duration_sequence, version_num, filenum):
  output_notes = []
  offset = 0
  x = 0
  for pitch, duration_offset in zip(pitch_sequence, duration_sequence):
    duration = float(duration_offset.split(':')[0])
    new_offset = float(duration_offset.split(':')[1])

    if ('.' in pitch) or pitch.isdigit():
        notes_in_chord = pitch.split('.')
        notes = []
        for current_note in notes_in_chord:
            new_note = note.Note(int(current_note))
            new_note.storedInstrument = instrument.Piano()
            notes.append(new_note)
        new_chord = chord.Chord(notes)
        new_chord.duration.quarterLength = float(duration)
        new_chord.offset = offset
        output_notes.append(new_chord)
    else:
        new_note = note.Note(pitch)
        new_note.duration.quarterLength = float(duration)
        new_note.offset = offset
        new_note.storedInstrument = instrument.Piano()
        output_notes.append(new_note)
    offset += new_offset

  midi_stream = stream.Stream(output_notes)
  piano = instrument.Piano()
  midi_stream.insert(0, piano)
  midi_stream.write('midi', fp='/content/drive/My Drive/MRP/test_file_' +str(version_num) + "_" + str(filenum) +'.midi')

  return

def generate(version_num, file_num, midi_input=False, num_notes=500):
  seq_len = 60

  # Load pickle files
  if midi_input:
    with open('simple_pitches_midi.pkl', 'rb') as f:
      midi_pitch = pickle.load(f)

    with open('duration_offsets_midi.pkl', 'rb') as f:
      midi_duration_offset = pickle.load(f)

    midi_pitch = flatten(midi_pitch)
    midi_duration_offset = flatten(midi_duration_offset)

  with open('simple_pitches.pkl', 'rb') as f:
    pitches = pickle.load(f)

  with open('duration_offsets.pkl', 'rb') as f:
    duration_offset = pickle.load(f)

  pitches = flatten(pitches)
  duration_offset = flatten(duration_offset)

  pitch_set_len = len(set(pitches))
  duration_set_len = len(set(duration_offset))

  if midi_input:
    pitch_network_input, duration_network_input, pitch_network_input_norm, duration_network_input_norm = sequence(midi_pitch, midi_duration_offset ,pitches, duration_offset, pitch_set_len, duration_set_len, seq_len, midi_input)
  else:
    pitch_network_input, duration_network_input, pitch_network_input_norm, duration_network_input_norm = sequence(pitches, duration_offset ,pitches, duration_offset, pitch_set_len, duration_set_len, seq_len, midi_input)

  model = create_model(pitch_network_input_norm, duration_network_input_norm, pitch_set_len, duration_set_len, version_num)
  generated_pitches, generated_durations = note_prediction(model, pitch_network_input, duration_network_input, pitches, duration_offset, seq_len, num_notes=num_notes, midi_input=midi_input)

  generate_midi(generated_pitches, generated_durations, version_num, file_num)

  return




In [None]:
for x in range(1,5):
  generate("v2", x)