In [1]:
import sys
sys.path.insert(0, '..')
# from transformer.models import DecoderOnlyModel
from data_utils.Datasets import SerializedConcatDataset, PermSerializedConcatDataset, BinarySerializer
import pickle
import torch
import numpy as np

from transformers import AutoConfig, GPT2LMHeadModel

In [2]:
with open('serializer_jazz.pkl', 'rb') as inp:
    binser = pickle.load(inp)

binser2 = BinarySerializer()

# define model
vocab_size = binser.vocab_size
d_model = 256
num_heads = 4
num_layers = 4
d_ff = 256
max_seq_length = binser.max_seq_length
dropout = 0.3

# dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
dev = torch.device("cpu")

config = AutoConfig.from_pretrained(
    "gpt2",
    vocab_size=vocab_size,
    n_positions=max_seq_length,
    n_layer=num_layers,
    n_head=num_heads,
    pad_token_id=binser.padding,
    bos_token_id=binser.padding,
    eos_token_id=binser.padding,
    n_embd=d_ff
)
transformer = GPT2LMHeadModel(config).to(dev)

transformer = transformer.to(dev)

saved_model_path = '../saved_models/melboost_jazz_GPT2/melboost_jazz_GPT2.pt'
transformer.load_state_dict(torch.load(saved_model_path), strict=False)

transformer.eval()

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(30, 256)
    (wpe): Embedding(1063, 256)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-3): 4 x GPT2Block(
        (ln_1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=256, out_features=30, bias=False)
)

In [3]:
# load data to test
# load data
npz_path = '../data/augmented_and_padded_data.npz'
dataset = SerializedConcatDataset(npz_path, pad_to_length=max_seq_length, left_padding=False)

In [4]:
x, mask = dataset.__getitem__(0)
# find start harmonization index
idx = np.where(x == binser.start_harmonizing)[0][0]
x_mel = x[:idx+1]
with np.printoptions(threshold=np.inf):
    print(x_mel)
#     print(mask)
#     print(x)

[ 1  2  3 12  2 14  2  3 12  2 14  2 10 14  2 12  2 10 14  2 12  2  3 12
  2 14  2  3 12  2 10 14  2  7  2  7  2  4  7  2  6  2  4  7  2  6  2  6
 14  2  4  2  6 14  2  4  2  4  7  2  6  2  4  7  2  6 14  2 11  2 10  2
  3 12  2 14  2  3 12  2 14  2 10 14  2 12  2 10 14  2 12  2  3 12  2 14
  2  3 12  2 10 14  2 10  2 10  2 12  2 10  2  8  2  7  2 10  2  8  2  5
  7  2  3 12  2 14  2  3 12  2  7 14  2  3  2  3 15]


In [5]:
# this is not generation, look below
inp = torch.from_numpy( np.expand_dims(x_mel, axis=0)).to(dev)
output = transformer(inp, attention_mask=inp != 0, output_attentions=True)
prediction = output.logits.argmax(dim=2, keepdim=True).squeeze()
z = prediction.cpu().numpy()
with np.printoptions(threshold=np.inf):
    print('output.logits: ', output.logits[0,-1,:])
    print('prediction.shape: ', prediction.shape)
    print('input: ', inp.cpu().numpy())
    print('z: ', z)
curr_idx = np.where(inp.cpu() == binser.start_harmonizing)[1][0]
print(curr_idx)

output.logits:  tensor([-25.4960, -12.1390, -13.3605,  -4.6104,  -8.6326, -14.7351,  -7.0679,
        -18.4973,  -7.6556, -21.0319,  -9.1271, -13.6450,  -2.3931, -14.3436,
        -28.7538, -13.2978,  35.4094,  -2.6035, -14.3177, -12.4976, -14.6406,
        -13.6371, -12.8388, -15.9796, -11.3975, -16.5874,   4.0534,  -5.8872,
         -0.2046, -20.1894], grad_fn=<SliceBackward0>)
prediction.shape:  torch.Size([137])
input:  [[ 1  2  3 12  2 14  2  3 12  2 14  2 10 14  2 12  2 10 14  2 12  2  3 12
   2 14  2  3 12  2 10 14  2  7  2  7  2  4  7  2  6  2  4  7  2  6  2  6
  14  2  4  2  6 14  2  4  2  4  7  2  6  2  4  7  2  6 14  2 11  2 10  2
   3 12  2 14  2  3 12  2 14  2 10 14  2 12  2 10 14  2 12  2  3 12  2 14
   2  3 12  2 10 14  2 10  2 10  2 12  2 10  2  8  2  7  2 10  2  8  2  5
   7  2  3 12  2 14  2  3 12  2  7 14  2  3  2  3 15]]
z:  [15 15 18 17 15 16 15 19 16 15 16 15 17 16 15 16 15 16 16 15 16 15 19 16
 15 16 15 19 16 15 17 16 15 17 15 24 15 21 22 15 21 15 21 22 15 24 15 

In [6]:
# generation point-by-point
inp = torch.from_numpy( np.expand_dims(x_mel, axis=0)).to(dev)
# when knowing that we don't need to stop
output = transformer.generate(inputs=inp, max_new_tokens=1, \
                              do_sample=True, bad_words_ids=[[binser.end_harmonizing]])
# when knowing that stopping is possible
# output = transformer.generate(inputs=inp, eos_token_id=29, max_new_tokens=1, \
#                               do_sample=True)
print('input shape: ', inp.shape)
print('output shape: ', output.shape)
print('input: ', inp)
print('output: ', output)
# print(1+(output[-1][-1]==binser2.chord_segment_separator).item())

input shape:  torch.Size([1, 137])
output shape:  torch.Size([1, 138])
input:  tensor([[ 1,  2,  3, 12,  2, 14,  2,  3, 12,  2, 14,  2, 10, 14,  2, 12,  2, 10,
         14,  2, 12,  2,  3, 12,  2, 14,  2,  3, 12,  2, 10, 14,  2,  7,  2,  7,
          2,  4,  7,  2,  6,  2,  4,  7,  2,  6,  2,  6, 14,  2,  4,  2,  6, 14,
          2,  4,  2,  4,  7,  2,  6,  2,  4,  7,  2,  6, 14,  2, 11,  2, 10,  2,
          3, 12,  2, 14,  2,  3, 12,  2, 14,  2, 10, 14,  2, 12,  2, 10, 14,  2,
         12,  2,  3, 12,  2, 14,  2,  3, 12,  2, 10, 14,  2, 10,  2, 10,  2, 12,
          2, 10,  2,  8,  2,  7,  2, 10,  2,  8,  2,  5,  7,  2,  3, 12,  2, 14,
          2,  3, 12,  2,  7, 14,  2,  3,  2,  3, 15]])
output:  tensor([[ 1,  2,  3, 12,  2, 14,  2,  3, 12,  2, 14,  2, 10, 14,  2, 12,  2, 10,
         14,  2, 12,  2,  3, 12,  2, 14,  2,  3, 12,  2, 10, 14,  2,  7,  2,  7,
          2,  4,  7,  2,  6,  2,  4,  7,  2,  6,  2,  6, 14,  2,  4,  2,  6, 14,
          2,  4,  2,  4,  7,  2,  6,  2,  4,  7

In [7]:
inp = torch.from_numpy( np.expand_dims(x_mel, axis=0)).to(dev)
output = transformer.generate(inputs=inp, eos_token_id=binser.end_harmonizing, max_new_tokens=500)
# , bad_words_ids=[[binser.end_harmonizing]] does not work when bad word is given as eos_token_id
print('input shape: ', inp.shape)
print('true output.shape:', x[x!=0].shape)
print('output shape: ', output.shape)

input shape:  torch.Size([1, 137])
true output.shape: (414,)
output shape:  torch.Size([1, 400])


In [8]:
with np.printoptions(threshold=np.inf):
    print('input: ', inp.cpu().numpy())
    print('true output:', x[x!=0])
    print('output: ', output.cpu().numpy())

input:  [[ 1  2  3 12  2 14  2  3 12  2 14  2 10 14  2 12  2 10 14  2 12  2  3 12
   2 14  2  3 12  2 10 14  2  7  2  7  2  4  7  2  6  2  4  7  2  6  2  6
  14  2  4  2  6 14  2  4  2  4  7  2  6  2  4  7  2  6 14  2 11  2 10  2
   3 12  2 14  2  3 12  2 14  2 10 14  2 12  2 10 14  2 12  2  3 12  2 14
   2  3 12  2 10 14  2 10  2 10  2 12  2 10  2  8  2  7  2 10  2  8  2  5
   7  2  3 12  2 14  2  3 12  2  7 14  2  3  2  3 15]]
true output: [ 1  2  3 12  2 14  2  3 12  2 14  2 10 14  2 12  2 10 14  2 12  2  3 12
  2 14  2  3 12  2 10 14  2  7  2  7  2  4  7  2  6  2  4  7  2  6  2  6
 14  2  4  2  6 14  2  4  2  4  7  2  6  2  4  7  2  6 14  2 11  2 10  2
  3 12  2 14  2  3 12  2 14  2 10 14  2 12  2 10 14  2 12  2  3 12  2 14
  2  3 12  2 10 14  2 10  2 10  2 12  2 10  2  8  2  7  2 10  2  8  2  5
  7  2  3 12  2 14  2  3 12  2  7 14  2  3  2  3 15 16 17 19 22 26 16 19
 22 24 28 16 17 19 22 26 16 19 22 24 28 16 17 21 24 28 16 17 20 22 26 16
 19 21 24 28 16 18 21 24 26 16 17 19 22 26 

In [9]:
inp = torch.from_numpy( np.expand_dims(x_mel, axis=0)).to(dev)
# point by point generation loop
# find number of chords - equal to melody segments
num_melody_segments = (inp == binser2.melody_segment_separator).sum().item()
num_chord_segments = 0
generate_mode = True
chords_number_ok = False
while generate_mode:
    if not chords_number_ok:
        output = transformer.generate(inputs=inp, max_new_tokens=1, \
                do_sample=True, bad_words_ids=[[binser2.end_harmonizing]])
        num_chord_segments += (output[-1][-1] == binser2.chord_segment_separator).item()
    else:
        # prevent new chord segment from happening
        output = transformer.generate(inputs=inp, max_new_tokens=1, \
                do_sample=True, eos_token_id=binser2.end_harmonizing, bad_words_ids=[[binser2.chord_segment_separator]])
        generate_mode = not output[-1][-1].item() == binser2.end_harmonizing
    print(output.shape, num_chord_segments, num_melody_segments, chords_number_ok, output[-1][-1].item(), end='\r')
    # trim start for max length if needed
    if len(output) > binser.max_seq_length:
        inp = output[:,:-binser.max_seq_length]
    else:
        inp = output
    if num_chord_segments >= num_melody_segments:
        chords_number_ok = True

torch.Size([1, 421]) 55 55 True 296

In [10]:
print(output)
print('true output:', x[x!=0])

tensor([[ 1,  2,  3, 12,  2, 14,  2,  3, 12,  2, 14,  2, 10, 14,  2, 12,  2, 10,
         14,  2, 12,  2,  3, 12,  2, 14,  2,  3, 12,  2, 10, 14,  2,  7,  2,  7,
          2,  4,  7,  2,  6,  2,  4,  7,  2,  6,  2,  6, 14,  2,  4,  2,  6, 14,
          2,  4,  2,  4,  7,  2,  6,  2,  4,  7,  2,  6, 14,  2, 11,  2, 10,  2,
          3, 12,  2, 14,  2,  3, 12,  2, 14,  2, 10, 14,  2, 12,  2, 10, 14,  2,
         12,  2,  3, 12,  2, 14,  2,  3, 12,  2, 10, 14,  2, 10,  2, 10,  2, 12,
          2, 10,  2,  8,  2,  7,  2, 10,  2,  8,  2,  5,  7,  2,  3, 12,  2, 14,
          2,  3, 12,  2,  7, 14,  2,  3,  2,  3, 15, 16, 17, 19, 22, 26, 16, 19,
         22, 24, 28, 16, 17, 19, 22, 26, 16, 19, 22, 24, 28, 16, 17, 21, 24, 28,
         16, 17, 21, 24, 26, 16, 20, 23, 25, 28, 16, 18, 22, 25, 28, 16, 17, 21,
         24, 27, 16, 17, 20, 22, 26, 16, 19, 21, 24, 28, 16, 17, 19, 22, 26, 16,
         19, 22, 24, 28, 16, 19, 21, 24, 28, 16, 18, 21, 24, 26, 16, 17, 19, 22,
         26, 16, 19, 22, 24,

In [11]:
bin_info = binser2.indexes2binary(output[0])

In [12]:
print(bin_info['melody'].shape)
print(bin_info['chords'].shape)
print(bin_info['error_messages'])

(54, 12)
(54, 12)
[]
