In [2]:
%reload_ext autoreload
%autoreload 2

from model import GPT, GPTConfig
import torch
from utils import *

from mingpt_utils import set_seed
from mingpt_utils import sample_new, sample
import pytz
from datetime import datetime, timezone
import numpy as np
torch.cuda.empty_cache()

In [3]:
tokens = np.load('../data/formatted/tokens.npy', allow_pickle=True)
train = np.load('../data/shuffled/dataset_train.npy', allow_pickle=True)
midi_train = np.load('../data/shuffled/midi_train.npy', allow_pickle=True)

In [4]:
block_size = 2048
dataset = TokenDatasetMidi(train, midi_train, block_size, tokens)

data has 43272 pieces, 195 unique tokens.


In [5]:
epochs = 90
embedding = 256
heads = 4
layers = 4
batch_size = 32
learning_rate = 3e-5
num_workers = 4
midi_vocab = 128
token_size = len(tokens)

mconf = GPTConfig(token_size, block_size, midi_vocab, n_layer=layers, n_head=heads, n_embd=embedding)
session_model = GPT(mconf)

MODEL_NAME = "/workspace/models/model_epochs->90_heads->4_embd->256_batch->32_new_midi_embeddings"
session_model = load_model(MODEL_NAME, session_model)

03/15/2024 12:53:57 - INFO - model -   number of parameters: 3.268096e+06


Checkpoint loaded /workspace/models/model_epochs->90_heads->4_embd->256_batch->32_new_midi_embeddings


In [8]:
import formats as fmt
import voicing as vc
voicing = vc.Voicing()

def generateSample(context, duration, style, tonality, session_model, dataset, temperature=1.0, sample=True, top_k=None, top_p=0.99):
    data = fmt.getArrayOfElementsInChord(context, duration)
    data = ['<style>'] + [style] + ['Tonality'] + [tonality] + ['<start>'] + ['Form_A'] + ['|'] + data
    print(data)
    midi, _ = voicing.get_midi(data)
    print(midi)

    i = 0
    while ( i < 100):    
        x = torch.tensor([dataset.stoi[s] for s in data], dtype=torch.long)[None,...].to('cuda')
        m = torch.tensor(midi, dtype=torch.long)[None,...].to('cuda')
        
        #print(x.shape, m.shape)
        y = sample_new(session_model, x, m, 1, temperature=temperature, sample=sample, top_k=top_k, top_p=top_p)[0]
        
        data = [dataset.itos[int(i)] for i in y if dataset.itos[int(i)] != '<end>' and dataset.itos[int(i)] != '<pad>']
        
        if len(data) > 2:
            if data[-1] == data[-2]:
                print("Duplicated element: ", data[-1], data[-2])
                data = data[:-1]
                
        #print(data)
        midi, status = voicing.get_midi(data)
        if status == False:
            #erase the last element
            print("Error creating the MIDI format")
            break
        i+=1 

    #myChords = convertChordsFromOutput(data)
    print(data)
    return data

In [10]:
context = ['D']
duration = np.full(len(context), 4.0, dtype=float)
myStyle = 'Jazz'
tonality = 'D major'
data = generateSample(context, duration, myStyle, tonality, session_model, dataset, temperature=1.0, sample=True, top_k=None, top_p=0.9)

['<style>', 'Jazz', 'Tonality', 'D major', '<start>', 'Form_A', '|', '.', '4.0', 'D']
[[ 0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0]
 [50  0  0  0  0  0  0  0]]
Duplicated element:  . .
['<style>', 'Jazz', 'Tonality', 'D major', '<start>', 'Form_A', '|', '.', '4.0', 'D', 'maj7', '|', '.', '4.0', '/', 'G', '|', '.', '4.0', 'F#', 'm7', 'alter b5', '|', '.', '2.0', 'B', 'dom7', '|', '.', '2.0', 'E', 'dom7', '|', '.', '2.0', 'E', 'dom7', '|', '.', '4.0', '/', 'D', '|', '.', '2.0', 'C#', 'm7', '.', '2.0', 'Bbb', 'o7', '|', '.', '/', 'C#', '|', '.', '|', '.', '2.0', 'add 9', 'alter b5', '|', '.', '|', '.', 'add b6', '|', '.', '4.0', '|', '.', '/', 'A', '|', '.', '|', '.', '|', '.', '|', '.', '2.0', 'Ab', 'dom7', '|', '.', '4.0', 'alter b5', '|', '.', 'alter #5', '|', '.', '2.0', 'F#', 'm', '|', '.',