In [None]:
import os
import json
import pickle
from utils.processor import encode_midi
from utils.processor import decode_midi
import pretty_midi
from utils.processor import (_control_preprocess,
    _note_preprocess,_divide_note,
    _make_time_sift_events,_snote2events)
import torch
from torch.utils.data import DataLoader
from torch import nn
from utils.modelutil import Model

"""
Archivos y modelos generados en carpeta files
"""

In [2]:
file="files/maestro-v2.0.0/maestro-v2.0.0.json"

with open(file,"r") as fb:
    maestro_json=json.load(fb)

for x in maestro_json:
    mid=rf'files/maestro-v2.0.0/{x["midi_filename"]}'
    split_type = x["split"]
    f_name = mid.split("/")[-1] + ".pickle"
    if(split_type == "train"):
        o_file = rf'files/maestro-v2.0.0/train/{f_name}'
    elif(split_type == "validation"):
        o_file = rf'files/maestro-v2.0.0/val/{f_name}'
    elif(split_type == "test"):
        o_file = rf'files/maestro-v2.0.0/test/{f_name}'
    prepped = encode_midi(mid)
    with open(o_file,"wb") as f:
        pickle.dump(prepped, f)

In [3]:
train_size=len(os.listdir('files/maestro-v2.0.0/train'))
print(f"there are {train_size} files in the train set")
val_size=len(os.listdir('files/maestro-v2.0.0/val'))
print(f"there are {val_size} files in the validation set")
test_size=len(os.listdir('files/maestro-v2.0.0/test'))
print(f"there are {test_size} files in the test set")

there are 967 files in the train set
there are 137 files in the validation set
there are 178 files in the test set


In [4]:
file='MIDI-Unprocessed_Chamber1_MID--AUDIO_07_R3_2018_wav--2'
name=rf'files/maestro-v2.0.0/2018/{file}.midi'

# encode
events=[]
notes=[]

# convert song to an easily-manipulable format
song=pretty_midi.PrettyMIDI(name)
for inst in song.instruments:
    inst_notes=inst.notes
    ctrls=_control_preprocess([ctrl for ctrl in 
       inst.control_changes if ctrl.number == 64])
    notes += _note_preprocess(ctrls, inst_notes)
dnotes = _divide_note(notes)    
dnotes.sort(key=lambda x: x.time)    
for i in range(5):
    print(dnotes[i])

<[SNote] time: 1.0325520833333333 type: note_on, value: 74, velocity: 86>
<[SNote] time: 1.0442708333333333 type: note_on, value: 38, velocity: 77>
<[SNote] time: 1.2265625 type: note_off, value: 74, velocity: None>
<[SNote] time: 1.2395833333333333 type: note_on, value: 73, velocity: 69>
<[SNote] time: 1.2408854166666665 type: note_on, value: 37, velocity: 64>


In [5]:
cur_time = 0
cur_vel = 0
for snote in dnotes:
    events += _make_time_sift_events(prev_time=cur_time,
                                     post_time=snote.time)
    events += _snote2events(snote=snote, prev_vel=cur_vel)
    cur_time = snote.time
    cur_vel = snote.velocity    
indexes=[e.to_int() for e in events]   
for i in range(15):
    print(events[i])

<Event type: time_shift, value: 99>
<Event type: time_shift, value: 2>
<Event type: velocity, value: 21>
<Event type: note_on, value: 74>
<Event type: time_shift, value: 0>
<Event type: velocity, value: 19>
<Event type: note_on, value: 38>
<Event type: time_shift, value: 17>
<Event type: note_off, value: 74>
<Event type: time_shift, value: 0>
<Event type: velocity, value: 17>
<Event type: note_on, value: 73>
<Event type: velocity, value: 16>
<Event type: note_on, value: 37>
<Event type: time_shift, value: 0>


In [6]:
max_seq=2048
def create_xys(folder):  
    files=[os.path.join(folder,f) for f in os.listdir(folder)]
    xys=[]
    for f in files:
        with open(f,"rb") as fb:
            music=pickle.load(fb)
        music=torch.LongTensor(music)      
        x=torch.full((max_seq,),389, dtype=torch.long)
        y=torch.full((max_seq,),389, dtype=torch.long)
        length=len(music)
        if length<=max_seq:
            print(length)
            x[:length]=music
            y[:length-1]=music[1:]
            y[length-1]=388    
        else:
            x=music[:max_seq]
            y=music[1:max_seq+1]   
        xys.append((x,y))
    return xys

In [7]:
trainfolder='files/maestro-v2.0.0/train'
train=create_xys(trainfolder)

586
1643
1771
5
15


In [8]:
valfolder='files/maestro-v2.0.0/val'
testfolder='files/maestro-v2.0.0/test'
print("processing the validation set")
val=create_xys(valfolder)
print("processing the test set")
test=create_xys(testfolder)

processing the validation set
processing the test set
1837


In [9]:
val1, _ = val[0]
print(val1.shape)
print(val1)

torch.Size([2048])
tensor([355, 260, 374,  ..., 294, 172, 269])


In [10]:
file_path="files/val1.midi"
decode_midi(val1.cpu().numpy(), file_path=file_path)

<pretty_midi.pretty_midi.PrettyMIDI at 0x7f5fbc1b1940>

In [11]:
train1, _ = train[0]
file_path="files/train1.midi"
decode_midi(train1.cpu().numpy(), file_path=file_path)

<pretty_midi.pretty_midi.PrettyMIDI at 0x7f5f56374890>

In [12]:
batch_size=2
trainloader=DataLoader(train,batch_size=batch_size,
                       shuffle=True)

In [13]:
class Config():
    def __init__(self):
        self.n_layer = 6
        self.n_head = 8
        self.n_embd = 512
        self.vocab_size = 390
        self.block_size = 2048 
        self.embd_pdrop = 0.1
        self.resid_pdrop = 0.1
        self.attn_pdrop = 0.1
        
# instantiate a Config() class
config=Config()
device="cuda" if torch.cuda.is_available() else "cpu"

In [14]:
model=Model(config)
model.to(device)
num=sum(p.numel() for p in model.transformer.parameters())
print("number of parameters: %.2fM" % (num/1e6,))
print(model)

number of parameters: 20.16M
Model(
  (transformer): ModuleDict(
    (wte): Embedding(390, 512)
    (wpe): Embedding(2048, 512)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-5): 6 x Block(
        (ln_1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=512, out_features=1536, bias=True)
          (c_proj): Linear(in_features=512, out_features=512, bias=True)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): ModuleDict(
          (c_fc): Linear(in_features=512, out_features=2048, bias=True)
          (c_proj): Linear(in_features=2048, out_features=512, bias=True)
          (act): GELU()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((512,), eps=1e-05, elementwise_affine=True)


In [15]:
lr=0.0001
optimizer = torch.optim.Adam(model.parameters(), lr=lr) 
# ignore the padding index
loss_func=torch.nn.CrossEntropyLoss(ignore_index=389)

In [None]:
model.train()  
for i in range(1,101):
    tloss = 0.
    for idx, (x,y) in enumerate(trainloader):
        x,y=x.to(device),y.to(device)
        output = model(x)
        loss=loss_func(output.view(-1,output.size(-1)),
                           y.view(-1))
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(),1)
        optimizer.step()
        tloss += loss.item()
    print(f'epoch {i} loss {tloss/(idx+1)}') 
torch.save(model.state_dict(),f'files/musicTransAdj.pth') 

  return F.linear(input, self.weight, self.bias)


epoch 1 loss 4.233509041061086
epoch 2 loss 3.9397175617454465
epoch 3 loss 3.808679642263523
epoch 4 loss 3.7136669843650063
epoch 5 loss 3.6235392039472405
epoch 6 loss 3.535602620317916
epoch 7 loss 3.4552007017056803
epoch 8 loss 3.386085465919873
epoch 9 loss 3.3248091965667474
epoch 10 loss 3.2722869343008876
epoch 11 loss 3.221428602194983
epoch 12 loss 3.176633269826243
epoch 13 loss 3.1344467721694755
epoch 14 loss 3.093346937628817
epoch 15 loss 3.0516386647854956
epoch 16 loss 3.006042497709763
epoch 17 loss 2.964790206802778
epoch 18 loss 2.9159620276167373
epoch 19 loss 2.8703508973121643
epoch 20 loss 2.8233438946984033
epoch 21 loss 2.7739959703989268
epoch 22 loss 2.728675560025144
epoch 23 loss 2.682811163181116
epoch 24 loss 2.639080459421331
epoch 25 loss 2.5949006553523795
epoch 26 loss 2.5536869114095513
epoch 27 loss 2.5121113153035974
epoch 28 loss 2.471019365324462
epoch 29 loss 2.4328709690531425
epoch 30 loss 2.393633859216674
epoch 31 loss 2.354637351164148
e

In [18]:
prompt, _  = test[42]
prompt = prompt.to(device)
len_prompt=250

file_path = "files/prompt.midi"
decode_midi(prompt[:len_prompt].cpu().numpy(),
            file_path=file_path)

<pretty_midi.pretty_midi.PrettyMIDI at 0x7f5f56304980>

In [19]:
# define the softmax function for later use
softmax=torch.nn.Softmax(dim=-1)
def sample(prompt,seq_length=1000,temperature=1):
    # create input to feed to the transformer
    gen_seq=torch.full((1,seq_length),389,dtype=torch.long).to(device)
    idx=len(prompt)
    gen_seq[..., :idx]=prompt.type(torch.long).to(device)
    while(idx < seq_length):
        y=softmax(model(gen_seq[..., :idx])/temperature)[...,:388]
        probs=y[:, idx-1, :]
        distrib=torch.distributions.categorical.Categorical(probs=probs)
        next_token=distrib.sample()
        gen_seq[:, idx]=next_token
        idx+=1
    return gen_seq[:, :idx]

In [22]:
file_path = "files/prompt.midi"
prompt = torch.tensor(encode_midi(file_path))
generated_music=sample(prompt, seq_length=1000,temperature=2)

In [23]:
music_data = generated_music[0].cpu().numpy()
file_path = 'files/musicTrans4.midi'
decode_midi(music_data, file_path=file_path)

info removed pitch: 82
info removed pitch: 86
info removed pitch: 65
info removed pitch: 89
info removed pitch: 85
info removed pitch: 25
info removed pitch: 44
info removed pitch: 43
info removed pitch: 85
info removed pitch: 84
info removed pitch: 49
info removed pitch: 64
info removed pitch: 104
info removed pitch: 97
info removed pitch: 103
info removed pitch: 50
info removed pitch: 32
info removed pitch: 24
info removed pitch: 57
info removed pitch: 92
info removed pitch: 32
info removed pitch: 41
info removed pitch: 28


<pretty_midi.pretty_midi.PrettyMIDI at 0x7f5debeb8530>