In [1]:
import torch
import torch.nn as nn
from models import GridMLMMelHarm
from GridMLM_tokenizers import CSGridMLMTokenizer
from data_utils import CSGridMLMDataset, CSGridMLM_collate_fn
from torch.utils.data import DataLoader
from train_utils import apply_masking
from generate_utils import random_progressive_generate, structured_progressive_generate,\
    load_model, overlay_generated_harmony, save_harmonized_score

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
val_dir = '/media/maindisk/maximos/data/hooktheory_all12_test'
tokenizer = CSGridMLMTokenizer(fixed_length=256)
val_dataset = CSGridMLMDataset(val_dir, tokenizer, 512)

In [3]:
mask_token_id = tokenizer.mask_token_id
pad_token_id = tokenizer.pad_token_id
nc_token_id = tokenizer.nc_token_id

In [4]:
curriculum_type = 'random'
device_name = 'cuda:1'
if device_name == 'cpu':
    device = torch.device('cpu')
else:
    if torch.cuda.is_available():
        device = torch.device(device_name)
    else:
        print('Selected device not available: ' + device_name)
model = GridMLMMelHarm(
    chord_vocab_size=len(tokenizer.vocab),
    device=device
)
model_path = 'saved_models/' + curriculum_type +  '.pt'
# checkpoint = torch.load(model_path, map_location=device_name, weights_only=True)
checkpoint = torch.load(model_path, map_location=device_name)
model.load_state_dict(checkpoint)
model.eval()
model.to(device)

GridMLMMelHarm(
  (condition_proj): Linear(in_features=16, out_features=512, bias=True)
  (melody_proj): Linear(in_features=100, out_features=512, bias=True)
  (harmony_embedding): Embedding(354, 512)
  (dropout): Dropout(p=0.3, inplace=False)
  (stage_embedding): Embedding(10, 64)
  (stage_proj): Linear(in_features=576, out_features=512, bias=True)
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-7): 8 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=2048, bias=True)
        (dropout): Dropout(p=0.3, inplace=False)
        (linear2): Linear(in_features=2048, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.3, inplace=False)
        (

In [5]:
import os
data_files = []
for dirpath, _, filenames in os.walk(val_dir):
    for file in filenames:
        if file.endswith('.xml') or file.endswith('.mxl') or file.endswith('.musicxml'):
            full_path = os.path.join(dirpath, file)
            data_files.append(full_path)
print(len(data_files))

10486


In [6]:
encoded = tokenizer.encode(data_files[1473])

  return self.iter().getElementsByClass(classFilterList)


In [7]:
print(encoded.keys())

dict_keys(['input_tokens', 'input_ids', 'pianoroll', 'time_signature', 'attention_mask', 'skip_steps', 'melody_part', 'ql_per_quantum'])


In [8]:
melody_grid = torch.stack([torch.tensor(encoded['pianoroll'], dtype=torch.float)])
conditioning_vec = torch.stack([torch.tensor(encoded['time_signature'], dtype=torch.float)])
harmony_gt = torch.stack([torch.tensor(encoded['input_ids'], dtype=torch.float)])

In [9]:
if curriculum_type == 'base2':
    generated_harmony = structured_progressive_generate(
        model=model,
        melody_grid=melody_grid,
        conditioning_vec=conditioning_vec,
        num_stages=10,
        mask_token_id=tokenizer.mask_token_id,
        temperature=1.0,
        strategy='topk'
    )
else:
    generated_harmony = random_progressive_generate(
        model=model,
        melody_grid=melody_grid,
        conditioning_vec=conditioning_vec,
        num_stages=10,
        mask_token_id=tokenizer.mask_token_id,
        temperature=1.0,
        strategy='topk',
        pad_token_id=pad_token_id,      # token ID for <pad>
        nc_token_id=nc_token_id,       # token ID for <nc>
        force_fill=True         # disallow <pad>/<nc> before melody ends
    )

last_active_index: 243


In [10]:
output_tokens = []
for i,t in enumerate(generated_harmony[0].tolist()):
    output_tokens.append( tokenizer.ids_to_tokens[t] )
print('output_tokens')
print(output_tokens)

harmony_gt_tokens = []
for i,t in enumerate(harmony_gt[0].tolist()):
    harmony_gt_tokens.append( tokenizer.ids_to_tokens[t] )
print('harmony_gt_tokens')
print(harmony_gt_tokens)

output_tokens
['G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:m

In [11]:
score = overlay_generated_harmony(encoded['melody_part'], output_tokens, encoded['ql_per_quantum'], encoded['skip_steps'])
save_harmonized_score(score, out_path="harmonized_output.mxl")