In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
import torch
from torch import nn

In [None]:
import torch.nn.functional as F

In [None]:
import os

In [None]:
# os.makedirs("files/maestro-v2.0.0/train", exist_ok=True)
# os.makedirs("files/maestro-v2.0.0/val", exist_ok=True)
# os.makedirs("files/maestro-v2.0.0/test", exist_ok=True)

In [None]:
# import json
import pickle
# from utils.processor import encode_midi
# file="files/maestro-v2.0.0/maestro-v2.0.0.json"
# with open(file,"r") as fb:
#     maestro_json=json.load(fb)

In [None]:
# for x in maestro_json:
#     mid=rf'files/maestro-v2.0.0/{x["midi_filename"]}'
#     split_type = x["split"]
#     f_name = mid.split("/")[-1] + ".pickle"
#     if(split_type == "train"):
#         o_file = rf'files/maestro-v2.0.0/train/{f_name}'
#     elif(split_type == "validation"):
#         o_file = rf'files/maestro-v2.0.0/val/{f_name}'
#     elif(split_type == "test"):
#         o_file = rf'files/maestro-v2.0.0/test/{f_name}'
#     prepped = encode_midi(mid)
#     with open(o_file,"wb") as f:
#         pickle.dump(prepped, f)

In [None]:
train_size=len(os.listdir('files/maestro-v2.0.0/train'))
print(f"there are {train_size} files in the train set")
val_size=len(os.listdir('files/maestro-v2.0.0/val'))
print(f"there are {val_size} files in the validation set")
test_size=len(os.listdir('files/maestro-v2.0.0/test'))
print(f"there are {test_size} files in the test set")

In [None]:

from utils.processor import encode_midi
import pretty_midi
from utils.processor import (_control_preprocess,
    _note_preprocess,_divide_note,
    _make_time_sift_events,_snote2events)

file='MIDI-Unprocessed_Chamber1_MID--AUDIO_07_R3_2018_wav--2'
name=rf'files/maestro-v2.0.0/2018/{file}.midi'

# encode
events=[]
notes=[]

# convert song to an easily-manipulable format
song=pretty_midi.PrettyMIDI(name)
for inst in song.instruments:
    inst_notes=inst.notes
    ctrls=_control_preprocess([ctrl for ctrl in 
       inst.control_changes if ctrl.number == 64])
    notes += _note_preprocess(ctrls, inst_notes)
dnotes = _divide_note(notes)    
dnotes.sort(key=lambda x: x.time)    
for i in range(5):
    print(dnotes[i])   

In [None]:
max_seq = 2048

In [None]:
def create_xys(folder):
    files = [os.path.join(folder, f) for f in os.listdir(folder)]
    xys = []
    for f in files:
        with open(f, 'rb') as fb:
            music = pickle.load(fb)
        music = torch.LongTensor(music)
        x = torch.full((max_seq, ), 389, dtype=torch.long)
        y = torch.full((max_seq, ), 389, dtype=torch.long)
        length = len(music)
        if length <= max_seq:
            x[:length] = music
            y[:length-1]=music[1:]
            y[length-1]=388
        else:
            x=music[:max_seq]
            y=music[1:max_seq+1]
        xys.append((x, y))
    return xys

In [None]:
trainfolder='files/maestro-v2.0.0/train'
train=create_xys(trainfolder)

In [None]:
valfolder='files/maestro-v2.0.0/val'
testfolder='files/maestro-v2.0.0/test'
print("processing the validation set")
val=create_xys(valfolder)
print("processing the test set")
test=create_xys(testfolder)

In [None]:
val1,_ =val[0]

In [None]:
val1.shape

In [None]:
from utils.processor import decode_midi
file_path="files/val1.midi"
decode_midi(val1.cpu().numpy(), file_path=file_path)

In [None]:
from torch.utils.data import DataLoader
batch_size = 2

In [None]:
trainloader=DataLoader(train,batch_size=batch_size, shuffle=True)

In [None]:
class Config:
    def __init__(self):
        self.n_layer = 6
        self.n_head = 8
        self.n_embd = 512
        self.vocab_size = 390
        self.block_size = 2048
        self.embd_pdrop = 0.1
        self.resid_pdrop = 0.1
        self.attn_pdrop = 0.1


config = Config()
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

In [None]:
from utils.ch14util import Model

In [None]:
model=Model(config)
model.to(device)
num=sum(p.numel() for p in model.transformer.parameters())
print("number of parameters: %.2fM" % (num/1e6,))
model

In [None]:
lr=0.0001
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_func=torch.nn.CrossEntropyLoss(ignore_index=389)

In [None]:
model.train()

In [24]:
from tqdm import tqdm
for i in range(1, 101):
    loop = tqdm(trainloader, leave=True)
    tloss = 0
    for idx, (x, y) in enumerate(loop):
        x, y = x.to(device), y.to(device)
        output = model(x)
        loss = loss_func(output.view(-1, output.size(-1)), y.view(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        tloss += loss.item()
        loop.set_postfix(loss=tloss/(idx+1), epoch=i)
    torch.save(model.state_dict(), f"files/musicTrans.pth")

100%|██████████| 484/484 [07:52<00:00,  1.02it/s, epoch=54, loss=1.6]
100%|██████████| 484/484 [07:52<00:00,  1.02it/s, epoch=55, loss=1.58]
100%|██████████| 484/484 [07:53<00:00,  1.02it/s, epoch=56, loss=1.55]
100%|██████████| 484/484 [07:52<00:00,  1.02it/s, epoch=57, loss=1.53]
100%|██████████| 484/484 [07:53<00:00,  1.02it/s, epoch=58, loss=1.51]
100%|██████████| 484/484 [07:52<00:00,  1.02it/s, epoch=59, loss=1.48]
100%|██████████| 484/484 [07:53<00:00,  1.02it/s, epoch=60, loss=1.46]
100%|██████████| 484/484 [07:53<00:00,  1.02it/s, epoch=61, loss=1.43]
100%|██████████| 484/484 [07:53<00:00,  1.02it/s, epoch=62, loss=1.41]
100%|██████████| 484/484 [07:52<00:00,  1.02it/s, epoch=63, loss=1.39]
100%|██████████| 484/484 [07:53<00:00,  1.02it/s, epoch=64, loss=1.37]
100%|██████████| 484/484 [07:52<00:00,  1.02it/s, epoch=65, loss=1.35]
100%|██████████| 484/484 [07:53<00:00,  1.02it/s, epoch=66, loss=1.33]
100%|██████████| 484/484 [07:53<00:00,  1.02it/s, epoch=67, loss=1.31]
100%|██