Setup:

In [None]:
#!sudo apt install -qq -y fluidsynth >/dev/null;  # that's for Linux.  on Mac: "brew install fluidsynth"
#!pip install -Uqq pyfluidsynth pretty_midi wandb

In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from glob import glob
import mido 
import pathlib
import pretty_midi
from tqdm import tqdm_notebook as tqdm
import multiprocessing as mp
from tqdm.contrib.concurrent import process_map
import random
torch.multiprocessing.set_sharing_strategy('file_system')

In [None]:

source_dataset = 'jsb'

!rm -rf midi_files
if source_dataset == 'groove':
    !wget -N https://storage.googleapis.com/magentadata/datasets/groove/groove-v1.0.0-midionly.zip
    !unzip -n -qq groove-v1.0.0-midionly.zip
    !ln -s groove midi_files
elif source_dataset == 'maestro':
    !wget -N https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0-midi.zip
    !unzip -n -qq maestro-v3.0.0-midi.zip
    !ln -s maestro-v3.0.0 midi_files
elif source_dataset == 'js-fakes':
    !ln -s js-fakes midi_files

In [None]:
data_dir = pathlib.Path('midi_files')
filenames = glob(str(data_dir/'**/*.mid*'), recursive=True)
print('Number of files:', len(filenames))

In [None]:
vocab_size=128 # number of midi notes
n_contin = 2   # number of continuous variables besides pitches


def time_convert(time_s, bpm, pm, time_units='ticks'):
    if time_units == 'beats':
        bps = bpm/60
        beats = time_s * bps
        #print("time_s, bps, beats = ",time_s, bps, beats)
        return beats
    elif time_units == 'ticks':  # 500000 ticks per beat
        return pm.time_to_tick(time_s)
    return time_s  # leave it in seconds


def midi_file_to_tensor_old(midi_file):
    pm = pretty_midi.PrettyMIDI(midi_file) # read in the whole file. this is incredibly slow

    # Sort the notes first by start time (then by pitch if two notes start at the same time)
    sorted_notes = sorted(pm.instruments[0].notes, key=lambda note: (note.start, note.pitch))
    notes = torch.empty( (len(sorted_notes), 3), dtype=torch.float32 ) # allocate storage

    prev_start = sorted_notes[0].start
    for i, note in enumerate(sorted_notes):
        notes[i] = note.pitch
        notes[i, 1] = note.start - prev_start  # step, i.e. time since last note started
        notes[i, 2] = note.end - note.start    # duration
        prev_start = note.start

    return notes


def midi_file_to_tensor(midi_file,
                        time_units='ticks', # beats, ticks, s
                        info=False,  # return info about the track
                       ):
    pm = pretty_midi.PrettyMIDI(midi_file) # read in the whole file. this is incredibly slow
    bpm = pm.estimate_tempo()
    mid = mido.MidiFile(midi_file)
    tpb = mid.ticks_per_beat
    tps = 60000.0 / (bpm * tpb) 
    spt = mido.tick2second(1, tpb, 500000 )
    # Sort the notes first by start time (then by pitch if two notes start at the same time)
    sorted_notes = sorted(pm.instruments[0].notes, key=lambda note: (note.start, note.pitch))
    notes = torch.empty( (len(sorted_notes), 3), dtype=torch.float32 ) # allocate storage
    
    prev_start = sorted_notes[0].start
    for i, note in enumerate(sorted_notes):
        notes[i] = note.pitch
        notes[i, 1] = note.start - prev_start  # step, time since last note
        notes[i, 2] = note.end - note.start    # duration
        prev_start = note.start

        #notes[i, 1] = time_convert(notes[i, 1], bpm, pm, time_units=time_units) # don't rescale any timing yet
        #notes[i, 2] = time_convert(notes[i, 2], bpm, pm, time_units=time_units)

    #notes[:,1:] = notes[:,1:]//(tpb/16) # don't quantize in time yet
    if info:
        return notes, {'bpm': bpm, 'ticks_per_beat':tpb, 'seconds_per_tick':spt}
    else:
        return notes

In [None]:
#notes, info = midi_file_to_tensor(filenames[0], info=True)
#print(info)
#pitches = notes[:,0].type(torch.long)  # just the pitch info
#notes.shape, pitches.shape

In [None]:
# @title Tensor to MIDI Display Code
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from typing import Optional
from IPython.display import Audio, display

def notes_arr_to_df(notes_arr) -> pd.DataFrame:
    columns = ['pitch','step','duration']
    df = pd.DataFrame(notes_arr, columns=columns)
    df["start"] = ""
    df["end"] = ""

    prev_start = 0
    #for i, row in tqdm(df.iterrows(), total=df.shape[0]):
    for i, row in df.iterrows():
        start = prev_start + float(row['step'])
        df.at[i, 'start'] = start
        df.at[i, 'end'] = start + float(row['duration'])
        prev_start = start
    return df

def df_to_midi(
        notes_df: pd.DataFrame,
        out_file: str = '',  # output file to save to, if any
        instrument_name: str = 'Acoustic Grand Piano', # whatever you want to call this instrument
        velocity: int = 100,  # note loudness
    ) -> pretty_midi.PrettyMIDI:
    "converts a dataframe to valid midi"

    pm = pretty_midi.PrettyMIDI()
    instrument = pretty_midi.Instrument(
        program=pretty_midi.instrument_name_to_program(
            instrument_name))

    prev_start = 0
    for i, note in notes_df.iterrows(): # this is a serial operation, not sure how to parallelize
        start = float(prev_start + note['step'])
        end = float(start + note['duration'])
        note = pretty_midi.Note(
            velocity=velocity,
            pitch=int(note['pitch']),
            start=start,
            end=end,
        )
        instrument.notes.append(note)
        prev_start = start

    pm.instruments.append(instrument)
    if out_file: pm.write(out_file)
    return pm

def plot_piano_roll(notes_df: pd.DataFrame, count: Optional[int] = None):
    "produce a piano roll plot"
    if count:
        title = f'First {count} notes'
    else:
        title = f'Whole track'
        count = len(notes_df['pitch'])
    plt.figure(figsize=(20, 4))
    plot_pitch = np.stack([notes_df['pitch'], notes_df['pitch']], axis=0)
    plot_start_stop = np.stack([notes_df['start'], notes_df['end']], axis=0)
    plt.plot(
        plot_start_stop[:, :count], plot_pitch[:, :count], color="b", marker=".")
    plt.xlabel('Time [s]')
    plt.ylabel('Pitch')
    ax = plt.gca()
    ax.set_ylim([0, vocab_size])
    _ = plt.title(title)
    plt.show()


def midi_to_audio(pm: pretty_midi.PrettyMIDI, seconds=30, sr=16000):
  "midi to audio, playable in notebook"
  waveform = pm.fluidsynth(fs=float(sr))
  # Take a sample of the generated waveform to mitigate kernel resets
  waveform_short = waveform[:seconds*sr]
  return display(Audio(waveform_short, rate=sr))

def pitches_to_midi(pitch_list, seconds=30):
    notes_tensor = torch.zeros((len(pitch_list), 3)) + 0.25
    for i, p in enumerate(pitch_list):
        notes_tensor[i,0] = p
    notes_df = notes_arr_to_df(notes_tensor.cpu().detach().numpy())
    midi = df_to_midi(notes_df)
    plot_piano_roll(notes_df)
    audio_display = midi_to_audio(midi, seconds=seconds)
    return audio_display

def notes_to_midi(notes_tensor, seconds=30, time_rescale=None):
    notes_tensor = notes_tensor.clone() # just to avoid weird overwrites of memory
    #notes_tensor = notes_tensor * (notes_tensor>0)  # negative numbers clipped to zero
    if notes_tensor.min() < 0.0:
      print("WARNING: You have negative pitches, steps or durations. Setting them to zero")
      notes_tensor = notes_tensor * (notes_tensor >= 0)
    if time_rescale is not None :
        notes_tensor[:,1:] = notes_tensor[:,1:] *time_rescale # no quantization, just rescaling time
    notes_df = notes_arr_to_df(notes_tensor.cpu().detach().numpy())
    midi = df_to_midi(notes_df)
    plot_piano_roll(notes_df)
    audio_display = midi_to_audio(midi, seconds=seconds)
    return audio_display

In [None]:
def files_to_tensor_list(filenames): 
    tensor_list = process_map(midi_file_to_tensor, filenames, max_workers=mp.cpu_count(), chunksize=1)
    return tensor_list

In [None]:
#uncomment if needed
#notes_list = files_to_tensor_list(filenames)
#print(f"\n{len(notes_list)} files read")
#torch.save(notes_list, 'maestro3_tensor_list.pt') # save for next time

In [None]:
notes_list = torch.load('maestro3_tensor_list.pt')  # load from previous computation
#notes_list = torch.load('rastro-120bpm_16th_tensor_list.pt')  # load from previous computation
len(notes_list)

Quantize in time. 

In [None]:
def time_quantize(notes_tensor,  # a single song
                  time_res=0.008, # resolution in seconds.  8ms is from Google "This Time With Feeling" paper
                  t_max=1.0, # again, from Google paper. This will give us from 0 to 1 second. Anything beyond that gets clipped
                  use_buckets=True,
                 ):
    nt2 = notes_tensor.contiguous().clone()
    if use_buckets:
        bucket_vals = torch.arange(0, t_max, time_res)
        boundaries = torch.arange(time_res/2, t_max - time_res/2, time_res)
        inds = torch.bucketize(nt2[:,1:].contiguous(), boundaries)
        #nt2[:,1:] = bucket_vals[inds]
        nt2[:,1:] = inds  # time is now in divisions with resolution time_res 
    else:
        nt2[:,1:] = torch.clamp(torch.floor(nt2[:,1:]/time_res)*time_res, 0.0, t_max)
    return nt2

quant_notes_list = [time_quantize(q) for q in notes_list]
torch.save([q.type(torch.int16) for q in quant_notes_list], 'maestro3_tensor_list_quant.pt') # all integers


In [None]:
notes_list = quant_notes_list 

In [None]:
use_wandb = True
if use_wandb:
    import wandb
    wandb.login()

In [None]:
device = 'cuda:1' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print("device =",device)

# Dataset creation

RASTRO: "Reduced Accessible Songs for Training, Rhythmically Oversimplified"

**TODO:** change dataset from one big string of integers to reading from individual songs? 

In [None]:
def tl_to_notes(tensor_list, shuffle=False, delimit=True):
  "list of tensors (of arbitrary length, for each song) converted to one big long tensor of notes all running togehter"
  if shuffle:random.shuffle(tensor_list)
  if delimit:
    delimiter = torch.zeros(3)  # use all zeros to show ends of songs
    tensor_list = [element for item in tensor_list for element in (item, delimiter)]
  return torch.vstack(tensor_list).type(torch.float32)  # return one big tensor of floats

seed = 1337
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)

all_notes = tl_to_notes(notes_list, shuffle=True) 

In [None]:
# jsb
if source_dataset == 'jsb': 
    all_notes = torch.load('jsb_tensor_rests.pt').type(torch.float32)

In [None]:
all_pitches = all_notes[:,0].type(torch.long)  # just the pitch info

#info about our data
print("all_notes.shape =",all_notes.shape)
print("steps: min max = ",all_notes[:,1].min(), all_notes[:,1].max())
print("dur: min max =",all_notes[:,2].min(), all_notes[:,2].max())

In [None]:
len(torch.unique(all_notes[:,1])), len(torch.unique(all_notes[:,2])),

In [None]:
all_notes[0:28]

In [None]:
all_notes[:,1:] = torch.clamp(all_notes[:,1:], 0, 5.76) 

In [None]:
smallest_time = all_notes[:,1:][all_notes[:,1:]>0].min()
print("smallest_time = ",smallest_time)
all_notes[:,1:] = all_notes[:,1:]/smallest_time  #integer times, need to convert later

In [None]:
def outbound_t_rescale(notes_tensor, time_min=1.0):
    nt2 = notes_tensor.contiguous().clone()
    nt2[:,1:] = nt2[:,1:]*time_min
    return nt2

In [None]:
vocab_size=128
#time_rescale = 0.008 # per Magenta specs
time_rescale = 1 # for my JSB retread
notes_to_midi(outbound_t_rescale(all_notes[0:120],time_min=smallest_time), time_rescale=time_rescale)

In [None]:
data = all_notes # all_notes
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]
print("len(train_data), len(val_data) =",len(train_data), len(val_data) )
#assert n_contin == train_data.shape[-1]-1  # data consists of pitches plus continuous variables

In [None]:
len(all_notes[:,1].unique())

In [None]:
len(all_notes[:,2].unique())

In [None]:
all_notes[0:20]

# TODO: need to do proper vocab mapping for pitches, steps and durations. 


In [None]:
vocab_size = 128 
print("pitch spread = ",len(all_notes[:,0].unique()))
pitch_vocab_size = vocab_size # 
#steps_size = len(all_notes[:,1].unique())
#durs_size = len(all_notes[:,2].unique())
step_vocab_size = int(all_notes[:,1].max()+1)
dur_vocab_size = int(all_notes[:,2].max()+1)

pitch_vocab_size, step_vocab_size, dur_vocab_size

# Hyperparameters

In [None]:

# hyperparameters
learning_rate = 1e-3
#learning_rate = 5e-5  # for onecycle schedule

'''
# Overfit parameters:
batch_size = 64 # how many independent sequences will we process in parallel?
block_size = 128 # what is the maximum context length for predictions?
n_embd = 384 #256# 64  # dimensions of embeddings
n_head = 16# 4
n_layer = 16# 4
dropout = 0.1 #0.1

max_iters = 25000
eval_interval = 100
eval_iters = 100
use_alibi = False
'''

'''
Underfit parameters:
batch_size = 64 # how many independent sequences will we process in parallel?
block_size = 128 # what is the maximum context length for predictions?
n_embd = 64 #256# 64  # dimensions of embeddings
n_head = 8# 4
n_layer = 6# 4
dropout = 0.1 #0.1
'''

'''
# large sequence length
batch_size = 64 # how many independent sequences will we process in parallel?
block_size = 256 # what is the maximum context length for predictions?
n_embd = 256 # 64  # dimensions of embeddings
n_head = 16 # 4
n_layer = 8 # # 4
dropout = 0.1 #0.1
'''

# jsb
batch_size = 128 # how many independent sequences will we process in parallel?
block_size = 128 # what is the maximum context length for predictions?
n_embd = 64 # 64  # dimensions of embeddings
n_head = 8 # 4
n_layer = 4 # # 4
dropout = 0.1 #0.1


max_iters = 25000
eval_interval = 100
eval_iters = 100
use_alibi = True


# ------------

config = {
  'learning_rate': learning_rate,
  'batch_size': batch_size,
  'block_size': block_size,
  'n_embd': n_embd,
  'n_head': n_head,
  'n_layer': n_layer,
  'dropout': dropout,
  'use_alibi': use_alibi
}

In [None]:
print(f"removing short songs, for block_size = {block_size}. Current len(notes_list) = {len(notes_list)}")
notes_list = [q for q in notes_list if q.shape[0] > block_size+1 ]
print("new len(notes_list) =",len(notes_list))

In [None]:
def augment_data(db, pb=12): # db = datablock - a seqeunce , likely of length 1+block_size 
    db = db.clone()  # avoid in-place alterations of data
    #minpitch, maxpitch = int(db[::,0].min().item()),  int(db[::,0].max().item())
    pitch_shift = torch.randint(low=-pb, high=pb, size=(1,1,1), dtype=torch.long).item()
    db[:,0] = torch.where( db[:,0] < 127,  db[:,0]+pitch_shift,  db[:,0])  
    #time_scale = torch.randint(1, 2, (1,)) 
    #db[::,0] = torch.clamp(db[::,0] + pitch_shift, 0, pitch_vocab_size-1) 
    #db[:,:,1] = torch.clamp(db[:,:,1] * time_scale,  0, step_vocab_size-1)   
    #db[:,:,2] = torch.clamp(db[:,:,2] * time_scale,  0, dur_vocab_size-1)  
    return db

class NotesDataset(Dataset):
    def __init__(self, 
                 notes, # notes is either a list of tensors (i.e. songs) or a tensor (all notes together)
                 block_size, 
                 augment=False):
        super().__init__()
        self.notes = notes
        self.augment = augment   # set to True for training 
        
    def __len__(self):
        return len(self.notes)

    def __getitem__(self, idx):
        data = self.notes[idx] if type(self.notes) is list else self.notes 
        try:
            i = torch.randint(0, data.shape[0]-1 - block_size, (1,)) 
        except: 
            assert False, f"data.shape[0] = {data.shape[0]}, block_size = {block_size}. type(self.notes) = {type(self.notes)}"
        data_block = data[i:i+block_size+1]  # grab input sequence plus the next note
        if self.augment:
            data_block = augment_data( data_block )
        x, y = data_block[:-1], data_block[1:] 
        return x, y                   
        

random.seed(1337)

use_song_list = False
notes = notes_list if use_song_list else all_notes 
n_split = int(0.8*len(notes))
train_ds = NotesDataset(notes[:n_split], block_size, augment=True)
val_ds   = NotesDataset(notes[n_split:], block_size, augment=False)


train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=12)
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=12)

In [None]:
from itertools import cycle

train_dl_iter = cycle(train_dl)
val_dl_iter = cycle(val_dl)

In [None]:



def get_batch(split, debug=False):
    '''
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.clone(), y.clone()  # just to avoid overwrites
    if split == 'train': # data augmentation
        x, y = augment_data(x,y)
    x[:,0,1] = x[:,0,1] * 0  # set first step of every x sequence to 0 (this shouldn't make a difference though)
    '''
    x, y = next(train_dl_iter) if split == 'train' else next(val_dl_iter)
    x, y = x.to(device), y.to(device)
    return x, y
    

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        pitch_losses = torch.zeros(eval_iters)
        step_losses = torch.zeros(eval_iters)
        dur_losses = torch.zeros(eval_iters)

        for k in range(eval_iters):
            X, Y = get_batch(split)
            pitch_logits, step_logits, dur_logits, loss, sublosses = model(X, Y)
            losses[k] = loss.item()
            pitch_losses[k] = sublosses['pitch_loss']
            step_losses[k] = sublosses['step_loss']
            dur_losses[k] = sublosses['dur_loss']
        out[split] = losses.mean()
        out['pitch_loss'] = pitch_losses.mean()
        out['step_loss'] = step_losses.mean()
        out['dur_loss'] = dur_losses.mean()
    model.train()
    return out

In [None]:
x,y = get_batch('val',debug=True)
print("x.shape = ",x.shape)
print(f"B, T = {batch_size}, {block_size}")

x is a sequence

In [None]:
#time_rescale = info['seconds_per_tick']*grid_resolution
ind = 2
notes_to_midi(x[ind], time_rescale=time_rescale)
x[ind][:10]

y is x shifted back by one and including new data.
in this sense only y[:,-1] is the "next token" being predicted.

In [None]:
notes_to_midi(y[ind], time_rescale=time_rescale)
y[ind][:10]

Model definition

In [None]:
def create_alibi_mask(n):
    # Create an n x n matrix filled with zeros
    mat = torch.zeros((n, n))
    
    # Iterate over the lower triangular indices (excluding the diagonal)
    for i in range(1, n):
        for j in range(i):
            mat[i, j] = -(i - j)
    
    # Create a mask for the upper triangular part, excluding the diagonal
    mask = torch.triu(torch.ones_like(mat), diagonal=1)
    
    # Apply the mask and set the upper triangular part to -inf
    mat[mask.bool()] = float('-inf')
    
    return mat


class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key   = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.use_alibi = use_alibi
        if use_alibi:
            self.register_buffer('alibi', create_alibi_mask(block_size))
        else:
            self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)


    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        if self.use_alibi:
            wei = wei + self.alibi[:T,:T]
        else:
            wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out


class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x


def mse_with_positive_pressure(y_true: torch.Tensor, y_pred: torch.Tensor):
  # from Magenta example, converted to pytorch
  se = (y_true - y_pred) ** 2
  positive_pressure = 10 * torch.clamp(-y_pred, min=0) # ten times the negative values made positive
  return (se + positive_pressure).mean()


# super simple bigram model
class BigramLanguageModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.add_embeddings = False
        if self.add_embeddings: 
            pitch_embd, step_emb, dur_embd = [n_embd]*3
        else: 
            pitch_embd, step_emb, dur_embd = int(n_embd*0.5), int(n_embd*0.25), int(n_embd*0.25) 
        # each token directly reads off the logits for the next token from a lookup table
        self.pitch_embedding_table = nn.Embedding(pitch_vocab_size, pitch_embd)
        self.step_embedding_table = nn.Embedding(step_vocab_size, step_emb)
        self.dur_embedding_table = nn.Embedding(dur_vocab_size, dur_embd)

        if not use_alibi: self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm
        self.lm_head = nn.Linear(n_embd, pitch_vocab_size, bias=False)
        self.step_head = nn.Linear(n_embd, step_vocab_size, bias=False)
        self.dur_head = nn.Linear(n_embd, dur_vocab_size, bias=False)
        
    def forward(self, inputs, targets=None):
        device = inputs.device
        pitch_idx, step_idx, dur_idx = [inputs[:,:,i].type(torch.long).to(device) for i in [0,1,2]]  # indices
        B, T = pitch_idx.shape

        # idx and targets are both (B,T) tensor of integers
        pitch_tok_emb = self.pitch_embedding_table(pitch_idx) # (B,T,C) # embed the pitches
        step_tok_emb  = self.step_embedding_table(step_idx)   # (B,T,C) # embed the steps
        dur_tok_emb   = self.dur_embedding_table(dur_idx)     # (B,T,C) # embed the durations

        if self.add_embeddings:
            x = pitch_tok_emb + step_tok_emb + dur_tok_emb  # try just adding them? 
        else:
            x = torch.cat((pitch_tok_emb, step_tok_emb, dur_tok_emb), dim=-1)   # concat the separate embeddings

        if not use_alibi:
            pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
            x = x + pos_emb # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x)   # (B,T,C)
        pitch_logits = self.lm_head(x)    # (B,T,pitch_vocab_size)
        step_logits  = self.step_head(x)  # (B,T,step_vocab_size)
        dur_logits   = self.dur_head(x)   # (B,T,dur_vocab_size)
               
        # we can also compute losses if targets are included in .forward call
        if targets is None:
            return pitch_logits, step_logits, dur_logits, None
        else:
            B, T, C = pitch_logits.shape
            pitch_logits, step_logits, dur_logits = [q.view(B*T,-1) for q in [pitch_logits, step_logits, dur_logits]]
            pitch_targets, step_targets, dur_targets = [targets[:,:,i].type(torch.long).to(device).view(B*T) for i in [0,1,2]]
            
            lambda_pitch, lambda_step, lambda_dur = 0.5, 0.5, 0.5  # scale relative losses
            #lambda_pitch, lambda_step, lambda_dur = 0.5, 0.5, 0.1  # scale relative losses
            #lambda_pitch, lambda_step, lambda_dur = 0.5, 0.1666, 0.2  # scale relative losses

            pitch_loss = lambda_pitch * F.cross_entropy(pitch_logits, pitch_targets)
            step_loss  = lambda_step  * F.cross_entropy(step_logits,  step_targets)
            dur_loss   = lambda_dur   * F.cross_entropy(dur_logits,   dur_targets)

            sublosses = {'pitch_loss':pitch_loss.item(), 'step_loss':step_loss.item(), 'dur_loss':dur_loss.item() } 

            loss = pitch_loss + step_loss + dur_loss
            return pitch_logits, step_logits, dur_logits, loss, sublosses

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0):
        # idx is (B, T, 1+contin_vars) array of values in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            pitch_logits, step_logits, dur_logits, loss = self(idx_cond)
            # focus only on the last time step
            pitch_logits, step_logits, dur_logits = [q[:, -1, :] for q in [pitch_logits, step_logits, dur_logits]] # becomes (B, C)
            # apply softmax to get probabilities
            pitch_probs = F.softmax(pitch_logits/temperature, dim=-1) # (B, C)
            step_probs  = F.softmax(step_logits /temperature, dim=-1) # (B, C)
            dur_probs   = F.softmax(dur_logits  /temperature, dim=-1) # (B, C)

            # sample from the distribution
            pitch_idx_next = torch.multinomial(pitch_probs, num_samples=1) # (B, 1)
            step_idx_next  = torch.multinomial(step_probs,  num_samples=1) # (B, 1)
            dur_idx_next   = torch.multinomial(dur_probs,   num_samples=1) # (B, 1)

            # concat continuous variables
            idx_next = torch.cat((pitch_idx_next, step_idx_next, dur_idx_next), dim=-1).unsqueeze(0) # unsqueeze to add dummy batch dim
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)

        return idx.cpu()



Instantiate and get ready to run

In [None]:
model = BigramLanguageModel()
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
#scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, 0.002, total_steps=max_iters)

In [None]:
def save_checkpoint(step, model, optimizer,loss, name):
    name = name + f'_{step}.pt'
    print("Saving checkpoint to",name)
    torch.save({
            'step': step,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            }, name)

def load_checkpoint(name, model, optimizer):
    checkpoint = torch.load(name)
    model.load_state_dict(checkpoint['model_state_dict'])
    m = model.to(device)
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    step = checkpoint['step']
    loss = checkpoint['loss']
    return step, m, optimizer, loss

In [None]:
load_the_checkpoint = False

if load_the_checkpoint:
    step, model, optimizer, loss = load_checkpoint('augemented_cuda_checkpoint.pt', model, optimizer)

In [None]:
use_wandb = True
if use_wandb: wandb.init(project='musicbox-jsb', config=config)

# Do training

In [None]:
loss_hist = []
checkpoint_every = 1000
for iteri in range(max_iters): # "iter" already means something in python, so "iteri" here

    # every once in a while evaluate the loss on train and val sets
    if iteri % eval_interval == 0 or iteri == max_iters - 1:
        losses = estimate_loss()
        print(f"iter {iteri}: losses: train {losses['train']:.4f}, val {losses['val']:.4f}, pitch {losses['pitch_loss']:.4f},",
              f"step_loss {losses['step_loss']:.4f} dur_loss {losses['dur_loss']:.4f}")
        if use_wandb: wandb.log(losses | {'iter':(iteri)//eval_interval,  'lr':optimizer.param_groups[0]['lr'] })

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    pitch_logits, step_logits, dur_logits, loss, sublosses = model(xb, yb)

    if (iteri>0) and (iteri % checkpoint_every==0):
        save_checkpoint(iteri, model, optimizer, loss, "giant_model_quantized_checkpoint-noalibi")
        
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    #scheduler.step()

In [None]:
save_checkpoint(iteri, model, optimizer,loss, f"giant_model_quantized-{'no' if use_alibi==False else ''}alibi")

In [None]:
if use_wandb: wandb.finish()

# Evaluate Generative Capabilities

In [None]:
iteri, model, optimizer, loss = load_checkpoint('giant_model_quantized.pt', model, optimizer)

In [None]:
# input examplze (from training dataset)
time_rescale=0.008 if source_dataset != 'jsb' else .12
input_example, _  = next(train_dl_iter)
input_example = input_example[1]
print("len(input_example) =", len(input_example))
notes_to_midi(input_example, time_rescale=time_rescale)

In [None]:
n_starting_notes = 32  # how many real notes to start with
starting_notes = input_example[:n_starting_notes]
notes_to_midi(starting_notes, time_rescale=time_rescale)

Generate

In [None]:
# generate output variations
context = starting_notes.detach().to(device=device).unsqueeze(0)

for variation in range(4):
  notes = m.generate(context, max_new_tokens=128, temperature=1.0)[0]
  notes_to_midi(notes, time_rescale=time_rescale)

In [None]:
notes

In [None]:

def generator(model:nn.Module, idx, max_new_tokens=10, temperature=1.0):
    # idx is (B, T, 1+contin_vars) array of values in the current context
    if len(idx.shape) < 3: idx = idx.unsqueeze(0)
    with torch.no_grad():
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            pitch_logits, step_logits, dur_logits, loss = model(idx_cond.to(device))
            print(f"p s d shapes: {[q.shape for q in [pitch_logits, step_logits, dur_logits]]}") 
            # focus only on the last time step
            pitch_logits, step_logits, dur_logits = [q[:, -1, :].to(device) for q in [pitch_logits, step_logits, dur_logits]] # becomes (B, C)
            # apply softmax to get probabilities
            pitch_probs = F.softmax(pitch_logits/temperature, dim=-1) # (B, C)
            step_probs  = F.softmax(step_logits /temperature, dim=-1) # (B, C)
            dur_probs   = F.softmax(dur_logits  /temperature, dim=-1) # (B, C)
    
            # sample from the distribution
            pitch_idx_next = torch.multinomial(pitch_probs, num_samples=1) # (B, 1)
            step_idx_next  = torch.multinomial(step_probs,  num_samples=1) # (B, 1)
            dur_idx_next   = torch.multinomial(dur_probs,   num_samples=1) # (B, 1)
    
            # concat continuous variables
            idx_next = torch.cat((pitch_idx_next, step_idx_next, dur_idx_next), dim=-1).unsqueeze(0) # unsqueeze to add dummy batch dim
            # append sampled index to the running sequence
            idx = torch.cat((idx.to(device), idx_next.to(device)), dim=1) # (B, T+1)
    return idx.cpu()

sn = starting_notes[:5]
notes_pred = generator(model, sn, max_new_tokens=5)
print("starting notes =\n",sn)
print("notes_pred =\n",notes_pred)

Links:

https://archives.ismir.net/ismir2021/latebreaking/000005.pdf

https://dl.acm.org/doi/10.1145/3394171.3413671

https://miditok.readthedocs.io/en/latest/index.html

https://github.com/lucasnfe/adl-piano-midi

https://medium.com/mlearning-ai/generating-music-with-gpt-b0f4ab738b58

https://arxiv.org/pdf/1809.04281.pdf

https://magenta.tensorflow.org/music-transformer

https://colab.research.google.com/notebooks/magenta/piano_transformer/piano_transformer.ipynb

https://github.com/gwinndr/MusicTransformer-Pytorch