# Creating the RASTRO Dataset
The new dataset called RASTRO ("Reduced Accessible Songs for Teaching, Rhythmically Oversimplified") is a time-quantized version of the MIDI portion of Google Magenta's MAESTRO dataset of piano recordings. 
It is a list of PyTorch tensors, one tensor for each song. 

In [None]:
#!wget -N https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0-midi.zip
#!unzip -n -qq maestro-v3.0.0-midi.zip

source_dataset = 'maestro'

!rm -rf midi_files
if source_dataset == 'groove':
    !wget -N https://storage.googleapis.com/magentadata/datasets/groove/groove-v1.0.0-midionly.zip
    !unzip -n -qq groove-v1.0.0-midionly.zip
    !ln -s groove midi_files
else:
    !wget -N https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0-midi.zip
    !unzip -n -qq maestro-v3.0.0-midi.zip
    !ln -s maestro-v3.0.0 midi_files

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from glob import glob
import mido 
import pathlib
import pretty_midi
from tqdm import tqdm_notebook as tqdm
import multiprocessing as mp
from tqdm.contrib.concurrent import process_map
import random
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from typing import Optional
from IPython.display import Audio, display

In [None]:
data_dir = pathlib.Path('midi_files')
filenames = glob(str(data_dir/'**/*.mid*'), recursive=True)
print('Number of files:', len(filenames))

In [None]:
# routines for processing midi files
def time_convert(time_s, bpm, pm, time_units='ticks'):
    if time_units == 'beats':
        bps = bpm/60
        beats = time_s * bps
        #print("time_s, bps, beats = ",time_s, bps, beats)
        return beats
    elif time_units == 'ticks':  # 500000 ticks per beat
        return pm.time_to_tick(time_s)
    return time_s  # leave it in seconds    

def midi_file_to_tensor_old(midi_file,
                        time_units='ticks', # beats, ticks, s
                        info=False,  # return info about the track
                       ):
    pm = pretty_midi.PrettyMIDI(midi_file) # read in the whole file. this is incredibly slow
    bpm = pm.estimate_tempo()
    mid = mido.MidiFile(midi_file)
    tpb = mid.ticks_per_beat
    #tps = 60000.0 / (bpm * tpb) 
    spt = mido.tick2second(1, tpb, 500000 )
    # Sort the notes first by start time (then by pitch if two notes start at the same time)
    sorted_notes = sorted(pm.instruments[0].notes, key=lambda note: (note.start, note.pitch))
    notes = torch.empty( (len(sorted_notes), 3), dtype=torch.float32 ) # allocate storage
    
    prev_start = sorted_notes[0].start
    for i, note in enumerate(sorted_notes):
        notes[i] = note.pitch
        notes[i, 1] = note.start - prev_start  # step, time since last note
        notes[i, 2] = note.end - note.start    # duration
        prev_start = note.start

        notes[i, 1] = time_convert(notes[i, 1], bpm, pm, time_units=time_units)
        notes[i, 2] = time_convert(notes[i, 2], bpm, pm, time_units=time_units)

    #notes[:,1:] = notes[:,1:]//(tpb/16) # severely quantize in time <-- save for later
    if info:
        return notes, {'bpm': bpm, 'ticks_per_beat':tpb, 'seconds_per_tick':spt}
    else:
        return notes

def midi_file_to_tensor_ugh(midi_file,
                        to_bpm=None, # e.g. 120. None=no change
                        divs_per_beat=4, # quantize to 16th-note-time-steps  
                        info=False,  # return info about the track
                       ):
    pm = pretty_midi.PrettyMIDI(midi_file) # read in the whole file. this is incredibly slow
    # Sort the notes first by start time (then by pitch if two notes start at the same time)
    sorted_notes = sorted(pm.instruments[0].notes, key=lambda note: (note.start, note.pitch))
    notes = torch.empty( (len(sorted_notes), 3), dtype=torch.float32 ) # allocate storage

    return_info = {}
    time_mult=1.0
    if to_bpm is not None: 
        bpm = pm.estimate_tempo()
        time_mult = bpm / to_bpm
        return_info = {'orig_bpm':bpm, 'time_mult':time_mult}
    
    prev_start = sorted_notes[0].start
    for i, note in enumerate(sorted_notes):
        notes[i] = note.pitch
        notes[i, 1] = note.start - prev_start  # step, time since last note
        notes[i, 2] = note.end - note.start    # duration
        prev_start = note.start

        if to_bpm is not None:
            notes[i, 1:] = notes[i, 1:] * time_mult # rescale timing for tempo change

    #notes[:,1:] = notes[:,1:]//(tpb/16) # severely quantize in time <-- save quantization for later
    if info:
        return notes, return_info
    else:
        return notes


def midi_file_to_tensor(midi_file,
                        time_units='ticks', # beats, ticks, s
                        info=False,  # return info about the track
                       ):
    pm = pretty_midi.PrettyMIDI(midi_file) # read in the whole file. this is incredibly slow
    bpm = pm.estimate_tempo()
    mid = mido.MidiFile(midi_file)
    tpb = mid.ticks_per_beat
    tps = 60000.0 / (bpm * tpb) 
    spt = mido.tick2second(1, tpb, 500000 )
    # Sort the notes first by start time (then by pitch if two notes start at the same time)
    sorted_notes = sorted(pm.instruments[0].notes, key=lambda note: (note.start, note.pitch))
    notes = torch.empty( (len(sorted_notes), 3), dtype=torch.float32 ) # allocate storage
    
    prev_start = sorted_notes[0].start
    for i, note in enumerate(sorted_notes):
        notes[i] = note.pitch
        notes[i, 1] = note.start - prev_start  # step, time since last note
        notes[i, 2] = note.end - note.start    # duration
        prev_start = note.start

        #notes[i, 1] = time_convert(notes[i, 1], bpm, pm, time_units=time_units)
        #notes[i, 2] = time_convert(notes[i, 2], bpm, pm, time_units=time_units)

    #notes[:,1:] = notes[:,1:]//(tpb/16) # severely quantize in time <-- save for later
    if info:
        return notes, {'bpm': bpm, 'ticks_per_beat':tpb, 'seconds_per_tick':spt}
    else:
        return notes

In [None]:
notes, info = midi_file_to_tensor(filenames[0], info=True)  
print("info = ",info)
pitches = notes[:,0].type(torch.long)  # just the pitch info
print("notes.shape, pitches.shape =",notes.shape, pitches.shape)
print("notes[:,1] min, max = ", notes[:,1].min(), notes[:,1].max())


In [None]:
def time_quantize(notes_tensor,  # a single song
                  time_res=0.008, # resolution in seconds.  8ms is from Google "This Time With Feeling" paper
                  t_max=1.0, # again, from Google paper. This will give us from 0 to 1 second. Anything beyond that gets clipped
                 ):
    nt2 = notes_tensor.contiguous().clone()
    if False:
        bucket_vals = torch.arange(0, t_max, time_res)
        boundaries = torch.arange(time_res/2, t_max - time_res/2, time_res)
        inds = torch.bucketize(nt2[:,1:], boundaries)
        nt2[:,1:] = bucket_vals[inds]
    else:
        nt2[:,1:] = torch.clamp(torch.floor(nt2[:,1:]/time_res)*time_res, 0.0, t_max)
    return nt2


In [None]:
notes_quant = time_quantize(notes)
print("notes_quant[:,1] min, max = ", notes_quant[:,1].min(), notes_quant[:,1].max())
print("notes_quant[:,2] min, max = ", notes_quant[:,2].min(), notes_quant[:,2].max())
print("Number of unique steps, durations = ", len(notes_quant[:,1].unique()), len(notes_quant[:,2].unique()) )

# Read all files into a list of tensors

In [None]:
def files_to_tensor_list(filenames): 
    tensor_list = process_map(midi_file_to_tensor, filenames, max_workers=mp.cpu_count(), chunksize=1)
    return tensor_list

In [None]:
notes_list = files_to_tensor_list(filenames)
print(f"\n{len(notes_list)} files read")

In [None]:
# save that for next time
torch.save(notes_list, f'{source_dataset}_tensorlist.pt') # save for next time

In [None]:
#notes_list = torch.load('maestro3_tensorlist_120bpm.pt')  # load from previous computation
len(notes_list)

For easier analysis, put all notes into one big long tensor called "`all_notes`"

In [None]:
def tl_to_notes(tensor_list, shuffle=False, delimit=True):
  "list of tensors (of arbitrary length, for each song) converted to one big long tensor of notes all running togehter"
  if shuffle:random.shuffle(tensor_list)
  if delimit:
    delimiter = torch.zeros(3)  # use all zeros to show ends of songs
    tensor_list = [element for item in tensor_list for element in (item, delimiter)]
  return torch.vstack(tensor_list)

all_notes = tl_to_notes(notes_list, shuffle=False) # just grab one file, for testing overfitting
all_notes.shape

In [None]:
# routines for displaying midi / notes

def notes_arr_to_df(notes_arr) -> pd.DataFrame:
    columns = ['pitch','step','duration']
    df = pd.DataFrame(notes_arr, columns=columns)
    df["start"] = ""
    df["end"] = ""

    prev_start = 0
    #for i, row in tqdm(df.iterrows(), total=df.shape[0]):
    for i, row in df.iterrows():
        start = prev_start + float(row['step'])
        df.at[i, 'start'] = start
        df.at[i, 'end'] = start + float(row['duration'])
        prev_start = start
    return df

def df_to_midi(
        notes_df: pd.DataFrame,
        out_file: str = '',  # output file to save to, if any
        instrument_name: str = 'Acoustic Grand Piano', # whatever you want to call this instrument
        velocity: int = 100,  # note loudness
    ) -> pretty_midi.PrettyMIDI:
    "converts a dataframe to valid midi"

    pm = pretty_midi.PrettyMIDI()
    instrument = pretty_midi.Instrument(
        program=pretty_midi.instrument_name_to_program(
            instrument_name))

    prev_start = 0
    for i, note in notes_df.iterrows(): # this is a serial operation, not sure how to parallelize
        start = float(prev_start + note['step'])
        end = float(start + note['duration'])
        note = pretty_midi.Note(
            velocity=velocity,
            pitch=int(note['pitch']),
            start=start,
            end=end,
        )
        instrument.notes.append(note)
        prev_start = start

    pm.instruments.append(instrument)
    if out_file: pm.write(out_file)
    return pm

def plot_piano_roll(notes_df: pd.DataFrame, count: Optional[int] = None, vocab_size=128):
    "produce a piano roll plot"
    if count:
        title = f'First {count} notes'
    else:
        title = f'Whole track'
        count = len(notes_df['pitch'])
    plt.figure(figsize=(20, 4))
    plot_pitch = np.stack([notes_df['pitch'], notes_df['pitch']], axis=0)
    plot_start_stop = np.stack([notes_df['start'], notes_df['end']], axis=0)
    plt.plot(
        plot_start_stop[:, :count], plot_pitch[:, :count], color="b", marker=".")
    plt.xlabel('Time [s]')
    plt.ylabel('Pitch')
    ax = plt.gca()
    ax.set_ylim([0, vocab_size])
    _ = plt.title(title)
    plt.show()


def midi_to_audio(pm: pretty_midi.PrettyMIDI, seconds=30, sr=16000):
    "midi to audio, playable in notebook"
    waveform = pm.fluidsynth(fs=float(sr))
    # Take a sample of the generated waveform to mitigate kernel resets
    try: 
        waveform_short = waveform[:seconds*sr]
    except:
        waveform_short = waveform
    return display(Audio(waveform_short, rate=sr))

def pitches_to_midi(pitch_list, seconds=30):
    notes_tensor = torch.zeros((len(pitch_list), 3)) + 0.25
    for i, p in enumerate(pitch_list):
        notes_tensor[i,0] = p
    notes_df = notes_arr_to_df(notes_tensor.cpu().detach().numpy())
    midi = df_to_midi(notes_df)
    plot_piano_roll(notes_df)
    audio_display = midi_to_audio(midi, seconds=seconds)
    return audio_display

def notes_to_midi(notes_tensor, seconds=30, time_rescale=2/((120/60))):
    notes_tensor = notes_tensor.clone() # just to avoid weird overwrites of memory
    #notes_tensor = notes_tensor * (notes_tensor>0)  # negative numbers clipped to zero
    if notes_tensor.min() < 0.0:
      print("WARNING: You have negative pitches, steps or durations. Setting them to zero")
      notes_tensor = notes_tensor * (notes_tensor >= 0)
    if time_rescale is not None :
        notes_tensor[:,1:] = notes_tensor[:,1:] *time_rescale
    notes_df = notes_arr_to_df(notes_tensor.cpu().detach().numpy())
    midi = df_to_midi(notes_df)
    plot_piano_roll(notes_df)
    audio_display = midi_to_audio(midi, seconds=seconds)
    return audio_display

In [None]:
start, playlen = 920000, 70
orig1 = all_notes[start:start+playlen]
#notes_to_midi(orig1, time_rescale=info['seconds_per_tick'])
notes_to_midi(orig1)

In [None]:
#quantization
grid_resolution = 1/8
quant_notes = all_notes.clone()
quant_notes[:,1] = torch.round(all_notes[:,1]/grid_resolution)
quant_notes[:,2] = torch.ceil(all_notes[:,2]/grid_resolution)  # ceil to avoid zero duration notes



print("steps: min max, unique = ",quant_notes[:,1].min(), quant_notes[:,1].max(), len(quant_notes[:,1].unique()))
print("dur: min max, unique =",quant_notes[:,2].min(), quant_notes[:,2].max(), len(quant_notes[:,2].unique()))

In [None]:
time_rescale = 2/(120/60) * grid_resolution
time_rescale

In [None]:
#notes_to_midi(quant_notes[start:start+playlen], time_rescale=info['seconds_per_tick']*grid_resolution)
notes_to_midi(quant_notes[start:start+playlen], time_rescale = 2/(120/60) * grid_resolution)

In [None]:
step_vals = quant_notes[:,1].unique().sort()[0]
step_cap = step_vals[len(step_vals)//2]
step_cap = quant_notes[:,1].unique().median()//2
dur_cap = int(quant_notes[:,2].unique().median()/ 2 )
step_cap, dur_cap

In [None]:
quant_notes[:,1] = torch.clamp( quant_notes[:,1], 0, step_cap) 
quant_notes[:,2] = torch.clamp( quant_notes[:,2], 1, dur_cap) 

In [None]:
len(quant_notes[:,1].unique())

In [None]:
#notes_to_midi(quant_notes[start:start+playlen], time_rescale=info['seconds_per_tick']*grid_resolution)
notes_to_midi(quant_notes[start:start+playlen], time_rescale = 2/(120/60) * grid_resolution)

That looks good.  Let's "burn it in"

In [None]:
# build a "notes list" based on original notes list
quant_notes_list = []
for notes in notes_list: 
    quant_notes = notes.clone()
    quant_notes[:,1] = torch.round(notes[:,1]/grid_resolution)
    quant_notes[:,2] = torch.ceil(notes[:,2]/grid_resolution)  # ceil to avoid zero duration notes
    
    quant_notes[:,1] = torch.clamp( quant_notes[:,1], 0, step_cap) 
    quant_notes[:,2] = torch.clamp( quant_notes[:,2], 0, dur_cap) 
    quant_notes_list.append(quant_notes.type(torch.torch.int16))
len(quant_notes_list)

In [None]:
torch.save(quant_notes_list, 'rastro-120bpm_16th_tensor_list.pt')

In [None]:
1/time_rescale

# JS  Bach chorales

In [None]:
import json

path = "JSB-Chorales-dataset/Jsb16thSeparated.json"
with open(path) as f:
    data_dict = json.load(f)

data_dict.keys()

In [None]:
len(data_dict['train'][0])

In [None]:
song = np.array(data_dict['train'][1])
song[0:12]

In [None]:
def song_to_piano_roll(song):
    frames = song.shape[0]
    pr = np.zeros((128,frames))
    for frame, quad in enumerate(song): 
        for note in quad: 
            pr[note,frame] = 64 # velocity? 
    return pr

piano_roll = song_to_piano_roll(song)
piano_roll.shape

In [None]:
%pip install -qq librosa

In [None]:
import pretty_midi as pm 


from __future__ import division
import sys
import argparse
import numpy as np
import pretty_midi
#import librosa


def piano_roll_to_pretty_midi(piano_roll, fs=8*(25/24), program=0):
    '''Convert a Piano Roll array into a PrettyMidi object
     with a single instrument.

    Parameters
    ----------
    piano_roll : np.ndarray, shape=(128,frames), dtype=int
        Piano roll of one instrument
    fs : int
        Sampling frequency of the columns, i.e. each column is spaced apart
        by ``1./fs`` seconds.
    program : int
        The program number of the instrument.

    Returns
    -------
    midi_object : pretty_midi.PrettyMIDI
        A pretty_midi.PrettyMIDI class instance describing
        the piano roll.

    '''
    notes, frames = piano_roll.shape
    pm = pretty_midi.PrettyMIDI()
    instrument = pretty_midi.Instrument(program=program)

    # pad 1 column of zeros so we can acknowledge inital and ending events
    piano_roll = np.pad(piano_roll, [(0, 0), (1, 1)], 'constant')

    # use changes in velocities to find note on / note off events
    velocity_changes = np.nonzero(np.diff(piano_roll).T)

    # keep track on velocities and note on times
    prev_velocities = np.zeros(notes, dtype=int)
    note_on_time = np.zeros(notes)

    for time, note in zip(*velocity_changes):
        # use time + 1 because of padding above
        velocity = piano_roll[note, time + 1]
        time = time / fs
        if velocity > 0:
            if prev_velocities[note] == 0:
                note_on_time[note] = time
                prev_velocities[note] = velocity
        else:
            pm_note = pretty_midi.Note(
                velocity=prev_velocities[note],
                pitch=note,
                start=note_on_time[note],
                end=time)
            instrument.notes.append(pm_note)
            prev_velocities[note] = 0
    pm.instruments.append(instrument)
    return pm


fs = 8*(25/24)
print("fs, 1/fs = ",fs, 1/fs)
pm2 = piano_roll_to_pretty_midi(piano_roll, fs=fs)

In [None]:
def pm_to_tensor(pm):
    # Sort the notes first by start time (then by pitch if two notes start at the same time)
    sorted_notes = sorted(pm.instruments[0].notes, key=lambda note: (note.start, note.pitch))
    notes = torch.empty( (len(sorted_notes), 3), dtype=torch.float32 ) # allocate storage
    
    prev_start = sorted_notes[0].start
    for i, note in enumerate(sorted_notes):
        notes[i] = note.pitch
        notes[i, 1] = note.start - prev_start  # step, time since last note
        notes[i, 2] = note.end - note.start    # duration
        prev_start = note.start
    return notes

notes2 = pm_to_tensor(pm2)
notes2.shape

In [None]:
notes2

In [None]:
notes_to_midi(notes2, seconds=None)

In [None]:
16*1/fs

In [None]:
notes_tensor_list = []
prev_dur = None
total_16ths = 0

for sub in ['train','valid','test']:
    print(sub)
    for i in range(len(data_dict[sub])):
        song = np.array(data_dict[sub][i])
        print(f"   i = {i}, len(song) = {len(song)}")
        piano_roll = song_to_piano_roll(song)
        pm2 = piano_roll_to_pretty_midi(piano_roll, fs=fs)
        # how to set step for first/last note of song? 
        extra_pitch = 127 # a rest
        extra_step = 0 if prev_dur is None else prev_dur
        extra_dur = 0.96  # that's what's used elsewhere
        extra_note = torch.tensor([extra_pitch, extra_step, extra_dur]).unsqueeze(0)  
        notes_tensor = pm_to_tensor(pm2)
        notes_tensor[0,1] = extra_dur # first note comes after the initial rest
        notes_tensor = torch.cat((extra_note, notes_tensor),dim=0)
        notes_tensor_list.append(notes_tensor)
        total_16ths = total_16ths + len(song)
        prev_dur = notes_tensor[-1,2]
print("total_16ths,  total_16ths/4 =",total_16ths, total_16ths/4)

In [None]:
len(notes_tensor_list[0])

In [None]:
def tl_to_notes(tensor_list, shuffle=False, delimit=False):
  "list of tensors (of arbitrary length, for each song) converted to one big long tensor of notes all running togehter"
  if shuffle:random.shuffle(tensor_list)
  if delimit:
    delimiter = torch.zeros(3)  # use all zeros to show ends of songs
    tensor_list = [element for item in tensor_list for element in (item, delimiter)]
  return torch.vstack(tensor_list).type(torch.float32)  

In [None]:
all_notes = tl_to_notes(notes_tensor_list)
all_notes.shape

In [None]:
all_notes[:,1].unique()

In [None]:
all_notes[:,2].unique()

In [None]:
#all_notes[:,1:] = torch.clamp(all_notes[:,1:], 0, 5.7600)

In [None]:
torch.save(all_notes, 'jsb_tensor_rests.pt')

In [None]:
notes_to_midi(all_notes[2000:5000], seconds=None)

In [None]:
b = all_notes[:,0] == 127
indices = b.nonzero()
len(indices)

In [None]:
i= indices[2].item()
i

In [None]:
all_notes[i-3:i+3]

In [None]:
notes_to_midi(all_notes[i-40:i+40])

---
------------------- my old way 

In [None]:
def zero_repeats(song):
    for i in range(len(song)-1,0,-1):
        for j in range(4): 
            if song[i,j] == song[i-1,j]: song[i,j] = 0 
    return song
    
song = zero_repeats(song)

In [None]:
song = song[:,::-1]
print("len(song) = ",len(song))
song[0:20]

In [None]:
len(np.nonzero(np.array([58, 65, 70, 74]))[0])

In [None]:
def grid_to_stepdur(song):
    notelist = []
    for i in range(0,len(song)):
        for j in range(4): 
            if song[i,j] > 0: # we have a new note
                # time step
                step = 4
                if len(np.nonzero(song[i,:j])[0]) >= 1:
                    step = 0         
                else:  # how many rows back to get any non-zero value? 
                    prev = np.sum(song[:i,:],axis=-1)[::-1]
                    #print("i, j, song[i,j], prev = ",i, j, song[i,j], prev) 
                    try: 
                        step = 1+np.nonzero(prev)[0][0]
                    except:
                        step = 0
                # to get the duration, count how many zeros are under it 
                dur = np.nonzero(song[i+1:,j])[0]
                try:
                    dur = dur[0]
                except:
                    dur = 0
                
                notelist.append([song[i,j], step, 1+dur]) 
        notelist[0][1]=0 # 0 step at start
    return notelist
    
notelist = grid_to_stepdur(song)


In [None]:
notelist = []
songs_count = 0 
for sub in ['train','valid','test']:
    for i in range(len(data_dict[sub])):
        songs_count += 1
        song = np.array(data_dict[sub][i])
        song = zero_repeats(song)
        song = song[:,::-1]
        notelist = notelist + grid_to_stepdur(song)

In [None]:
songs_count

In [None]:
jsbnotes = np.array(notelist)
print("jsbnotes.shape = ",jsbnotes.shape)
jsbnotes[0:10]

In [None]:
jsbnotes[:,1:] *= 16
jsbnotes[0:10]

In [None]:
jsbnotes.shape

In [None]:
'''jsbnotes = np.zeros((len(pitchvals),3))
jsbnotes[:,0] = pitchvals
note_dur = int(0.5/.008)
jsbnotes[:,2] = note_dur
jsbnotes[4::4,1] = note_dur
jsbnotes.shape'''

In [None]:
jsbnotes[0:10]

In [None]:
jsbtensor = torch.tensor(jsbnotes) 
jsbtensor.shape

In [None]:
jsbtensor = torch.clamp(jsbtensor, 0, 4*64)

In [None]:
torch.save(jsbtensor, 'jsb_tensor_sd.pt')

In [None]:
jsbtensor[:,2].unique()/16

In [None]:
jsbtensor[jsbtensor[:,1]>0][:,1].min()