# Super Piano 3: Google Music Transformer
## Generating Music with Long-Term structure
### Based on 2019 ICLR paper by Cheng-Zhi Anna Huang, Google Brain and Damon Gwinn's code/repo https://github.com/gwinndr/MusicTransformer-Pytorch

Huge thanks go out to the following people who contributed the code/repos used in this colab. Additional contributors are listed in the code as well.

1) Kevin-Yang https://github.com/jason9693/midi-neural-processor

2) gudgud96 for fixing Kevin's MIDI Encoder properly https://github.com/gudgud96

2) jinyi12, Zac Koh, Akamight, Zhang https://github.com/COMP6248-Reproducability-Challenge/music-transformer-comp6248

Thank you so much for your hard work and for sharing it with the world :)


Modified slightly by Maya Shen for Group 4's Project 3 for 10-615: Art and ML at CMU

###Setup Environment and Dependencies. Check GPU.

In [None]:
#@title Check if GPU (driver) is availiable (you do not want to run this on CPU, trust me)
!nvcc --version
!nvidia-smi

In [None]:
#@title Clone/Install all dependencies
!git clone https://github.com/asigalov61/midi-neural-processor
!git clone https://github.com/asigalov61/MusicTransformer-Pytorch
!pip install tqdm
!pip install progress
!pip install pretty-midi
!pip install pypianoroll
!pip install matplotlib
!pip install librosa
!pip install scipy
!pip install pillow
!apt install fluidsynth #Pip does not work for some reason. Only apt works
!pip install midi2audio
!pip install mir_eval
!cp /usr/share/sounds/sf2/FluidR3_GM.sf2 /content/font.sf2

In [None]:
#@title Import all needed modules
import numpy as np
import pickle
import os
import sys
import math
import random
# For plotting
import pypianoroll
from pypianoroll import Multitrack, Track
import matplotlib
import matplotlib.pyplot as plt
#matplotlib.use('SVG')
#%matplotlib inline
#matplotlib.get_backend()
import mir_eval.display
import librosa
import librosa.display
# For rendering output audio
import pretty_midi
from midi2audio import FluidSynth
from google.colab import output
from IPython.display import display, Javascript, HTML, Audio

In [None]:
#@title (Optional) Pre-trained models download (2 models trained for 100 epochs to 1.968 FLoss and 0.420 acc)
!mkdir /content/MusicTransformer-Pytorch/rpr
!mkdir /content/MusicTransformer-Pytorch/rpr/results
%cd /content/MusicTransformer-Pytorch/rpr/results
!wget 'https://superpiano.s3-us-west-1.amazonaws.com/SuperPiano3models.zip'
!unzip SuperPiano3models.zip
%cd /content/MusicTransformer-Pytorch/

#Please note that you MUST DOWNLOAD AND PROCESS ONE OF THE DATASETS TO TRAIN OR TO USE PRE-TRAINED MODEL as it primes the model from DATASET files.

#Option 1: MAESTRO DataSet

In [None]:
#@title Download Google Magenta MAESTRO v.2.0.0 Piano MIDI Dataset (~1300 MIDIs)
%cd /content/MusicTransformer-Pytorch/dataset/
!wget 'https://storage.googleapis.com/magentadata/datasets/maestro/v2.0.0/maestro-v2.0.0-midi.zip'
!unzip maestro-v2.0.0-midi.zip
%cd /content/MusicTransformer-Pytorch/

In [None]:
#@title Prepare directory structure and MIDI processor
%cd /content/
!mv midi-neural-processor midi_processor
%cd /content/MusicTransformer-Pytorch/

In [None]:
#@title Process MAESTRO MIDI DataSet
!python3 preprocess_midi.py '/content/MusicTransformer-Pytorch/dataset/maestro-v2.0.0'

#Option 2: Your own Custom MIDI DataSet

In [None]:
#@title Create directory structure for the DataSet and prep MIDI processor

!mkdir '/content/MusicTransformer-Pytorch/dataset/e_piano/'
!mkdir '/content/MusicTransformer-Pytorch/dataset/e_piano/train'
!mkdir '/content/MusicTransformer-Pytorch/dataset/e_piano/test'
!mkdir '/content/MusicTransformer-Pytorch/dataset/e_piano/val'
!mkdir '/content/MusicTransformer-Pytorch/dataset/e_piano/custom_midis'

%cd /content/
!mv midi-neural-processor midi_processor
%cd /content/MusicTransformer-Pytorch/

In [None]:
#@title Delete old custom MIDI dataset if necessary

!rm -f /content/MusicTransformer-Pytorch/dataset/e_piano/custom_midis/*

In [None]:
#@title Upload your custom MIDI DataSet to created "dataset/e_piano/custom_midis" folder through this cell or manually through any other means. You can also use ready-to-use DataSets below
from google.colab import files
%cd '/content/MusicTransformer-Pytorch/dataset/e_piano/custom_midis'
uploaded = files.upload()

for fn in uploaded.keys():
  print('User uploaded file "{name}" with length {length} bytes'.format(
      name=fn, length=len(uploaded[fn])))

In [None]:
#@title (The Best Choice/Works best stand-alone) Super Piano 2 Original 2500 MIDIs of Piano Music
%cd /content/MusicTransformer-Pytorch/dataset/e_piano/custom_midis
!wget 'https://github.com/asigalov61/SuperPiano/raw/master/Super_Piano_2_MIDI_DataSet_CC_BY_NC_SA.zip'
!unzip -j 'Super_Piano_2_MIDI_DataSet_CC_BY_NC_SA.zip'
!rm Super_Piano_2_MIDI_DataSet_CC_BY_NC_SA.zip

In [None]:
#@title (Second Best Choice/Works best stand-alone) Alex Piano Only Original 450 MIDIs 
%cd /content/MusicTransformer-Pytorch/dataset/e_piano/custom_midis
!wget 'https://github.com/asigalov61/AlexMIDIDataSet/raw/master/AlexMIDIDataSet-CC-BY-NC-SA-Piano-Only.zip'
!unzip -j 'AlexMIDIDataSet-CC-BY-NC-SA-Piano-Only.zip'
!rm AlexMIDIDataSet-CC-BY-NC-SA-All-Drafts-Piano-Only.zip

For now, we are going to split the dataset by random into "test"/"val" dirs which is not ideal. So feel free to modify the code to your liking to achieve better training results with this implementation.

In [None]:
#@title Process your custom MIDI DataSet :)
%cd /content/MusicTransformer-Pytorch
from processor import encode_midi

import os
import random



%cd '/content/MusicTransformer-Pytorch/dataset/e_piano/custom_midis'

custom_MIDI_DataSet_dir = '/content/MusicTransformer-Pytorch/dataset/e_piano/custom_midis'

train_dir = '/content/MusicTransformer-Pytorch/dataset/e_piano/train' # split_type = 0
test_dir = '/content/MusicTransformer-Pytorch/dataset/e_piano/test' # split_type = 1  
val_dir = '/content/MusicTransformer-Pytorch/dataset/e_piano/val' # split_type = 2

total_count = 0
train_count = 0
val_count   = 0
test_count  = 0

f_ext = '.pickle'
fileList = os.listdir(custom_MIDI_DataSet_dir)
for file in fileList:
     # we gonna split by a random selection for now
    
    split = random.randint(1, 2)
    if (split == 0):
         o_file = os.path.join(train_dir, file+f_ext)
         train_count += 1

    elif (split == 2):
         o_file0 = os.path.join(train_dir, file+f_ext)
         train_count += 1
         o_file = os.path.join(val_dir, file+f_ext)
         val_count += 1

    elif (split == 1):
         o_file0 = os.path.join(train_dir, file+f_ext)
         train_count += 1
         o_file = os.path.join(test_dir, file+f_ext)
         test_count += 1
    try:
      prepped = encode_midi(file)
      o_stream = open(o_file0, "wb")
      pickle.dump(prepped, o_stream)
      o_stream.close()

      prepped = encode_midi(file)
      o_stream = open(o_file, "wb")
      pickle.dump(prepped, o_stream)
      o_stream.close()
   
      print(file)
      print(o_file)
      print('Coverted!')  
    except KeyboardInterrupt: 
      raise   
    except:
      print('Bad file. Skipping...')

print('Done')
print("Num Train:", train_count)
print("Num Val:", val_count)
print("Num Test:", test_count)
print("Total Count:", train_count)

%cd /content/MusicTransformer-Pytorch

In [None]:
%cd /content/MusicTransformer-Pytorch
from processor import encode_midi

#Train the Model

In [None]:
#@title Delete old training weights if necessary

!rm -f /content/MusicTransformer-Pytorch/rpr/weights/*

In [None]:
#@title Activate Tensorboard Graphs/Stats to monitor/evaluate model perfomance during and after training runs
# Load the TensorBoard notebook extension
%reload_ext tensorboard
import tensorflow as tf
import datetime, os
%tensorboard --logdir /content/MusicTransformer-Pytorch/rpr

In [None]:
#@title Start to Train the Model
batch_size = 4 #@param {type:"slider", min:0, max:8, step:1}
number_of_training_epochs = 100 #@param {type:"slider", min:0, max:200, step:1}
maximum_output_MIDI_sequence = 2048 #@param {type:"slider", min:0, max:8192, step:128}
!python3 train.py -output_dir rpr --rpr -batch_size=$batch_size -epochs=$number_of_training_epochs -max_sequence=$maximum_output_MIDI_sequence #-n_layers -num_heads -d_model -dim_feedforward

In [None]:
# Download specific set of weights (epoch)
from google.colab import files
files.download('/content/MusicTransformer-Pytorch/rpr/weights/epoch_0025.pickle')

In [None]:
# Zip all weights and save
!zip -r '/content/MusicTransformer-Pytorch/rpr/weights/model3_weights.zip' '/content/MusicTransformer-Pytorch/rpr/weights'
files.download('/content/MusicTransformer-Pytorch/rpr/weights/model3_weights.zip')

In [None]:
# Re-Start Training from a certain checkpoint and epoch
batch_size = 4 # {type:"slider", min:0, max:8, step:1}
number_of_training_epochs = 100 # {type:"slider", min:0, max:200, step:1}
maximum_output_MIDI_sequence = 2048 # {type:"slider", min:0, max:8192, step:128}
saved_checkpoint_full_path = "/content/MusicTransformer-Pytorch/rpr/weights/epoch_0018.pickle" # {type:"string"}
continue_epoch_number =  18 # {type:"integer"}

!python3 train.py -output_dir rpr --rpr -batch_size=$batch_size -epochs=$number_of_training_epochs -max_sequence=$maximum_output_MIDI_sequence -continue_weights $saved_checkpoint_full_path -continue_epoch $continue_epoch_number #-n_layers -num_heads -d_model -dim_feedforward

###Evaluate the resulted models

In [None]:
#@title Evaluate Best Resulting Accuracy Model (best_acc_weights.pickle)
!python3 evaluate.py -model_weights rpr/results/best_acc_weights.pickle --rpr

In [None]:
#@title Evaluate Best Resulting Loss Model (best_loss_weights.pickle)
!python3 evaluate.py -model_weights rpr/results/best_loss_weights.pickle --rpr

In [None]:
import pandas as pd
results = pd.read_csv("/content/MusicTransformer-Pytorch/rpr/results/results.csv")[:101]

In [None]:
results

In [None]:
plt.plot(results[['Avg Train loss']], label = "Avg Train Loss")
plt.plot(results[['Avg Eval loss']], label = "Avg Eval Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.title("Super Piano 3 Ghibli Melody Generation: Loss by Training Epoch")
plt.savefig("/content/MusicTransformer-Pytorch/rpr/results/loss.jpg", dpi  = 300)

In [None]:
plt.plot(results[['Train Accuracy']], label = "Train Accuracy")
plt.plot(results[['Eval accuracy']], label = "Eval Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.title("Super Piano 3 Ghibli Melody: Accuracy by Training Epoch")
plt.savefig("/content/MusicTransformer-Pytorch/rpr/results/acc.jpg", dpi  = 300)

To have the model continue your custom MIDI enter the following into the custom_MIDI field below:

-primer_file '/content/some_dir/some_seed_midi.mid'

For example: -primer_file '/content/MusicTransformer-Pytorch/seed.mid'

# Generate and Explore the output :)

In [None]:
#@title MIDI seed length for priming_sequence_length (for full seed)
seed_MIDI = '/content/MusicTransformer-Pytorch/Seed1-Ascending-Am.mid' #@param {type:"string"}

len(encode_midi(seed_MIDI)) - 1

In [None]:
#@title Generate, Plot, Graph, Save, Download, and Render the resulting output
number_of_tokens_to_generate = 1024 #@param {type:"slider", min:1, max:2048, step:1}
priming_sequence_length = 1024 #@param {type:"slider", min:1, max:2048, step:8}
maximum_possible_output_sequence = 2048 #@param {type:"slider", min:0, max:2048, step:8}
select_model = "/content/MusicTransformer-Pytorch/rpr/results/best_acc_weights.pickle" #@param ["/content/MusicTransformer-Pytorch/rpr/results/best_acc_weights.pickle", "/content/MusicTransformer-Pytorch/rpr/results/best_loss_weights.pickle"]
seed_MIDI = "-primer_file '/content/MusicTransformer-Pytorch/seed.mid'" #@param {type:"string"}

import processor
from processor import encode_midi, decode_midi

!python generate.py -output_dir output -model_weights=$select_model --rpr -target_seq_length=$number_of_tokens_to_generate -num_prime=$priming_sequence_length -max_sequence=$maximum_possible_output_sequence $seed_MIDI #

print('Successfully exported the output to output folder. To primer.mid and rand.mid')

# set the src and play
FluidSynth("/content/font.sf2").midi_to_audio('/content/MusicTransformer-Pytorch/output/rand.mid', '/content/MusicTransformer-Pytorch/output/output.wav')

from google.colab import files
files.download('/content/MusicTransformer-Pytorch/output/rand.mid')
files.download('/content/MusicTransformer-Pytorch/output/primer.mid')

Audio('/content/MusicTransformer-Pytorch/output/output.wav')


In [None]:
files.download('/content/MusicTransformer-Pytorch/output/output.wav')

In [None]:
#@title Plot and Graph the Output :)
graphs_length_inches = 18 #@param {type:"slider", min:0, max:20, step:1}
notes_graph_height = 6 #@param {type:"slider", min:0, max:20, step:1}
highest_displayed_pitch = 92 #@param {type:"slider", min:1, max:128, step:1}
lowest_displayed_pitch = 24 #@param {type:"slider", min:1, max:128, step:1}
piano_roll_color_map = "Blues"

import librosa
import numpy as np
import pretty_midi
import pypianoroll
from pypianoroll import Multitrack, Track
import matplotlib
import matplotlib.pyplot as plt
#matplotlib.use('SVG')
# For plotting
import mir_eval.display
import librosa.display
%matplotlib inline


midi_data = pretty_midi.PrettyMIDI('/content/MusicTransformer-Pytorch/output/rand.mid')

def plot_piano_roll(pm, start_pitch, end_pitch, fs=100):
    # Use librosa's specshow function for displaying the piano roll
    librosa.display.specshow(pm.get_piano_roll(fs)[start_pitch:end_pitch],
                             hop_length=1, sr=fs, x_axis='time', y_axis='cqt_note',
                             fmin=pretty_midi.note_number_to_hz(start_pitch))



roll = np.zeros([int(graphs_length_inches), 128])
# Plot the output

track = Multitrack('/content/MusicTransformer-Pytorch/output/rand.mid', name='track')
plt.figure(figsize=[graphs_length_inches, notes_graph_height])
fig, ax = track.plot()
fig.set_size_inches(graphs_length_inches, notes_graph_height)
plt.figure(figsize=[graphs_length_inches, notes_graph_height])
ax2 = plot_piano_roll(midi_data, int(lowest_displayed_pitch), int(highest_displayed_pitch))
plt.show(block=False)

### Save to Google Drive (Standard GD connect code) -- DOES NOT WORK WHEN CONNECTED TO VM INSTANCE

In [None]:
from google.colab import drive
drive.mount('/content/drive')