In [1]:
# USER OPTIONS
# define tokenizer name - should be one among the keys in the cell below
tokenizer_name = 'ChordSymbolTokenizer' # or any other name from the keys in tokenizers dictionary
# folder to xmls
# val_dir = '/media/maindisk/maximos/data/gjt_melodies/Library_melodies'
# val_dir = '/media/datadisk/datasets/gjt_melodies/Library_melodies'
val_dir = '/media/maximos/9C33-6BBD/data/gjt_melodies/Library_melodies'
# generation or MLM
generation = True # True if generation, False is MLM
# define batch size depending on GPU availability / status
batchsize = 16
# select device name - could be 'cpu', 'cuda', 'coda:0', 'cuda:1'...
device_name = 'cpu'

In [2]:
from data_utils import MergedMelHarmDataset, MLMCollator, GenCollator
import os
import numpy as np
from harmony_tokenizers_m21 import ChordSymbolTokenizer, RootTypeTokenizer, \
    PitchClassTokenizer, RootPCTokenizer, GCTRootPCTokenizer, \
    GCTSymbolTokenizer, GCTRootTypeTokenizer, MelodyPitchTokenizer, \
    MergedMelHarmTokenizer
from torch.utils.data import DataLoader
from transformers import RobertaConfig, RobertaForMaskedLM, AutoConfig, GPT2LMHeadModel
import torch
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
tokenizers = {
    'ChordSymbolTokenizer': ChordSymbolTokenizer,
    'RootTypeTokenizer': RootTypeTokenizer,
    'PitchClassTokenizer': PitchClassTokenizer,
    'RootPCTokenizer': RootPCTokenizer,
    'GCTRootPCTokenizer': GCTRootPCTokenizer,
    'GCTSymbolTokenizer': GCTSymbolTokenizer,
    'GCTRootTypeTokenizer': GCTRootTypeTokenizer
}

In [4]:
melody_tokenizer = MelodyPitchTokenizer.from_pretrained('saved_tokenizers/MelodyPitchTokenizer')
harmony_tokenizer = tokenizers[tokenizer_name].from_pretrained('saved_tokenizers/' + tokenizer_name)

tokenizer = MergedMelHarmTokenizer(melody_tokenizer, harmony_tokenizer)

In [5]:
if generation:
    collator = GenCollator(tokenizer)
    val_dataset = MergedMelHarmDataset(val_dir, tokenizer, max_length=2048, return_harmonization_labels=True)
    model_path = 'saved_models/gen/' + tokenizer_name + '/' + tokenizer_name + '.pt'
else:
    collator = MLMCollator(tokenizer)
    val_dataset = MergedMelHarmDataset(val_dir, tokenizer, max_length=2048)
    model_path = 'saved_models/mlm/' + tokenizer_name + '/' + tokenizer_name + '.pt'

valloader = DataLoader(val_dataset, batch_size=batchsize, shuffle=True, collate_fn=collator)

In [6]:
if generation:
    config = AutoConfig.from_pretrained(
        "gpt2",
        vocab_size=len(tokenizer.vocab),
        n_positions=2048,
        n_layer=4,
        n_head=4,
        pad_token_id=tokenizer.vocab[tokenizer.pad_token],
        bos_token_id=tokenizer.vocab[tokenizer.bos_token],
        eos_token_id=tokenizer.vocab[tokenizer.eos_token],
        n_embd=256
    )

    model = GPT2LMHeadModel(config)
else:
    model_config = RobertaConfig(
        vocab_size=len(tokenizer.vocab),
        hidden_size=256,
        num_hidden_layers=4,
        num_attention_heads=4,
        pad_token_id=tokenizer.vocab[tokenizer.pad_token],
        bos_token_id=tokenizer.vocab[tokenizer.bos_token],
        eos_token_id=tokenizer.vocab[tokenizer.eos_token],
        mask_token_id=tokenizer.vocab[tokenizer.mask_token],
        max_position_embeddings=2048,
    )

    model = RobertaForMaskedLM(model_config)
# end if

checkpoint = torch.load(model_path, map_location="cpu", weights_only=True)
model.load_state_dict(checkpoint)

model.eval()

  return self.fget.__get__(instance, owner)()


GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(545, 256)
    (wpe): Embedding(2048, 256)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-3): 4 x GPT2Block(
        (ln_1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D(nf=768, nx=256)
          (c_proj): Conv1D(nf=256, nx=256)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=1024, nx=256)
          (c_proj): Conv1D(nf=256, nx=1024)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=256, out_features=545, bias=False)
)

In [7]:
if device_name == 'cpu':
    device = torch.device('cpu')
else:
    if torch.cuda.is_available():
        device = torch.device(device_name)
    else:
        print('Selected device not available: ' + device_name)
model.to(device)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(545, 256)
    (wpe): Embedding(2048, 256)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-3): 4 x GPT2Block(
        (ln_1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D(nf=768, nx=256)
          (c_proj): Conv1D(nf=256, nx=256)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=1024, nx=256)
          (c_proj): Conv1D(nf=256, nx=1024)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=256, out_features=545, bias=False)
)

In [24]:
val_loss = 0
running_loss = 0
batch_num = 0
running_accuracy = 0
val_accuracy = 0
print('validation')
with torch.no_grad():
    with tqdm(valloader, unit='batch') as tepoch:
        tepoch.set_description(f'Running')
        print(tepoch)
        for batch in tepoch:
            input_ids = batch['input_ids'].to(device)
            if generation:
                attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            if generation:
                outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            else:
                outputs = model(input_ids, labels=labels)
            loss = outputs.loss
            
            # update loss
            batch_num += 1
            running_loss += loss.item()
            val_loss = running_loss/batch_num
            # accuracy
            if generation:
                predictions = outputs.logits.argmax(dim=-1).roll(shifts=(0,1), dims=(0,1))
            else:
                predictions = outputs.logits.argmax(dim=-1)
            mask = labels != -100
            running_accuracy += (predictions[mask] == labels[mask]).sum().item()/mask.sum().item()
            val_accuracy = running_accuracy/batch_num
            
            tepoch.set_postfix(loss=val_loss, accuracy=val_accuracy)

validation


Running:   0%|          | 0/41 [00:00<?, ?batch/s]

Running:   0%|          | 0/41 [00:00<?, ?batch/s]


  return self.iter().getElementsByClass(classFilterList)
Running:   2%|▏         | 1/41 [00:07<05:12,  7.82s/batch, accuracy=0.573, loss=1.89]

tensor([[111, 117,   6,  ..., 111,   6,   6],
        [111, 117,   6,  ..., 111,   6,   6],
        [  6, 117,   6,  ...,   6,   6,   6],
        ...,
        [  3, 117,   6,  ...,   3, 429,   3],
        [  6, 117,   6,  ...,   6,   6,   6],
        [111, 117,   6,  ..., 111, 436, 111]])


  return self.iter().getElementsByClass(classFilterList)
Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7b8af36ab950>>
Traceback (most recent call last):
  File "/home/maximos/miniconda/envs/harmtok/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(

KeyboardInterrupt: 
Running:   5%|▍         | 2/41 [00:17<05:42,  8.77s/batch, accuracy=0.546, loss=1.98]

tensor([[111, 117,   6,  ..., 111, 494,   6],
        [  6, 117,   6,  ..., 111,   6,   6],
        [  6, 117,   6,  ..., 111,   6,   6],
        ...,
        [  6, 117,   6,  ...,   6, 408,   6],
        [  6, 117,   6,  ...,   6,   6,   6],
        [  6, 117,   6,  ...,   6,   6,   6]])
