In [2]:
import random
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
import cbor2

assert torch.cuda.is_available()
random.seed(42)
torch.manual_seed(42)
summary = SummaryWriter("runs/attention-v3")

In [3]:
# CONSTANTS
CONTEXT_LEN = 256
BATCH_SIZE = 512
EMBED_LEN = 32
ADDED_PARAM_LEN = 3 # duration, delay, velocity
PARAM_LEN = 1 + ADDED_PARAM_LEN
DROPOUT_RATE = 0.2
TOKEN_LEN = 128

In [71]:
track_paths = glob.glob('/home/bob/vmshare/midi/solo-piano/*.tokens')
#display(track_paths)
tracks = []
for path in track_paths:
    with open(path, 'rb') as f:
        track_data = cbor2.load(f)
        if len(track_data) < 100:
            print("too short:", path)
        tracks.append(track_data)

In [72]:
print(tracks[5][0:5])
print(len(tracks))

[[2, 0.0, 0.0, 0.0], [67, 0.3125, 0.0, 0.4094488188976378], [69, 1.2708333333333333, 0.71875, 0.4015748031496063], [67, 0.578125, 0.005208333333333333, 0.3858267716535433], [50, 0.7760416666666666, 0.3489583333333333, 0.33858267716535434]]
456


In [73]:
data = torch.tensor([event for track in tracks for event in track])
split_n = int(0.9 * len(data))
train_data = data[:split_n]
eval_data = data[split_n:]

In [74]:
def get_batch(split):
    data = train_data if split == 'train' else eval_data
    ix = torch.randint(len(data) - CONTEXT_LEN, (BATCH_SIZE,))
    x = torch.stack([data[i:i+CONTEXT_LEN] for i in ix])
    y = torch.stack([data[i+1:i+CONTEXT_LEN+1] for i in ix])
    return x.cuda(), y.cuda()

In [75]:
get_batch('train')[0].shape

torch.Size([512, 256, 4])

In [31]:
class SelfAttention(nn.Module):
    def __init__(self, head_size, n_embed=(EMBED_LEN + ADDED_PARAM_LEN)):
        super().__init__()
        self.key = nn.Linear(n_embed, head_size, bias=False)
        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.value = nn.Linear(n_embed, head_size, bias=False)
        self.dropout = nn.Dropout(DROPOUT_RATE)
        self.register_buffer('tril', torch.tril(torch.ones(CONTEXT_LEN, CONTEXT_LEN)))

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        w = q @ k.transpose(-2, -1) * C**-0.5
        w = w.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        w = F.softmax(w, dim=-1)
        w = self.dropout(w)
        v = self.value(x)
        out = w @ v
        return out

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size, n_embed=(EMBED_LEN + ADDED_PARAM_LEN)):
        super().__init__()
        self.heads = nn.ModuleList([SelfAttention(head_size) for _ in range(num_heads)])
        self.projection = nn.Linear(num_heads * head_size, n_embed, bias=False)
        self.dropout = nn.Dropout(DROPOUT_RATE)
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.projection(out)
        out = self.dropout(out)
        return out

class FeedForward(nn.Module):
    def __init__(self, n_embed):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4 * n_embed),
            nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed),
            nn.Dropout(DROPOUT_RATE),
        )
    def forward(self, x):
        return self.net(x)

class SABlock(nn.Module):
    def __init__(self, n_embed, n_head, n_added=ADDED_PARAM_LEN):
        super().__init__()
        head_size = (n_embed + n_added) // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedForward(n_embed + n_added)
        self.ln1 = nn.LayerNorm(n_embed + n_added)
        self.ln2 = nn.LayerNorm(n_embed + n_added)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

In [47]:
class Model(nn.Module):
    def __init__(self, n_hidden_neurons, n_sa_heads=4):
        super().__init__()
        self.token_embedding_table = nn.Embedding(TOKEN_LEN, EMBED_LEN)
        self.position_embedding_table = nn.Embedding(CONTEXT_LEN, EMBED_LEN)
        # self.sa_head = SelfAttention(EMBED_LEN)
        self.blocks = nn.Sequential(
            SABlock(EMBED_LEN, 4),
            SABlock(EMBED_LEN, 4),
            SABlock(EMBED_LEN, 4),
            SABlock(EMBED_LEN, 4),
            SABlock(EMBED_LEN, 4),
            SABlock(EMBED_LEN, 4),
            SABlock(EMBED_LEN, 4),
            SABlock(EMBED_LEN, 4),
        )
        self.lm_head = nn.Linear(EMBED_LEN + ADDED_PARAM_LEN, TOKEN_LEN + ADDED_PARAM_LEN, bias=False)
        
    def forward(self, x, targets=None):
        B, T, C = x.shape
        tokens = x[:,:,0].long()
        added = x[:,:,1:].view(-1, CONTEXT_LEN, ADDED_PARAM_LEN)
        tok_emb = self.token_embedding_table(tokens)
        pos_emb = self.position_embedding_table(torch.arange(T, device='cuda'))
        x = tok_emb + pos_emb
        x = torch.cat((x, added), dim=2)
        x = self.blocks(x)
        logits = self.lm_head(x)
        if targets is None:
            return logits, None, None, None
        B, T, C = logits.shape
        classes = targets[:, :, 0].long()
        added_params = targets[:, :, 1:]
        label_loss = F.cross_entropy(logits.view(B*T, C), classes.view(B*T))
        sq_err = ((added_params[:, :, -ADDED_PARAM_LEN:] - logits[:, :, -ADDED_PARAM_LEN:]) ** 2)
        added_param_loss = sq_err.mean(dim=1).mean(dim=0).sum()
        loss = label_loss + added_param_loss
        return logits, loss, label_loss, added_param_loss
    
    @torch.no_grad()
    def generate(self, idx, max_new_tokens):
        model.eval()
        out = []
        for _ in range(max_new_tokens):
            idx = idx[:, -CONTEXT_LEN:] # truncate context
            logits, loss, _, _ = self(idx)
            probs = F.softmax(logits[:, -1, :TOKEN_LEN], dim=-1)
            next_token = torch.multinomial(probs, num_samples=1) # B, 1
            token = next_token.int().item()
            added_params = logits[:,-1,-ADDED_PARAM_LEN:].view(1, ADDED_PARAM_LEN)
            next_idx = torch.cat((next_token, added_params), dim=-1) # B, 3
            out.append([token] + added_params.tolist()[0])
            idx = torch.cat((idx, next_idx.view(1, -1, PARAM_LEN)), dim=1)
            if next_token.item() == 3:
                break
        return out


In [48]:
model = Model(BATCH_SIZE).cuda()
summary.add_graph(model, get_batch('train'))
global_training_steps = 0

In [83]:
lossi = []
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
max_steps = 100_000
# print(sum(1 for p in model.parameters()))
# assert Xtrain.device == 'cuda:0', Xtrain.device
for i in range(max_steps+1):
    global_training_steps += 1
    Xb, Yb = get_batch('train')
    # forward
    #if logits.isnan().any():
    #    print("NAN ALERT")
    #    break
    logits, loss, label_loss, added_param_loss = model(Xb, Yb)
    if i % 10 == 0:
        summary.add_scalar('loss', loss.item(), global_training_steps)
        summary.add_scalar('label_loss', label_loss.item(), global_training_steps)
        summary.add_scalar('added_param_loss', added_param_loss.item(), global_training_steps)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    # track

    if i % 100 == 0:
        print(f"{i:7d}/{max_steps:7d}: {loss.item():.4f}")
    lossi.append(loss.log10().item())

      0/ 100000: 3.1253
    100/ 100000: 3.1523
    200/ 100000: 3.0603
    300/ 100000: 3.0908
    400/ 100000: 3.1333
    500/ 100000: 3.0978
    600/ 100000: 3.0702
    700/ 100000: 3.0938
    800/ 100000: 3.0388
    900/ 100000: 3.0902
   1000/ 100000: 3.0236
   1100/ 100000: 3.0595
   1200/ 100000: 3.0815
   1300/ 100000: 3.1109
   1400/ 100000: 3.0992
   1500/ 100000: 3.0373
   1600/ 100000: 3.0796
   1700/ 100000: 3.0577
   1800/ 100000: 3.1496
   1900/ 100000: 3.1160
   2000/ 100000: 3.0873
   2100/ 100000: 3.0738
   2200/ 100000: 3.0961
   2300/ 100000: 3.1164
   2400/ 100000: 3.0711
   2500/ 100000: 3.0453
   2600/ 100000: 3.0843
   2700/ 100000: 3.0884
   2800/ 100000: 3.0816
   2900/ 100000: 3.0455
   3000/ 100000: 3.1328
   3100/ 100000: 3.0615
   3200/ 100000: 3.1581
   3300/ 100000: 3.1135
   3400/ 100000: 3.1224
   3500/ 100000: 3.0880
   3600/ 100000: 3.1145
   3700/ 100000: 3.0935
   3800/ 100000: 3.0201
   3900/ 100000: 3.0578
   4000/ 100000: 3.0558
   4100/ 100000:

In [None]:
plt.figure(figsize=(16, 6))
# instead, average across 1_000 points
plt.plot(torch.tensor(lossi)[:-1].view(-1, 1000).mean(1))


In [87]:
for sample in range(1):
    gen_ctx = torch.tensor([[[1, 0.0, 0.0, 0.0]] * (CONTEXT_LEN - 1) + [[2, 0.0, 0.0, 0.0]]], device='cuda')
    generated = model.generate(gen_ctx, max_new_tokens=1_000)
    with open(f"sample-{sample}.tokens", 'wb') as f:
        cbor2.dump(generated, f)
        print("wrote", f.name)

wrote sample-0.tokens


In [None]:
@torch.no_grad()
def split_loss(split):
    model.eval()
    data = {
        "train": train_data,
        "eval": eval_data,
    }[split]
    ix = torch.randint(len(data) - CONTEXT_LEN, (BATCH_SIZE,))
    x = torch.stack([data[i:i+CONTEXT_LEN] for i in ix]).cuda()
    y = torch.stack([data[i+1:i+CONTEXT_LEN+1] for i in ix]).cuda()
    logits, loss = model(x, y)
    print(split, loss.item())
    model.train()

split_loss("train")
split_loss("eval")

In [84]:
torch.save(model, "transformer-v3-200_000_iters.pth")