In [42]:
import sys
sys.path.insert(0, '..')
# from transformer.models import DecoderOnlyModel
from data_utils.Datasets import SerializedConcatDataset, PermSerializedConcatDataset, BinarySerializer
import pickle
import torch
import numpy as np

from transformers import AutoConfig, GPT2LMHeadModel

In [43]:
def load_model(model_id='any_tonality'):
    '''
    model_id: 'jazz', 'mel_boost', 'perm', 'c_major'
    '''
    if model_id == 'c_major':
        with open('serializer_cmaj_nottingham.pkl', 'rb') as inp:
            binser = pickle.load(inp)
    else:
        with open('serializer_jazz.pkl', 'rb') as inp:
            binser = pickle.load(inp)
    
    # define model
    vocab_size = binser.vocab_size
    d_model = 256
    num_heads = 4
    num_layers = 4
    max_seq_length = binser.max_seq_length
    d_ff = 256
    dropout = 0.3
    
    # dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    dev = torch.device("cpu")
    
    config = AutoConfig.from_pretrained(
        "gpt2",
        vocab_size=vocab_size,
        n_positions=max_seq_length,
        n_layer=num_layers,
        n_head=num_heads,
        pad_token_id=binser.padding,
        bos_token_id=binser.padding,
        eos_token_id=binser.padding,
        n_embd=d_ff
    )
    transformer = GPT2LMHeadModel(config).to(dev)
    transformer = transformer.to(dev)

    if model_id == 'c_major':
        saved_model_path = '../saved_models/melboost_cmaj_nottingham_GPT2/melboost_cmaj_nottingham_GPT2.pt'
    elif model_id == 'perm':
        saved_model_path = '../saved_models/perm_jazz_GPT2/perm_jazz_GPT2.pt'
    elif model_id == 'mel_boost':
        saved_model_path = '../saved_models/melboost_jazz_GPT2/melboost_jazz_GPT2.pt'
    else:
        saved_model_path = '../saved_models/jazz_GPT2/jazz_GPT2.pt'
    transformer.load_state_dict(torch.load(saved_model_path), strict=False)
    
    transformer.eval()
    return transformer
# end load_model

In [44]:
def harmonize_melody_pcps_with_model_id(melody_pcps, model_id):
    model = load_model(model_id)
    binser2 = BinarySerializer()
    # melody_pcps to serialized
    x_mel, _ = binser2.sequence_serialization( melody_pcps, np.array([]) )
    # x_mel has 'end harmonizing' at the end - remove it
    x_mel = x_mel[:-1]
    # run model
    # dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    dev = torch.device("cpu")
    inp = torch.from_numpy( np.expand_dims(x_mel, axis=0)).to(dev)
    output = model.generate(inputs=inp, eos_token_id=29, max_new_tokens=300)
    # print('x_mel: ', inp)
    # print('output: ', output)
    # back to binary
    bin_all = binser2.indexes2binary( list(output)[0] )
    # make sure length of melody and chords are equal
    c = bin_all['chords']
    m = bin_all['melody']
    if c.shape[1] > m.shape[1]:
        c = c[:, :m.shame[1]]
    elif c.shape[1] < m.shape[1]:
        c = np.c_[c, np.zeros( (12, m.shape[1] - c.shape[1] ) )]
    return c, m

In [45]:
# test
# load data
npz_path = '../data/augmented_and_padded_data.npz'
data = np.load(npz_path)
melody_pcps = data['melody_pcps'].astype('float32')
print(melody_pcps.shape)
melody_pcp = melody_pcps[0]
# melody_pcp = np.roll(melody_pcps[0],[0,1])
print(melody_pcp.shape)

# binser2 = BinarySerializer(left_padding=False)
# x, _ = binser2.sequence_serialization( melody_pcp, np.array([]) )
# # the last element is 'st'
# print(x.shape)
# with np.printoptions(threshold=np.inf):
#     print(x)
# print(mask.shape)

# model_id = 'c_major'
model_id = 'any_tonality'

c, m = harmonize_melody_pcps_with_model_id(melody_pcp, model_id)

(5328, 129, 12)
(129, 12)
x_mel:  tensor([[ 1,  2,  4, 13,  2,  2,  3,  4, 13,  2,  2,  3, 11,  2,  3, 13,  2, 11,
          2,  3, 13,  2,  4, 13,  2,  2,  3,  4, 13,  2, 11,  2,  3,  8,  2,  8,
          2,  5,  8,  2,  7,  2,  5,  8,  2,  7,  2,  7,  2,  3,  5,  2,  7,  2,
          3,  5,  2,  5,  8,  2,  7,  2,  5,  8,  2,  7,  2,  3, 12,  2, 11,  2,
          4, 13,  2,  2,  3,  4, 13,  2,  2,  3, 11,  2,  3, 13,  2, 11,  2,  3,
         13,  2,  4, 13,  2,  2,  3,  4, 13,  2, 11,  2,  3, 11,  2, 11,  2, 13,
          2, 11,  2,  9,  2,  8,  2, 11,  2,  9,  2,  6,  8,  2,  4, 13,  2,  2,
          3,  4, 13,  2,  8,  2,  3,  4,  2,  4,  2, 15]])
output:  tensor([[ 1,  2,  4, 13,  2,  2,  3,  4, 13,  2,  2,  3, 11,  2,  3, 13,  2, 11,
          2,  3, 13,  2,  4, 13,  2,  2,  3,  4, 13,  2, 11,  2,  3,  8,  2,  8,
          2,  5,  8,  2,  7,  2,  5,  8,  2,  7,  2,  7,  2,  3,  5,  2,  7,  2,
          3,  5,  2,  5,  8,  2,  7,  2,  5,  8,  2,  7,  2,  3, 12,  2, 11,  2,
       

In [46]:
with np.printoptions(threshold=np.inf):
    print('input: ', m)
    print('output: ', c)

input:  [[0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 