In [1]:
import torch
import torch.nn as nn
from models import GridMLMMelHarm
from GridMLM_tokenizers import CSGridMLMTokenizer
from data_utils import CSGridMLMDataset, CSGridMLM_collate_fn
from torch.utils.data import DataLoader
from train_utils import apply_masking
from generate_utils import random_progressive_generate, structured_progressive_generate

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
batchsize = 1
val_dir = '/media/maindisk/maximos/data/hooktheory_all12_test'
tokenizer = CSGridMLMTokenizer(fixed_length=256)
val_dataset = CSGridMLMDataset(val_dir, tokenizer, 512)
valloader = DataLoader(val_dataset, batch_size=batchsize, shuffle=False, collate_fn=CSGridMLM_collate_fn)

In [3]:
mask_token_id = tokenizer.mask_token_id
stage = 0

In [4]:
curriculum_type = 'random'
device_name = 'cuda:1'
if device_name == 'cpu':
    device = torch.device('cpu')
else:
    if torch.cuda.is_available():
        device = torch.device(device_name)
    else:
        print('Selected device not available: ' + device_name)
model = GridMLMMelHarm(
    chord_vocab_size=len(tokenizer.vocab),
    device=device
)
model_path = 'saved_models/' + curriculum_type +  '.pt'
# checkpoint = torch.load(model_path, map_location=device_name, weights_only=True)
checkpoint = torch.load(model_path, map_location=device_name)
model.load_state_dict(checkpoint)
model.eval()
model.to(device)

GridMLMMelHarm(
  (condition_proj): Linear(in_features=16, out_features=512, bias=True)
  (melody_proj): Linear(in_features=100, out_features=512, bias=True)
  (harmony_embedding): Embedding(354, 512)
  (dropout): Dropout(p=0.3, inplace=False)
  (stage_embedding): Embedding(10, 64)
  (stage_proj): Linear(in_features=576, out_features=512, bias=True)
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-7): 8 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=2048, bias=True)
        (dropout): Dropout(p=0.3, inplace=False)
        (linear2): Linear(in_features=2048, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.3, inplace=False)
        (

In [5]:
i = 0
for batch in valloader:
    if i == 13:
        break
    i += 1

  return self.iter().getElementsByClass(classFilterList)


In [6]:
melody_grid = batch["pianoroll"].to(device)           # (B, 256, 140)
harmony_gt = batch["input_ids"].to(device)         # (B, 256)
conditioning_vec = batch["time_signature"].to(device)  # (B, C0)

In [7]:
# Apply masking to harmony
visible_harmony, denoising_target, stage_indices = apply_masking(
    harmony_gt,
    mask_token_id,
    total_stages=10,
    curriculum_type=curriculum_type,
    stage=0
)

In [8]:
# print('visible_harmony:', visible_harmony)
print('denoising_target:', denoising_target)
# print('stage_indices:', stage_indices)

denoising_target: tensor([[-100, -100, -100, -100, -100, -100, -100, -100,  297, -100, -100, -100,
         -100, -100, -100, -100, -100,  304, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100,  335, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,  129,
         -100, -100, -100, -100, -100,  129, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100,  129, -100, -100, -100, -100,  129,
          297, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100,  304, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,  164,
         -

In [9]:
logits = model(
    conditioning_vec.to(device),
    melody_grid.to(device),
    visible_harmony.to(device),
    stage_indices
)
output_ids = torch.argmax(logits, dim=-1)

In [10]:
# print('output_ids:', output_ids)
output_tokens = []
for i,t in enumerate(output_ids[0].tolist()):
    output_tokens.append( str(i) + ':' + tokenizer.ids_to_tokens[t] )
print('output_tokens')
print(output_tokens)

harmony_gt_tokens = []
for i,t in enumerate(harmony_gt[0].tolist()):
    harmony_gt_tokens.append( str(i) + ':' + tokenizer.ids_to_tokens[t] )
print('harmony_gt_tokens')
print(harmony_gt_tokens)

output_tokens
['0:F#:maj', '1:F#:maj', '2:F#:maj', '3:F#:maj', '4:F#:maj', '5:F#:maj', '6:F#:maj', '7:F#:maj', '8:F#:maj', '9:F#:maj', '10:F#:maj', '11:F#:maj', '12:G#:min', '13:G#:min', '14:G#:min', '15:G#:min', '16:C#:maj', '17:C#:maj', '18:C#:maj', '19:C#:maj', '20:C#:maj', '21:C#:maj', '22:C#:maj', '23:C#:maj', '24:F#:maj', '25:F#:maj', '26:F#:maj', '27:F#:maj', '28:F#:maj', '29:F#:maj', '30:F#:maj', '31:F#:maj', '32:F#:maj', '33:F#:maj', '34:F#:maj', '35:F#:maj', '36:F#:maj', '37:F#:maj', '38:F#:maj', '39:F#:maj', '40:F#:maj', '41:F#:maj', '42:F#:maj', '43:F#:maj', '44:C#:maj', '45:C#:maj', '46:C#:maj', '47:C#:maj', '48:E:maj', '49:E:maj', '50:E:maj', '51:E:maj', '52:E:maj', '53:E:maj', '54:E:maj', '55:E:maj', '56:E:maj', '57:E:maj', '58:E:maj', '59:E:maj', '60:E:maj', '61:E:maj', '62:E:maj', '63:E:maj', '64:E:maj', '65:E:maj', '66:E:maj', '67:E:maj', '68:E:maj', '69:E:maj', '70:E:maj', '71:E:maj', '72:E:maj', '73:E:maj', '74:E:maj', '75:E:maj', '76:E:maj', '77:E:maj', '78:E:maj',

In [11]:
generated_harmony = random_progressive_generate(
    model=model,
    melody_grid=melody_grid,
    conditioning_vec=conditioning_vec,
    num_stages=10,
    mask_token_id=tokenizer.mask_token_id,
    temperature=1.0,
    strategy='topk'
)


In [12]:
print('generated_harmony:', generated_harmony)
output_tokens = []
for i,t in enumerate(generated_harmony[0].tolist()):
    output_tokens.append( str(i) + ':' + tokenizer.ids_to_tokens[t] )
print('output_tokens')
print(output_tokens)

harmony_gt_tokens = []
for i,t in enumerate(harmony_gt[0].tolist()):
    harmony_gt_tokens.append( str(i) + ':' + tokenizer.ids_to_tokens[t] )
print('harmony_gt_tokens')
print(harmony_gt_tokens)

generated_harmony: tensor([[180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180,  35,  35,
          35,  35,  35,  35,  35,  35,  35,  35,  35,  35, 180, 180, 180, 180,
         180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180,  35,  35,
          35,  35,  35,  35,  35,  35,  35,  35,  35,  35,  35,  35,  35,  35,
         122, 122, 122, 122, 122, 122, 122, 122, 122, 122, 122, 122, 122, 122,
         122, 122,  35,  35,  35,  35,  35,  35,  35,  35,  35,  35,  35,  35,
         122, 122, 122, 122, 122, 122, 122, 122, 122, 122, 122, 122, 180, 180,
         180, 180, 180, 180, 180, 180, 180, 180, 180, 180,  35,  35,  35,  35,
          35,  35,  35,  35,  35,  35,  35,  35, 180, 180, 180, 180, 180, 180,
         180, 180, 180, 180, 180, 180, 180, 180, 180, 180,  35,  35,  35,  35,
          35,  35,  35,  35,  35,  35,  35,  35,  35,  35,  35,  35,  35,  35,
          35,  35,  35,  35,  35,  35,  35,  35,  35,  35,  35,  35,  35,  35,
          35,  35,  35,  35,  35,

In [13]:
curriculum_type = 'base2'
device_name = 'cuda:1'
if device_name == 'cpu':
    device = torch.device('cpu')
else:
    if torch.cuda.is_available():
        device = torch.device(device_name)
    else:
        print('Selected device not available: ' + device_name)
model = GridMLMMelHarm(
    chord_vocab_size=len(tokenizer.vocab),
    device=device
)
model_path = 'saved_models/' + curriculum_type +  '.pt'
# checkpoint = torch.load(model_path, map_location=device_name, weights_only=True)
checkpoint = torch.load(model_path, map_location=device_name)
model.load_state_dict(checkpoint)
model.eval()
model.to(device)

GridMLMMelHarm(
  (condition_proj): Linear(in_features=16, out_features=512, bias=True)
  (melody_proj): Linear(in_features=100, out_features=512, bias=True)
  (harmony_embedding): Embedding(354, 512)
  (dropout): Dropout(p=0.3, inplace=False)
  (stage_embedding): Embedding(10, 64)
  (stage_proj): Linear(in_features=576, out_features=512, bias=True)
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-7): 8 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=2048, bias=True)
        (dropout): Dropout(p=0.3, inplace=False)
        (linear2): Linear(in_features=2048, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.3, inplace=False)
        (

In [14]:
generated_harmony = structured_progressive_generate(
    model=model,
    melody_grid=melody_grid,
    conditioning_vec=conditioning_vec,
    num_stages=10,
    mask_token_id=tokenizer.mask_token_id,
    temperature=1.0,
    strategy='topk'
)

In [15]:
print('generated_harmony:', generated_harmony)
output_tokens = []
for i,t in enumerate(generated_harmony[0].tolist()):
    output_tokens.append( str(i) + ':' + tokenizer.ids_to_tokens[t] )
print('output_tokens')
print(output_tokens)

harmony_gt_tokens = []
for i,t in enumerate(harmony_gt[0].tolist()):
    harmony_gt_tokens.append( str(i) + ':' + tokenizer.ids_to_tokens[t] )
print('harmony_gt_tokens')
print(harmony_gt_tokens)

generated_harmony: tensor([[180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180,
         180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180,
         180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180,
         180, 180, 180, 180, 180, 180, 122, 122, 122, 122, 122, 122, 122, 122,
         122, 122, 122, 122, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180,
         180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180,
         180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180,
         180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180,
         180, 180, 180, 180, 180, 180, 180, 180, 122, 122, 122, 122, 122, 122,
         122, 122, 122, 122, 122, 122, 122, 122, 122, 122, 122, 122, 122, 122,
         122, 122, 122, 122,  35,  35,  35,  35,  35,  35,  35,  35,  35,  35,
          35,  35,  35,  35,  35,  35,  35,  35,  35,  35,  35,  35,  35,  35,
         122, 122, 122, 122, 122,