* Karpathy video: https://www.youtube.com/watch?v=kCc8FmEb1nY
* Karpathy repo: https://github.com/karpathy/nanoGPT
* Colab for video: https://colab.research.google.com/drive/1JMLa53HDuA-i7ZBmqV7ZnA3c_fvtXnx-?usp=sharing

# Setup:

In [1]:
data_source = "jsb"

! rm -f midi_files
if data_source == "shakespeare":
    !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
elif data_source == "maestro":
    !wget -N https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0-midi.zip
    !unzip -n -qq maestro-v3.0.0-midi.zip
    !ln -s maestro-v3.0.0 midi_files 
elif data_source == "jsb":
    !ln -s jsb_chorale_midi midi_files

In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from glob import glob
import pathlib
import pretty_midi
from tqdm import tqdm
import numpy as np
import random 
from midi_player import MIDIPlayer
from midi_player.stylers import basic, cifka_advanced, dark
import multiprocessing as mp
from tqdm.contrib.concurrent import process_map
import pandas as pd

In [3]:
use_wandb = True
if use_wandb:
    import wandb
    wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mdrscotthawley[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
# Karpathy's # hyperparameters
batch_size = 16 # how many independent sequences will we process in parallel?
block_size = 32 # what is the maximum context length for predictions?
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3
eval_iters = 200
n_embd = 64
n_head = 4
n_layer = 4
dropout = 0.0
weight_decay = 1e-2 # pytorch default added here
# ------------


# Mine for Pitches
# hyperparameters
batch_size = 64 # how many independent sequences will we process in parallel?
block_size = 256 # what is the maximum context length for predictions?
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3
eval_iters = 200
n_embd = 128 # 64
n_head = 8 # 4
n_layer = 4 * 4
dropout = 0.1
weight_decay = 1e-2 # pytorch default added here


# Mine for Full notes? 
# hyperparameters
batch_size = 96 # how many independent sequences will we process in parallel?
block_size = 256 # what is the maximum context length for predictions?
max_iters = 5000
eval_interval = 100
#learning_rate = 1e-3
learning_rate = 3e-3
eval_iters = 200
n_embd = 128 # 64
n_head = 8 # 4
n_layer = 4 * 4
dropout = 0.1
weight_decay = 1e-2 # pytorch default added here

# for comparison with giant_model code
batch_size = 128 # how many independent sequences will we process in parallel?
block_size = 128 # what is the maximum context length for predictions?
n_embd = 64 # 64  # dimensions of embeddings
n_head = 8 # 4
n_layer = 4 # # 4
dropout = 0.1 #0.1
max_iters = 25000
eval_interval = 100
eval_iters = 100
#use_alibi = True # n/a here
weight_decay = 1e-2 # pytorch default added here


# for jsb testing
batch_size = 128 # how many independent sequences will we process in parallel?
block_size = 128 # what is the maximum context length for predictions?
learning_rate = 1e-3
n_embd = 256
n_head = 16
n_layer = 8
dropout = 0.7 # drop that shit
weight_decay =  1e-2 # 1e-2 is pytorch default
# ------------


#"small model" but larger batch:
batch_size = 32
block_size =64  
learning_rate = 0.001
n_embd = 128
n_head = 8
n_layer = 4
dropout = 0.1

max_iters = 11000
eval_interval = 100
eval_iters = 200


config = {
    'learning_rate': learning_rate,
    'batch_size': batch_size,
    'block_size': block_size,
    'n_embd': n_embd,
    'n_head': n_head,
    'n_layer': n_layer,
    'dropout': dropout,
    'weight_decay': weight_decay,
}

In [5]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

print('device =',device) 

device = cuda:0


# Dataset creation

In [6]:
# for reproducibility.
def set_seeds(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed) 
        
set_seeds(0)

In [7]:
data_dir = pathlib.Path('midi_files')
filenames = glob(str(data_dir/'**/*.mid*'), recursive=True)
print('Number of files:', len(filenames))

Number of files: 382


In [8]:
def midi_file_to_tensor(midi_file,
                        time_units='ticks', # beats, ticks, s
                        info=False,  # return info about the track
                       ):
    pm = pretty_midi.PrettyMIDI(midi_file) # read in the whole file. this is very slow for long midi files (e.g. in MAESTRO)
    # Sort the notes first by start time (then by pitch if two notes start at the same time)
    sorted_notes = sorted(pm.instruments[0].notes, key=lambda note: (note.start, note.pitch))
    notes = torch.empty( (len(sorted_notes), 5), dtype=torch.float32 ) # allocate storage
    
    prev_start = sorted_notes[0].start
    for i, note in enumerate(sorted_notes):
        notes[i] = note.pitch
        notes[i, 1] = note.start - prev_start  # step, time since last note
        notes[i, 2] = note.end - note.start    # duration
        # let's keep the start & end to make life easier elsewhere. we can always chop these off.
        notes[i, 3], notes[i, 4] = note.start, note.end 
        prev_start = note.start

    return notes


def files_to_tensor_list(filenames): 
    "runs in parallel so should be reasonably fast"
    tensor_list = process_map(midi_file_to_tensor, filenames, max_workers=mp.cpu_count(), chunksize=1)
    return tensor_list

In [9]:
# read in one song
notes = midi_file_to_tensor(filenames[0])
pitches = notes[:,0].type(torch.long)  # just the pitch info
notes.shape, pitches.shape

(torch.Size([432, 5]), torch.Size([432]))

In [10]:
# here we read midi files into a list of tensors
read_all_midi_files = True # sometime you don't wanna re-do this

if read_all_midi_files:
    notes_list = files_to_tensor_list(filenames)
    print(f"\nlen(notes_list) = {len(notes_list)}")
    torch.save(notes_list, f'{data_source}_tensor_list.pt') # save for next time

  0%|          | 0/382 [00:00<?, ?it/s]


len(notes_list) = 382


In [11]:
# load from previous save
notes_tensor_list = torch.load(f'{data_source}_tensor_list.pt')  # load from previous computation
#notes_list = torch.load('rastro-120bpm_16th_tensor_list.pt')  # load from previous computation
len(notes_tensor_list), notes_tensor_list[0].shape

(382, torch.Size([432, 5]))

In [12]:
def step_from_last(ts:torch.Tensor):
    "returns how far the next song should step from the last-struck note of the current song"
    if ts.shape[-1] < 5: return None # too much of a pain to re-integrate start & end times
    last_struck_note = ts.shape[0]-1  # the final midi note event
    last_held_note = ts[:,4].argmax() # the note with the final end time
    # time diff between (start of last-struck note) and (end of last-held note)
    return ts[last_held_note,4] - ts[last_struck_note,3]


def tl_to_notes(tensor_list, 
                shuffle=False, 
                delimit=True, # leave this on
                rest_pitch=127, # some datasets already prefer 127 or -1. you should check
                rest_dur=0.96,  # value used by jsb chorales iirc. in seconds
                ):
    """Takes list of tensors (of arbitrary length, for each song).
    converts to one big long tensor of notes all running togehter"""
    if shuffle: random.shuffle(tensor_list)  # shuffle order of songs
    # writing the following as a loop first so i get it right
    out_tl = []
    for si, ts in enumerate(tensor_list):  # ts = "tensor song" lol
        if si == 0 :  
            out_tl.append(ts)  # no work to do
        else:
            sfl = step_from_last(tensor_list[-1]) 
            add_rest = torch.tensor((rest_pitch, sfl, rest_dur, 0, rest_dur))  # jsb chorales used this method, so good enough for me
            out_tl.append(add_rest)
            ts[0,1] = rest_dur  # step of not of new song should be dur of rest that precedes it.
            out_tl.append(ts)
    # remember to strip out start & end values
    out = torch.vstack(out_tl).type(torch.float32)  # return one big tensor of floats
    out = out[:,:3]   #  leave only pitch, step, duration
    return out

set_seeds(0)
all_notes = tl_to_notes(notes_tensor_list, shuffle=True) 
notes_tensor_list = [q[:,:3] for q in notes_tensor_list] # make sure we chopped off any any extra values
all_pitches = all_notes[:,0].type(torch.long)  # just the pitch info
all_notes.shape, all_pitches.shape

(torch.Size([78799, 3]), torch.Size([78799]))

In [13]:
# save all_notes for inspection & later use
torch.save(all_notes, f'{data_source}_all_notes_tensor.pt')

In [14]:
def notes_arr_to_df(notes_arr) -> pd.DataFrame:
    columns = ['pitch','step','duration']
    df = pd.DataFrame(notes_arr, columns=columns)
    df["start"], df["end"] = "",  ""
    prev_start = 0
    for i, row in df.iterrows():
        start = prev_start + float(row['step'])
        df.at[i, 'start'] = start
        df.at[i, 'end']   = start + float(row['duration'])
        prev_start = start
    return df

def df_to_midi(
        notes_df: pd.DataFrame,
        out_file: str = '',  # output file to save to, if any
        instrument_name: str = 'Acoustic Grand Piano', # whatever you want to call this instrument
        velocity: int = 100,  # note loudness
    ) -> pretty_midi.PrettyMIDI:
    "converts a dataframe to valid midi"

    pm = pretty_midi.PrettyMIDI()
    instrument = pretty_midi.Instrument(
        program=pretty_midi.instrument_name_to_program(
            instrument_name))

    prev_start = 0
    for i, note in notes_df.iterrows(): # this is a serial operation, not sure how to parallelize
        start = float(prev_start + note['step'])
        end = float(start + note['duration'])
        note = pretty_midi.Note( velocity=velocity, pitch=int(note['pitch']), start=start,  end=end, )
        instrument.notes.append(note)
        prev_start = start

    pm.instruments.append(instrument)
    if out_file: pm.write(out_file)
    return pm

def notes_to_midi(notes_tensor, time_rescale=None, out_file: str = '') -> pretty_midi.PrettyMIDI:
    notes_tensor = notes_tensor.clone() # just to avoid weird overwrites of memory
    #notes_tensor = notes_tensor * (notes_tensor>0)  # negative numbers clipped to zero
    if notes_tensor.min() < 0.0:
      print("WARNING: You have negative pitches, steps or durations. Setting them to zero")
      notes_tensor = notes_tensor * (notes_tensor >= 0)
    if time_rescale is not None :
        notes_tensor[:,1:] = notes_tensor[:,1:] *time_rescale # no quantization, just rescaling time
    notes_df = notes_arr_to_df(notes_tensor.cpu().detach().numpy())
    return df_to_midi(notes_df, out_file=out_file)

def midiplayer(notes_tensor, height=400, time_rescale=None, midi_file="/tmp/tmp.mid"):
    pm = notes_to_midi(notes_tensor, time_rescale=time_rescale, out_file=midi_file)
    return MIDIPlayer(midi_file, height, styler=dark, dl=True)

In [15]:
p1 = midiplayer(notes_tensor_list[0], time_rescale=1.1)
p1

## Codebook creation

In [16]:

# Housekeeping: clamp any step or dur values that are really long at, say 4 seconds
all_notes[:,1] = torch.clamp(all_notes[:,1], 0, 4.0) # step
all_notes[:,2] = torch.clamp(all_notes[:,2], 0, 4.0) # dur

raw_data = all_notes
print("raw_data.shape =",raw_data.shape)
n_codebooks = raw_data.shape[-1]
print("n_codebooks =",n_codebooks)


codebooks = []
for i in range(n_codebooks): 
    if i==0:
        cb_vals = torch.arange(128)  # use all possible pitches 
    else:
        cb_vals = raw_data[:,i].unique().sort()[0]  # sorted(set(raw_data[:,i]))
    print(f"\n---\ncb {i}: cb_vals = {cb_vals}")
    codebooks.append({'encode':{k.item(): int(v) for v, k in enumerate(cb_vals)}, 
                      'decode':{int(v): k for v, k in enumerate(cb_vals)}})
    print(f" cb {i}: cb keys = {codebooks[-1]['encode'].keys()}")
vocab_sizes = [len(cb['encode']) for cb in codebooks]
print("vocab_sizes = ",vocab_sizes)
vocab_size = vocab_sizes[0]
vocab_size

raw_data.shape = torch.Size([78799, 3])
n_codebooks = 3

---
cb 0: cb_vals = tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127])
 cb 0: cb keys = dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 

128

Encoder and decoder

In [17]:
#these are slow and could be improved a lot by doing all cbs at once
#encode = lambda s: torch.tensor( [[codebooks[cb]['encode'][note[cb].item()] for cb in range(n_codebooks)] for note in s], dtype=torch.long).to(s.device)
        
#decode = lambda si: torch.tensor([[codebooks[cb]['decode'][i[cb].item()]  for cb in range(n_codebooks) ] for i in si]).to(si.device)

def remap_vals(s, direction, dtype=torch.long):
    out = torch.zeros_like(s, dtype=dtype)
    for cb in range(s.shape[-1]): 
        dict_map = codebooks[cb][direction]
        out[:,cb] = torch.tensor([dict_map[x.item()] for x in s[:,cb]], dtype=dtype)
    return out 

encode = lambda s: remap_vals(s, 'encode')
decode = lambda s: remap_vals(s, 'decode', dtype=all_notes.dtype)

test_str = all_notes[0:6]
print("Before encoding, test_str =\n",test_str)
ilist = encode(test_str)
print("After encoding, ilist =\n",ilist)
ret_str = decode(ilist)
print("After decoding, ret_str =\n",ret_str)
assert torch.equal(test_str, ret_str), f"Oops. test_str={test_str}, but ret_str={ret_str}"
print("Checks pass! :-)")

Before encoding, test_str =
 tensor([[51.0000,  0.0000,  0.4795],
        [66.0000,  0.0000,  0.4795],
        [69.0000,  0.0000,  0.4795],
        [71.0000,  0.0000,  1.4409],
        [52.0000,  0.4795,  0.4795],
        [64.0000,  0.0000,  0.4795]])
After encoding, ilist =
 tensor([[51,  0,  6],
        [66,  0,  6],
        [69,  0,  6],
        [71,  0, 21],
        [52,  7,  6],
        [64,  0,  6]])
After decoding, ret_str =
 tensor([[51.0000,  0.0000,  0.4795],
        [66.0000,  0.0000,  0.4795],
        [69.0000,  0.0000,  0.4795],
        [71.0000,  0.0000,  1.4409],
        [52.0000,  0.4795,  0.4795],
        [64.0000,  0.0000,  0.4795]])
Checks pass! :-)


In [18]:
data = encode(all_notes) # this takes a while
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]
print("train_data.shape, val_data.shape =",train_data.shape, val_data.shape)

train_data.shape, val_data.shape = torch.Size([70919, 3]) torch.Size([7880, 3])


In [19]:

def augment_data_db(db, pitch_shift=12, debug=True):
    dbnew = db.clone()
    change = torch.randint(-pitch_shift, pitch_shift, (1,))
    dbnew[ dbnew[:,:,0] < 127 ] += torch.tensor((change, 0, 0))
    dbnew[:,:,0] = torch.clamp(dbnew[:,:,0], 0, 127)
    return dbnew


# test data augmentation
db = torch.tensor([[54,12,6],[61,0,124],[127,31,4],[86,0,12],[126,7,12]]).unsqueeze(0)
print("db.shape = ",db.shape)
print("original db = \n",db)
ret = augment_data_db(db) # make sure the note with the "127" rest is left completely alone
print("augmented db = \n",ret)

assert torch.equal(db[0,2,:], ret[0,2,:]),f"Those lines don't match but they should: {db[0,2,:]} and {ret[0,2,:]}"
for i in [0,1,3,4]:
    assert not torch.equal(db[0,i,:], ret[0,i,:]), f"row {i}'s do match but probably shouldn't: {db[0,i,:]} and {ret[0,i,:]}"

print("Data aug looks good!")

db.shape =  torch.Size([1, 5, 3])
original db = 
 tensor([[[ 54,  12,   6],
         [ 61,   0, 124],
         [127,  31,   4],
         [ 86,   0,  12],
         [126,   7,  12]]])
augmented db = 
 tensor([[[ 62,  12,   6],
         [ 69,   0, 124],
         [127,  31,   4],
         [ 94,   0,  12],
         [127,   7,  12]]])
Data aug looks good!


In [20]:
# data loading
def get_batch(split, debug=False):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,)) # batch_size number of indices to grab from
    data_block = torch.stack([data[i:i+block_size+1] for i in ix])
    if split=='train': 
        data_block = augment_data_db(data_block)
    x, y = data_block[:,:-1,:], data_block[:,1:,:]
    #x = torch.stack([data[i:i+block_size] for i in ix])
    #y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    #x, y = augment_data(x,y)
    x, y = x.to(device), y.to(device)
    if debug: print(f"get_batch: x.shape = {x.shape}, y.shape = {y.shape}")
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [21]:
set_seeds(0)

In [22]:
x,y = get_batch('train',debug=True) # first time is slow
print(f"B, T = {batch_size}, {block_size}")
x[0][:10]

get_batch: x.shape = torch.Size([32, 64, 3]), y.shape = torch.Size([32, 64, 3])
B, T = 32, 64


tensor([[64,  0, 18],
        [68,  0,  6],
        [49,  4,  2],
        [47,  3,  3],
        [61,  0,  6],
        [69,  0,  6],
        [45,  4,  2],
        [44,  3,  3],
        [59,  0,  3],
        [71,  0,  7]], device='cuda:0')

y is x ahead back by one and including new data.
in this sense only y[:,-1] is the "next token" being predicted.

In [23]:
y[0][:10]

tensor([[68,  0,  6],
        [49,  4,  2],
        [47,  3,  3],
        [61,  0,  6],
        [69,  0,  6],
        [45,  4,  2],
        [44,  3,  3],
        [59,  0,  3],
        [71,  0,  7],
        [42,  4,  3]], device='cuda:0')

In [24]:
midiplayer(decode(x[0]), time_rescale=1)

# Model definition

In [25]:

class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

# super simple bigram model
class BigramLanguageModel(nn.Module):

    def __init__(self, debug=False):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        print("vocab sizes = ",[vocab_sizes[cb] for cb in range(n_codebooks)])
        self.token_embedding_tables = nn.ModuleList([nn.Embedding(vocab_sizes[cb], n_embd) for cb in range(n_codebooks)])
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm
        self.lm_heads = nn.ModuleList([nn.Linear(n_embd, vocab_sizes[cb]) for cb in range(n_codebooks)])
        self.debug=False

    def forward(self, idx, targets=None):
        B, T, CBS = idx.shape
        tok_emb = 0
        for cb in range(CBS): # just sum codebook reps
            if self.debug: 
                print(f"cb = {cb}, idx.shape = {idx.shape}, idx[:,:,cb].shape = {idx[:,:,cb].shape}, idx[:,:,cb] =\n",idx[:,:,cb])
                print(f"   idx[:,:,{cb}].min(), idx[:,:,{cb}].max() = ",idx[:,:,cb].min(), idx[:,:,cb].max())
            tok_emb = tok_emb + self.token_embedding_tables[cb](idx[:,:,cb])
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,E)

        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits_list = [head(x) for head in self.lm_heads]  # list of (B,T,vocab_sizes) tensors, CBS long

        if targets is None:
            loss = None
        else:
            lambdas = [0.5]*CBS
            #lambdas =  [1.0, 0.25, 0.25]+[0.5]*(CBS-3)  # relative "weights" to pitch, step, dur losses. anything added later gets 0.5
            loss = 0.0
            for cb in range(CBS):  # loop over codebooks
                logits = logits_list[cb]  
                B, T, V = logits.shape   # V = vocab size
                targ = targets[:,:,cb]   # B, T 
                logits = logits.view(B*T, V)
                targ = targ.reshape(B*T)
                loss = loss + lambdas[cb]*F.cross_entropy(logits, targ)
        if self.debug: print("loss = ",loss.item()) 
        return logits_list, loss


    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0):
        # idx is (B, T, CBS) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits_list, loss = self(idx_cond)
            idx_next_list = []
            for cb in range(idx_cond.shape[-1]):
                # focus only on the last time step
                logits = logits_list[cb]  # B, T, V  where V = vocab/embedding size 
                logits = logits[:, -1, :] # get last time.  becomes (B, V)
                # apply softmax to get probabilities
                probs = F.softmax(logits/temperature, dim=-1) # (B, V)
                # sample from the distribution
                idx_next_list.append(torch.multinomial(probs, num_samples=1)) # (B, 1)
                
            idx_next = torch.tensor(idx_next_list).unsqueeze(0).unsqueeze(0).to(idx.device)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx


Instantiate and get ready to run

In [26]:
model = BigramLanguageModel()
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

vocab sizes =  [128, 28, 46]
0.851914 M parameters


In [27]:
def save_checkpoint(step, model, optimizer,loss, name):
    name = name + f'_{step}.pt' if step is not None else name+".pt"
    print("Saving checkpoint to",name)
    torch.save({
            'step': step,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            }, name)

def load_checkpoint(name, model, optimizer):
    checkpoint = torch.load(name)
    model.load_state_dict(checkpoint['model_state_dict'])
    m = model.to(device)
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    step = checkpoint['step']
    loss = checkpoint['loss']
    return step, m, optimizer, loss

In [28]:
start_from_checkpoint = False

if start_from_checkpoint:
    step, model, optimizer, loss = load_checkpoint('fullnotes_jsb_best_cp_do0.7.pt', model, optimizer)

# Do training

In [29]:
if use_wandb: wandb.init(project='testvis-jsb', config=config)

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112386788914188, max=1.0…

In [None]:
demo_every = eval_interval*2
cp_every = eval_interval*2
prompt_length = 24
val_int = 4*((n+1)//4) # try to get at the start of 4 notes in a chord
context = data[val_int:val_int+prompt_length].unsqueeze(0).to(device) # same starting point from somewhere in validation set
best_loss = 999

print("Starting training loop...")
for iter in range(max_iters):
    wbl_dict = {} # wandb log dict
    # every once in a while evaluate the loss on train and val sets
    if (iter % eval_interval == 0 or iter == max_iters - 1) and iter>0:
        #print("      Estimating loss")
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        wbl_dict = wbl_dict | losses  | {'step':iter//eval_interval} 

    if use_wandb and iter % demo_every == 0 and iter>0:
        print("   Making demo...",end="",flush=True)
        with torch.no_grad():
            model.eval()
            new_notes = decode( model.generate(context, max_new_tokens=64, temperature=1)[0].cpu() )
            p2 = midiplayer(new_notes, time_rescale=1)
            wbl_dict = wbl_dict | {'player':wandb.Html(p2.html)}
        model.train()
        print(" ...Finished demo")
        
    if use_wandb and wbl_dict != {}: wandb.log(wbl_dict)
    
    if iter % cp_every==0 and iter>0: 
        if losses['val'] < best_loss:
            print("   New best val loss: ",end="",flush=True)
            best_loss = losses['val']
            save_checkpoint(None, model, optimizer,loss, f"fullnotes_jsb_best_cp")   
        
    
    xb, yb = get_batch('train') # sample a batch of data
    logits, loss = model(xb, yb) # evaluate the loss
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()


Starting training loop...
step 100: train loss 3.3224, val loss 3.2836
step 200: train loss 2.9632, val loss 2.9451
   Making demo... ...Finished demo
   New best val loss: Saving checkpoint to fullnotes_jsb_best_cp.pt
step 300: train loss 2.8119, val loss 2.7761
step 400: train loss 2.7092, val loss 2.6804
   Making demo... ...Finished demo
   New best val loss: Saving checkpoint to fullnotes_jsb_best_cp.pt
step 500: train loss 2.5537, val loss 2.5406
step 600: train loss 2.4535, val loss 2.4346
   Making demo... ...Finished demo
   New best val loss: Saving checkpoint to fullnotes_jsb_best_cp.pt
step 700: train loss 2.3512, val loss 2.3711
step 800: train loss 2.2738, val loss 2.2752
   Making demo... ...Finished demo
   New best val loss: Saving checkpoint to fullnotes_jsb_best_cp.pt
step 900: train loss 2.2089, val loss 2.2131
step 1000: train loss 2.1368, val loss 2.1392
   Making demo... ...Finished demo
   New best val loss: Saving checkpoint to fullnotes_jsb_best_cp.pt
step 110

In [None]:
#save_checkpoint(max_iters, model, optimizer,loss, f"fullnotees_jsb")

In [None]:
if use_wandb: wandb.finish()

# Generate

In [None]:
# get the best model from our run 
step, model, optimizer, loss = load_checkpoint('fullnotes_jsb_best_cp.pt', model, optimizer)

In [None]:
# redefining model.generate  for customization
from tqdm.notebook import tqdm as pbar

@torch.no_grad()
def generator(model, idx, max_new_tokens, temperature=1.0):
    # idx is (B, T, CBS) array of indices in the current context
    for _ in pbar(range(max_new_tokens), leave=False):
        # crop idx to the last block_size tokens
        idx_cond = idx[:, -block_size:]
        # get the predictions
        logits_list, loss = model(idx_cond)
        idx_next_list = []
        for cb in range(idx_cond.shape[-1]):
            # focus only on the last time step
            logits = logits_list[cb][:, -1, :] # becomes (B, V)
            # apply softmax to get probabilities
            probs = F.softmax(logits/temperature, dim=-1) # (B, V)
            # sample from the distribution
            idx_next_list.append(torch.multinomial(probs, num_samples=1)) # (B, 1)
        idx_next = torch.tensor(idx_next_list).unsqueeze(0).unsqueeze(0).to(idx.device)
        # append sampled index to the running sequence
        
        idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
    return idx


In [None]:
ind, num_tokens = 0, 150
original = val_data[ind:ind+num_tokens].to(device) 
print("Original:")
display(midiplayer(decode(original), time_rescale=1.1))

prompt_len = 21
prompt = original[0:prompt_len] # grab from validation dataset
print("prompt.shape =",prompt.shape)
print("Prompt:")
midiplayer(decode(prompt), time_rescale=1.1)

In [None]:
new_tokens = num_tokens - prompt_len

for temperature in [0.7, 0.7, 0.85, 0.92, 1.0, 1.2, 1.5]:
    set_seeds(1337) # same temp for same seed will yield same output
    notes = decode( generator(model, prompt.unsqueeze(0), max_new_tokens=new_tokens, temperature=temperature)[0].cpu() )
    print(f"temperature = {temperature}:")
    display(midiplayer(notes, time_rescale=1.0))