In [1]:
# USER OPTIONS
# define tokenizer name - should be one among the keys in the cell below
tokenizer_name = 'ChordSymbolTokenizer' # or any other name from the keys in tokenizers dictionary
# tokenizer_name = 'RootTypeTokenizer'
# tokenizer_name = 'PitchClassTokenizer'
# folder to xmls
val_dir = '/media/maindisk/maximos/data/hooktheory_test'
# val_dir = '/media/maindisk/maximos/data/gjt_melodies/Library_melodies'
# val_dir = '/media/datadisk/datasets/gjt_melodies/Library_melodies'
# val_dir = '/media/maximos/9C33-6BBD/data/gjt_melodies/Library_melodies'

# define batch size depending on GPU availability / status
batchsize = 16
# select device name - could be 'cpu', 'cuda', 'coda:0', 'cuda:1'...
device_name = 'cuda'

In [2]:
from data_utils import SeparatedMelHarmMarkovDataset
import os
import numpy as np
from harmony_tokenizers_m21 import ChordSymbolTokenizer, RootTypeTokenizer, \
    PitchClassTokenizer, RootPCTokenizer, GCTRootPCTokenizer, \
    GCTSymbolTokenizer, GCTRootTypeTokenizer, MelodyPitchTokenizer, \
    MergedMelHarmTokenizer
import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader
import torch.nn.functional as F
from transformers import BartForConditionalGeneration, BartConfig, DataCollatorForSeq2Seq
from tqdm import tqdm
from models import TransGraphVAE
import csv

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
tokenizers = {
    'ChordSymbolTokenizer': ChordSymbolTokenizer,
    'RootTypeTokenizer': RootTypeTokenizer,
    'PitchClassTokenizer': PitchClassTokenizer,
    'RootPCTokenizer': RootPCTokenizer,
    'GCTRootPCTokenizer': GCTRootPCTokenizer,
    'GCTSymbolTokenizer': GCTSymbolTokenizer,
    'GCTRootTypeTokenizer': GCTRootTypeTokenizer
}

In [4]:
melody_tokenizer = MelodyPitchTokenizer.from_pretrained('saved_tokenizers/MelodyPitchTokenizer')
harmony_tokenizer = tokenizers[tokenizer_name].from_pretrained('saved_tokenizers/' + tokenizer_name)

tokenizer = MergedMelHarmTokenizer(melody_tokenizer, harmony_tokenizer)

bart_path = 'saved_models/bart/' + tokenizer_name + '/' + tokenizer_name + '.pt'

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
bart_config = BartConfig(
    vocab_size=len(tokenizer.vocab),
    pad_token_id=tokenizer.pad_token_id,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    decoder_start_token_id=tokenizer.bos_token_id,
    forced_eos_token_id=tokenizer.eos_token_id,
    max_position_embeddings=512,
    encoder_layers=8,
    encoder_attention_heads=8,
    encoder_ffn_dim=512,
    decoder_layers=8,
    decoder_attention_heads=8,
    decoder_ffn_dim=512,
    d_model=512,
    encoder_layerdrop=0.3,
    decoder_layerdrop=0.3,
    dropout=0.3
)

bart = BartForConditionalGeneration(bart_config)

if torch.cuda.is_available():
    checkpoint = torch.load(bart_path, weights_only=True)
else:
    checkpoint = torch.load(bart_path, map_location="cpu", weights_only=True)
bart.load_state_dict(checkpoint)

bart.eval()
bart.to(device)

  return self.fget.__get__(instance, owner)()


BartForConditionalGeneration(
  (model): BartModel(
    (shared): BartScaledWordEmbedding(545, 512, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): BartScaledWordEmbedding(545, 512, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(514, 512)
      (layers): ModuleList(
        (0-7): 8 x BartEncoderLayer(
          (self_attn): BartSdpaAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=512, out_features=512, bias=True)
          (fc2): Linear(in_features=512, out_features=512, bias=True)
          (final_layer_norm

In [7]:
# enc1 = bart.model.encoder
# enc2 = bart.get_encoder()
# dec1 = bart.model.decoder
# dec2 = bart.get_decoder()

In [8]:
test_dir = '/mnt/ssd2/maximos/data/hooktheory_test'
test_dataset = SeparatedMelHarmMarkovDataset(test_dir, tokenizer, max_length=512, num_bars=64)

# Data collator for BART
def create_data_collator(tokenizer, model):
    return DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding=True)
# end create_data_collator

In [9]:
collator = create_data_collator(tokenizer, model=bart)
valloader = DataLoader(test_dataset, batch_size=1, shuffle=True, collate_fn=collator)

In [10]:
def sample_with_temperature(logits, temperature=1.0):
    # Scale logits by temperature
    logits = logits / temperature
    # Apply softmax to get probabilities
    probs = F.softmax(logits, dim=-1)

    # Flatten the logits if necessary
    batch_size, seq_len, vocab_size = probs.shape
    probs = probs.view(-1, vocab_size)  # Merge batch_size and seq_len dimensions
    
    # Sample from the probability distribution
    sampled_tokens = torch.multinomial(probs, num_samples=1)
    
    # Reshape back to [batch_size, seq_len, 1]
    sampled_tokens = sampled_tokens.view(batch_size, seq_len, 1)

    # # Sample from the probability distribution
    # sampled_token = torch.multinomial(probs, num_samples=1)
    return sampled_tokens

In [11]:
b = next(iter(valloader))

  return self.iter().getElementsByClass(classFilterList)
  batch["labels"] = torch.tensor(batch["labels"], dtype=torch.int64)


In [12]:
input_ids = b['input_ids'].to(device)
print(input_ids)

tensor([[  2,   6, 180,  95,  50,  99,  57, 103,  62, 109,  57, 111,  50, 115,
          58, 119,  62, 125,  58,   6,  95,  48,  99,  55, 103,  62, 109,  55,
         111,  50, 115,  57, 119,  62, 125,  57,   6,  95,  50,  99,  58, 103,
          62, 109,  58, 111,  50, 115,  57, 119,  62, 125,  57,   6,  95,  48,
          99,  55, 103,  62, 109,  55, 111,  50, 115,  55, 119,  62, 125,  55,
           6,  95,  50,  99,  57, 103,  62, 109,  57, 111,  50, 115,  58, 119,
          62, 125,  58,   6,  95,  48,  99,  55, 103,  62, 109,  55, 111,  50,
         115,  57, 119,  62, 125,  57,   6,  95,  50,  99,  58, 103,  62, 109,
          57, 111,  50, 115,  57, 119,  62, 125,  57,   6,  95,  48,  99,  55,
         103,  62, 109,  55, 111,  50]], device='cuda:0')


In [13]:
outs = bart(input_ids)

In [14]:
print(outs.logits.shape)

torch.Size([1, 132, 545])


In [15]:
print(type(bart))
inputs = bart.prepare_inputs_for_generation(input_ids)
print(inputs)

<class 'transformers.models.bart.modeling_bart.BartForConditionalGeneration'>
{'input_ids': None, 'encoder_outputs': None, 'past_key_values': None, 'decoder_input_ids': tensor([[  2,   6, 180,  95,  50,  99,  57, 103,  62, 109,  57, 111,  50, 115,
          58, 119,  62, 125,  58,   6,  95,  48,  99,  55, 103,  62, 109,  55,
         111,  50, 115,  57, 119,  62, 125,  57,   6,  95,  50,  99,  58, 103,
          62, 109,  58, 111,  50, 115,  57, 119,  62, 125,  57,   6,  95,  48,
          99,  55, 103,  62, 109,  55, 111,  50, 115,  55, 119,  62, 125,  55,
           6,  95,  50,  99,  57, 103,  62, 109,  57, 111,  50, 115,  58, 119,
          62, 125,  58,   6,  95,  48,  99,  55, 103,  62, 109,  55, 111,  50,
         115,  57, 119,  62, 125,  57,   6,  95,  50,  99,  58, 103,  62, 109,
          57, 111,  50, 115,  57, 119,  62, 125,  57,   6,  95,  48,  99,  55,
         103,  62, 109,  55, 111,  50]], device='cuda:0'), 'attention_mask': None, 'decoder_attention_mask': None, 'head

In [28]:
max_length = 500
temperature = 1.0
top_k = 50

In [None]:
output_generate = bart.generate(
    input_ids,
    max_length=max_length,
    do_sample=True,
    temperature=temperature,
    top_k=top_k,
    return_dict_in_generate=True,
    output_scores=True
)

In [33]:
print(output_generate['sequences'])

tensor([[  2, 196,   6,  95, 314, 109, 314, 111, 314, 119, 314, 125, 314,   6,
          95, 459, 109, 207, 111, 524, 119, 517, 125, 314,   6,  95, 314, 109,
         314, 111, 314, 119, 314, 125, 314,   6,  95, 459, 109, 207, 111, 524,
         125, 314,   6,  95, 314, 109, 314, 111, 314, 119, 314, 125, 314,   6,
          95, 459, 109, 207, 111, 524, 119, 517, 125, 314,   6,  95, 314, 109,
         314, 111, 314, 119, 314, 125, 314,   6,  95, 459, 109, 207, 111, 524,
         119, 517, 125, 314,   3]], device='cuda:0')


In [22]:
generate_tokens = []
for i in output_generate['sequences'][0]:
    generate_tokens.append( tokenizer.ids_to_tokens[ int(i) ].replace(' ','x') )
print(generate_tokens)

['<s>', '<h>', '<bar>', 'position_0x00', 'E:min', 'position_1x75', 'E:min', 'position_2x00', 'E:min', 'position_3x00', 'E:min', 'position_3x75', 'E:min', '<bar>', 'position_0x00', 'A:min', 'position_1x75', 'C:maj6', 'position_2x00', 'B:min7', 'position_3x00', 'B:min', 'position_3x75', 'E:min', '<bar>', 'position_0x00', 'E:min', 'position_1x75', 'E:min', 'position_2x00', 'E:min', 'position_3x00', 'E:min', 'position_3x75', 'E:min', '<bar>', 'position_0x00', 'A:min', 'position_1x75', 'C:maj6', 'position_2x00', 'B:min7', 'position_3x75', 'E:min', '<bar>', 'position_0x00', 'E:min', 'position_1x75', 'E:min', 'position_2x00', 'E:min', 'position_3x00', 'E:min', 'position_3x75', 'E:min', '<bar>', 'position_0x00', 'A:min', 'position_1x75', 'C:maj6', 'position_2x00', 'B:min7', 'position_3x00', 'B:min', 'position_3x75', 'E:min', '<bar>', 'position_0x00', 'E:min', 'position_1x75', 'E:min', 'position_2x00', 'E:min', 'position_3x00', 'E:min', 'position_3x75', 'E:min', '<bar>', 'position_0x00', 'A:min

In [24]:
encoder = bart.model.encoder
decoder = bart.model.decoder

In [26]:
with torch.no_grad():
    encoder_outputs = encoder(input_ids)
    encoder_hidden_states = encoder_outputs.last_hidden_state

decoder_input_ids = torch.tensor([[tokenizer.bos_token_id]], device=device)

In [29]:
for _ in range(max_length):
    with torch.no_grad():
        decoder_outputs = decoder(
            input_ids=decoder_input_ids,
            encoder_hidden_states=encoder_hidden_states
        )
        logits = bart.lm_head(decoder_outputs.last_hidden_state[:, -1, :])  # Get logits for last token

        # Apply temperature scaling
        logits = logits / temperature

        # Apply top-k sampling
        if top_k > 0:
            top_k_values, top_k_indices = torch.topk(logits, top_k)
            logits = torch.full_like(logits, float('-inf'))
            logits.scatter_(1, top_k_indices, top_k_values)

        # Convert logits to probabilities and sample
        probabilities = torch.nn.functional.softmax(logits, dim=-1)
        next_token = torch.multinomial(probabilities, num_samples=1)

        # Append to decoder input
        decoder_input_ids = torch.cat((decoder_input_ids, next_token), dim=1)

        # Stop if EOS token is generated
        if next_token.item() == tokenizer.eos_token_id:
            break

In [30]:
print(decoder_input_ids)

tensor([[  2, 196,   6,  95, 342,   6,  95, 400,   6,  95, 459,   6,  95, 314,
           6,  95, 342,   6,  95, 400,   6,  95, 459,   6,  95, 314,   3]],
       device='cuda:0')


In [32]:
ar_tokens = []
for i in decoder_input_ids[0]:
    ar_tokens.append( tokenizer.ids_to_tokens[ int(i) ].replace(' ','x') )
print(generate_tokens)

['<s>', '<h>', '<bar>', 'position_0x00', 'E:min', 'position_1x75', 'E:min', 'position_2x00', 'E:min', 'position_3x00', 'E:min', 'position_3x75', 'E:min', '<bar>', 'position_0x00', 'A:min', 'position_1x75', 'C:maj6', 'position_2x00', 'B:min7', 'position_3x00', 'B:min', 'position_3x75', 'E:min', '<bar>', 'position_0x00', 'E:min', 'position_1x75', 'E:min', 'position_2x00', 'E:min', 'position_3x00', 'E:min', 'position_3x75', 'E:min', '<bar>', 'position_0x00', 'A:min', 'position_1x75', 'C:maj6', 'position_2x00', 'B:min7', 'position_3x75', 'E:min', '<bar>', 'position_0x00', 'E:min', 'position_1x75', 'E:min', 'position_2x00', 'E:min', 'position_3x00', 'E:min', 'position_3x75', 'E:min', '<bar>', 'position_0x00', 'A:min', 'position_1x75', 'C:maj6', 'position_2x00', 'B:min7', 'position_3x00', 'B:min', 'position_3x75', 'E:min', '<bar>', 'position_0x00', 'E:min', 'position_1x75', 'E:min', 'position_2x00', 'E:min', 'position_3x00', 'E:min', 'position_3x75', 'E:min', '<bar>', 'position_0x00', 'A:min

In [None]:
# def custom_autoregressive_sampling(model, input_ids, max_length, temperature, top_k):
#     generated = input_ids
#     for _ in range(max_length - input_ids.shape[1]):
#         outputs = model(generated)
#         logits = outputs.logits[:, -1, :]  # Take last token logits
        
#         # Apply temperature scaling
#         logits = logits / temperature

#         # Apply top-k sampling
#         if top_k > 0:
#             top_k_values, top_k_indices = torch.topk(logits, top_k)
#             logits = torch.full_like(logits, float('-inf'))
#             logits.scatter_(1, top_k_indices, top_k_values)

#         # Convert logits to probabilities and sample
#         probabilities = torch.nn.functional.softmax(logits, dim=-1)
#         next_token = torch.multinomial(probabilities, num_samples=1)

#         # Append to generated sequence
#         generated = torch.cat((generated, next_token), dim=1)

#         # Stop if EOS token is generated
#         if next_token.item() == tokenizer.eos_token_id:
#             break

#     return generated
# # end custom_autoregressive_sampling