In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
import torch
from torch import nn

In [3]:
import torch.nn.functional as F

In [4]:
import os

In [5]:
os.makedirs("files/maestro-v2.0.0/train", exist_ok=True)
os.makedirs("files/maestro-v2.0.0/val", exist_ok=True)
os.makedirs("files/maestro-v2.0.0/test", exist_ok=True)

In [6]:
import json
import pickle
from utils.processor import encode_midi
file="files/maestro-v2.0.0/maestro-v2.0.0.json"
with open(file,"r") as fb:
    maestro_json=json.load(fb)

In [7]:
for x in maestro_json:
    mid=rf'files/maestro-v2.0.0/{x["midi_filename"]}'
    split_type = x["split"]
    f_name = mid.split("/")[-1] + ".pickle"
    if(split_type == "train"):
        o_file = rf'files/maestro-v2.0.0/train/{f_name}'
    elif(split_type == "validation"):
        o_file = rf'files/maestro-v2.0.0/val/{f_name}'
    elif(split_type == "test"):
        o_file = rf'files/maestro-v2.0.0/test/{f_name}'
    prepped = encode_midi(mid)
    with open(o_file,"wb") as f:
        pickle.dump(prepped, f)

In [8]:
train_size=len(os.listdir('files/maestro-v2.0.0/train'))
print(f"there are {train_size} files in the train set")
val_size=len(os.listdir('files/maestro-v2.0.0/val'))
print(f"there are {val_size} files in the validation set")
test_size=len(os.listdir('files/maestro-v2.0.0/test'))
print(f"there are {test_size} files in the test set")

there are 967 files in the train set
there are 137 files in the validation set
there are 178 files in the test set


In [9]:

from utils.processor import encode_midi
import pretty_midi
from utils.processor import (_control_preprocess,
    _note_preprocess,_divide_note,
    _make_time_sift_events,_snote2events)

file='MIDI-Unprocessed_Chamber1_MID--AUDIO_07_R3_2018_wav--2'
name=rf'files/maestro-v2.0.0/2018/{file}.midi'

# encode
events=[]
notes=[]

# convert song to an easily-manipulable format
song=pretty_midi.PrettyMIDI(name)
for inst in song.instruments:
    inst_notes=inst.notes
    ctrls=_control_preprocess([ctrl for ctrl in 
       inst.control_changes if ctrl.number == 64])
    notes += _note_preprocess(ctrls, inst_notes)
dnotes = _divide_note(notes)    
dnotes.sort(key=lambda x: x.time)    
for i in range(5):
    print(dnotes[i])   

<[SNote] time: 1.0325520833333333 type: note_on, value: 74, velocity: 86>
<[SNote] time: 1.0442708333333333 type: note_on, value: 38, velocity: 77>
<[SNote] time: 1.2265625 type: note_off, value: 74, velocity: None>
<[SNote] time: 1.2395833333333333 type: note_on, value: 73, velocity: 69>
<[SNote] time: 1.2408854166666665 type: note_on, value: 37, velocity: 64>


In [10]:
max_seq = 2048

In [11]:
def create_xys(folder):
    files = [os.path.join(folder, f) for f in os.listdir(folder)]
    xys = []
    for f in files:
        with open(f, 'rb') as fb:
            music = pickle.load(fb)
        music = torch.LongTensor(music)
        x = torch.full((max_seq, ), 389, dtype=torch.long)
        y = torch.full((max_seq, ), 389, dtype=torch.long)
        length = len(music)
        if length <= max_seq:
            x[:length] = music
            y[:length-1]=music[1:]
            y[length-1]=388
        else:
            x=music[:max_seq]
            y=music[1:max_seq+1]
        xys.append((x, y))
    return xys

In [12]:
trainfolder='files/maestro-v2.0.0/train'
train=create_xys(trainfolder)

In [13]:
valfolder='files/maestro-v2.0.0/val'
testfolder='files/maestro-v2.0.0/test'
print("processing the validation set")
val=create_xys(valfolder)
print("processing the test set")
test=create_xys(testfolder)

processing the validation set
processing the test set


In [14]:
val1,_ =val[0]

In [15]:
val1.shape

torch.Size([2048])

In [16]:
from utils.processor import decode_midi
file_path="files/val1.midi"
decode_midi(val1.cpu().numpy(), file_path=file_path)

<pretty_midi.pretty_midi.PrettyMIDI at 0x7f42584e9e90>

In [17]:
from torch.utils.data import DataLoader
batch_size = 2

In [18]:
trainloader=DataLoader(train,batch_size=batch_size, shuffle=True)

In [19]:
class Config:
    def __init__(self):
        self.n_layer = 6
        self.n_head = 8
        self.n_embd = 512
        self.vocab_size = 390
        self.block_size = 2048
        self.embd_pdrop = 0.1
        self.resid_pdrop = 0.1
        self.attn_pdrop = 0.1


config = Config()
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

In [20]:
from utils.ch14util import Model

In [21]:
model=Model(config)
model.to(device)
num=sum(p.numel() for p in model.transformer.parameters())
print("number of parameters: %.2fM" % (num/1e6,))
model

number of parameters: 20.16M


Model(
  (transformer): ModuleDict(
    (wte): Embedding(390, 512)
    (wpe): Embedding(2048, 512)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-5): 6 x Block(
        (ln_1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=512, out_features=1536, bias=True)
          (c_proj): Linear(in_features=512, out_features=512, bias=True)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): ModuleDict(
          (c_fc): Linear(in_features=512, out_features=2048, bias=True)
          (c_proj): Linear(in_features=2048, out_features=512, bias=True)
          (act): GELU()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_fe

In [22]:
lr=0.0001
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_func=torch.nn.CrossEntropyLoss(ignore_index=389)

In [23]:
model.train()

Model(
  (transformer): ModuleDict(
    (wte): Embedding(390, 512)
    (wpe): Embedding(2048, 512)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-5): 6 x Block(
        (ln_1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=512, out_features=1536, bias=True)
          (c_proj): Linear(in_features=512, out_features=512, bias=True)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): ModuleDict(
          (c_fc): Linear(in_features=512, out_features=2048, bias=True)
          (c_proj): Linear(in_features=2048, out_features=512, bias=True)
          (act): GELU()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_fe

In [None]:
#from tqdm import tqdm
# for i in range(1, 101):
#     loop = tqdm(trainloader, leave=True)
#     tloss = 0
#     for idx, (x, y) in enumerate(loop):
#         x, y = x.to(device), y.to(device)
#         output = model(x)
#         loss = loss_func(output.view(-1, output.size(-1)), y.view(-1))
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
#         tloss += loss.item()
#         loop.set_postfix(loss=tloss/(idx+1), epoch=i)
#     torch.save(model.state_dict(), f"files/musicTrans.pth")

100%|██████████| 484/484 [01:00<00:00,  7.95it/s, epoch=1, loss=4.24] 
100%|██████████| 484/484 [01:00<00:00,  7.99it/s, epoch=2, loss=3.94]
100%|██████████| 484/484 [01:00<00:00,  7.99it/s, epoch=3, loss=3.81]
100%|██████████| 484/484 [01:00<00:00,  7.97it/s, epoch=4, loss=3.72] 
100%|██████████| 484/484 [01:00<00:00,  7.98it/s, epoch=5, loss=3.62] 
100%|██████████| 484/484 [01:00<00:00,  7.98it/s, epoch=6, loss=3.52] 
100%|██████████| 484/484 [01:00<00:00,  7.95it/s, epoch=7, loss=3.44] 
100%|██████████| 484/484 [01:00<00:00,  7.96it/s, epoch=8, loss=3.38] 
100%|██████████| 484/484 [01:40<00:00,  4.80it/s, epoch=9, loss=3.32]
100%|██████████| 484/484 [00:20<00:00, 23.52it/s, epoch=10, loss=3.26]   
100%|██████████| 484/484 [01:00<00:00,  7.96it/s, epoch=11, loss=3.21] 
100%|██████████| 484/484 [01:00<00:00,  7.96it/s, epoch=12, loss=3.17] 
100%|██████████| 484/484 [01:00<00:00,  7.97it/s, epoch=13, loss=3.13] 
100%|██████████| 484/484 [01:00<00:00,  7.96it/s, epoch=14, loss=3.09] 
10