In [1]:
%reload_ext autoreload
%autoreload 2

from model import GPT, GPTConfig
import torch
from utils import *

from mingpt_utils import set_seed
from mingpt_utils import sample_new, sample
import pytz
from datetime import datetime, timezone
import numpy as np
torch.cuda.empty_cache()

In [2]:
tokens = np.load('../data/formatted/tokens.npy', allow_pickle=True)
train = np.load('../data/shuffled/dataset_train.npy', allow_pickle=True)
midi_train = np.load('../data/shuffled/midi_train.npy', allow_pickle=True)

In [3]:
block_size = 2048
dataset = TokenDatasetMidi(train, midi_train, block_size, tokens)

data has 43272 pieces, 195 unique tokens.


In [64]:
epochs = 90
embedding = 512
heads = 4
layers = 4
batch_size = 32
learning_rate = 3e-5
num_workers = 4
midi_vocab = 128
token_size = len(tokens)

mconf = GPTConfig(token_size, block_size, midi_vocab, n_layer=layers, n_head=heads, n_embd=embedding)
session_model = GPT(mconf)

MODEL_NAME = "../models/model_"+ "epochs->" + str(epochs) + "_heads->" + str(heads) + "_embd->" + str(embedding) + "_batch->" + str(batch_size) + "_new_midi_embeddings"

session_model = load_model(MODEL_NAME, session_model)

03/17/2024 16:55:29 - INFO - model -   number of parameters: 1.282714e+07


Checkpoint loaded ../models/model_epochs->90_heads->4_embd->512_batch->32_new_midi_embeddings


In [69]:
import formats as fmt
import voicing as vc
voicing = vc.Voicing()

def generateSample(context, duration, style, tonality, session_model, dataset, temperature=1.0, sample=True, top_k=None, top_p=0.99):
    data = fmt.getArrayOfElementsInChord(context, duration)
    data = ['<style>'] + [style] + ['Tonality'] + [tonality] + ['<start>'] + ['|'] + data
    print(data)
    midi, _ = voicing.get_midi(data)
    # for d, m in zip(data, midi):
    #     print(d, m)

    i = 0
    while ( i < 500):    
        x = torch.tensor([dataset.stoi[s] for s in data], dtype=torch.long)[None,...].to('cuda')
        m = torch.tensor(midi, dtype=torch.long)[None,...].to('cuda')
        
        #print(x.shape, m.shape)
        y = sample_new(session_model, x, m, 1, temperature=temperature, sample=sample, top_k=top_k, top_p=top_p)[0]
        
        data = [dataset.itos[int(i)] for i in y if dataset.itos[int(i)]]
        
        if len(data) > 2:
            if data[-1] == data[-2]:
                print("Duplicated element: ", data[-1], data[-2])
                data = data[:-1]
                
        if data[-2] == '.' and data[-1] not in voicing.durations:
            data = data[:-2]
            
        if data[-2] in voicing.durations and data[-1] not in voicing.all_notes:
            data = data[:-2]
            
        #print(data)
        midi, status = voicing.get_midi(data)
        if status == False:
            #erase the last element
            print("Error creating the MIDI format")
            break
        i+=1 

    #myChords = convertChordsFromOutput(data)
    print(data)
    return data

In [70]:
context = ['D']
duration = np.full(len(context), 4.0, dtype=float)
myStyle = 'Rock'
tonality = 'C major'
data = generateSample(context, duration, myStyle, tonality, session_model, dataset, temperature=1.0, sample=True, top_k=None, top_p=0.99)

['<style>', 'Rock', 'Tonality', 'C major', '<start>', '|', '.', '4.0', 'D']
['<style>', 'Rock', 'Tonality', 'C major', '<start>', '|', '.', '4.0', 'D', 'm7', '|', '.', '4.0', 'G', 'dom7', '|', '.', '4.0', 'C', 'maj7', '|', '.', '2.0', 'C', 'maj', '/', 'E', '|', '.', '4.0', 'D', 'm7', '|', '.', '4.0', 'Bb', 'maj7', '|', '.', '2.0', 'A', 'm7', '.']


In [71]:
midi, _ = voicing.convert_chords_to_voicing(data)
voicing.export_to_midi(midi, "generated_2")

song: generated_2
MIDI file created! 
---------------------------------
