In [31]:
import torch
import torch.nn as nn
from models import GridMLMMelHarm
from GridMLM_tokenizers import CSGridMLMTokenizer
from data_utils import CSGridMLMDataset, CSGridMLM_collate_fn
from torch.utils.data import DataLoader
from train_utils import apply_masking
from generate_utils import random_progressive_generate, structured_progressive_generate

In [32]:
batchsize = 1
val_dir = '/media/maindisk/maximos/data/hooktheory_all12_test'
tokenizer = CSGridMLMTokenizer(fixed_length=256)
val_dataset = CSGridMLMDataset(val_dir, tokenizer, 512)
valloader = DataLoader(val_dataset, batch_size=batchsize, shuffle=False, collate_fn=CSGridMLM_collate_fn)

In [33]:
mask_token_id = tokenizer.mask_token_id
stage = 0

In [34]:
curriculum_type = 'random'
device_name = 'cuda:1'
if device_name == 'cpu':
    device = torch.device('cpu')
else:
    if torch.cuda.is_available():
        device = torch.device(device_name)
    else:
        print('Selected device not available: ' + device_name)
model = GridMLMMelHarm(
    chord_vocab_size=len(tokenizer.vocab),
    device=device
)
model_path = 'saved_models/' + curriculum_type +  '.pt'
# checkpoint = torch.load(model_path, map_location=device_name, weights_only=True)
checkpoint = torch.load(model_path, map_location=device_name)
model.load_state_dict(checkpoint)
model.eval()
model.to(device)

GridMLMMelHarm(
  (condition_proj): Linear(in_features=16, out_features=512, bias=True)
  (melody_proj): Linear(in_features=100, out_features=512, bias=True)
  (harmony_embedding): Embedding(354, 512)
  (dropout): Dropout(p=0.3, inplace=False)
  (stage_embedding): Embedding(10, 64)
  (stage_proj): Linear(in_features=576, out_features=512, bias=True)
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-7): 8 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=2048, bias=True)
        (dropout): Dropout(p=0.3, inplace=False)
        (linear2): Linear(in_features=2048, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.3, inplace=False)
        (

In [35]:
i = 0
for batch in valloader:
    if i == 0:
        break
    i += 1

  return self.iter().getElementsByClass(classFilterList)


In [36]:
melody_grid = batch["pianoroll"].to(device)           # (B, 256, 140)
harmony_gt = batch["input_ids"].to(device)         # (B, 256)
conditioning_vec = batch["time_signature"].to(device)  # (B, C0)

In [37]:
generated_harmony = random_progressive_generate(
    model=model,
    melody_grid=melody_grid,
    conditioning_vec=conditioning_vec,
    num_stages=10,
    mask_token_id=tokenizer.mask_token_id,
    temperature=1.0,
    strategy='topk'
)


In [38]:
print('generated_harmony:', generated_harmony)
output_tokens = []
for i,t in enumerate(generated_harmony[0].tolist()):
    output_tokens.append( str(i) + ':' + tokenizer.ids_to_tokens[t] )
print('output_tokens')
print(output_tokens)

harmony_gt_tokens = []
for i,t in enumerate(harmony_gt[0].tolist()):
    harmony_gt_tokens.append( str(i) + ':' + tokenizer.ids_to_tokens[t] )
print('harmony_gt_tokens')
print(harmony_gt_tokens)

generated_harmony: tensor([[ 93,  93,  93,  93,  93,  93,  93,  93, 151, 151, 151, 151, 151, 151,
         151, 151, 151, 151, 151, 151, 151, 151, 151, 151, 151, 151, 151, 151,
         151, 151, 151, 151, 296, 296, 296, 296, 296,  93, 296, 296,   6,   6,
           6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,
           6,   6,   6,   6,   6,   6,   6,   6,  93,  93,  93,  93,  93,  93,
          93,  93, 151, 151, 151, 151, 151, 151, 151, 151, 151, 151, 151, 151,
         151, 151, 151, 151, 151, 151, 151, 151, 151, 151, 151, 151,  93,  93,
          93,  93,  93,  93,  93,  93,  93,  93,  93,  93,  93,  93,  93,  93,
         180, 180, 180, 180, 180, 180, 180, 180, 238, 238, 238, 238, 238, 238,
         238, 238,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
           1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
           1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
           1,   1,   1,   1,   1,

In [39]:
curriculum_type = 'random'
device_name = 'cuda:1'
if device_name == 'cpu':
    device = torch.device('cpu')
else:
    if torch.cuda.is_available():
        device = torch.device(device_name)
    else:
        print('Selected device not available: ' + device_name)
model = GridMLMMelHarm(
    chord_vocab_size=len(tokenizer.vocab),
    device=device
)
model_path = 'saved_models/' + curriculum_type +  '.pt'
# checkpoint = torch.load(model_path, map_location=device_name, weights_only=True)
checkpoint = torch.load(model_path, map_location=device_name)
model.load_state_dict(checkpoint)
model.eval()
model.to(device)

GridMLMMelHarm(
  (condition_proj): Linear(in_features=16, out_features=512, bias=True)
  (melody_proj): Linear(in_features=100, out_features=512, bias=True)
  (harmony_embedding): Embedding(354, 512)
  (dropout): Dropout(p=0.3, inplace=False)
  (stage_embedding): Embedding(10, 64)
  (stage_proj): Linear(in_features=576, out_features=512, bias=True)
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-7): 8 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=2048, bias=True)
        (dropout): Dropout(p=0.3, inplace=False)
        (linear2): Linear(in_features=2048, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.3, inplace=False)
        (

In [40]:
generated_harmony = random_progressive_generate(
    model=model,
    melody_grid=melody_grid,
    conditioning_vec=conditioning_vec,
    num_stages=10,
    mask_token_id=tokenizer.mask_token_id,
    temperature=1.0,
    strategy='topk'
)

In [41]:
print('generated_harmony:', generated_harmony)
output_tokens = []
for i,t in enumerate(generated_harmony[0].tolist()):
    output_tokens.append( str(i) + ':' + tokenizer.ids_to_tokens[t] )
print('output_tokens')
print(output_tokens)

harmony_gt_tokens = []
for i,t in enumerate(harmony_gt[0].tolist()):
    harmony_gt_tokens.append( str(i) + ':' + tokenizer.ids_to_tokens[t] )
print('harmony_gt_tokens')
print(harmony_gt_tokens)

generated_harmony: tensor([[ 93,  93,  93,  93,  93,  93,  93,  93, 151, 151, 151, 151, 151, 151,
         151, 151, 151, 151, 151, 151, 151, 151, 151, 151, 151, 151, 151, 151,
         151, 151, 151, 151, 296, 296, 296, 296, 296,  93, 296, 296,   6,   6,
           6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,
           6,   6,   6,   6,   6,   6,   6,   6,  93,  93,  93,  93,  93,  93,
          93,  93, 151, 151, 151, 151, 151, 151, 151, 151, 151, 151, 151, 151,
         151, 151, 151, 151, 151, 151, 151, 151, 151, 151, 151, 151,  93,  93,
          93,  93,  93,  93,  93,  93,  93,  93,  93,  93,  93,  93,  93,  93,
         180, 180, 180, 180, 180, 180, 180, 180, 238, 238, 238, 238, 238, 238,
         238, 238,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
           1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
           1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
           1,   1,   1,   1,   1,

In [42]:
import os
data_files = []
for dirpath, _, filenames in os.walk(val_dir):
    for file in filenames:
        if file.endswith('.xml') or file.endswith('.mxl') or file.endswith('.musicxml'):
            full_path = os.path.join(dirpath, file)
            data_files.append(full_path)
print(len(data_files))

10486


In [43]:
encoded = tokenizer.encode(data_files[1473])

In [44]:
print(encoded.keys())

dict_keys(['input_tokens', 'input_ids', 'pianoroll', 'time_signature', 'attention_mask', 'skip_steps', 'melody_part', 'ql_per_quantum'])


In [45]:
melody_grid = torch.stack([torch.tensor(encoded['pianoroll'], dtype=torch.float)])
conditioning_vec = torch.stack([torch.tensor(encoded['time_signature'], dtype=torch.float)])
harmony_gt = torch.stack([torch.tensor(encoded['input_ids'], dtype=torch.float)])

In [46]:
generated_harmony = structured_progressive_generate(
    model=model,
    melody_grid=melody_grid,
    conditioning_vec=conditioning_vec,
    num_stages=10,
    mask_token_id=tokenizer.mask_token_id,
    temperature=1.0,
    strategy='topk'
)

In [47]:
output_tokens = []
for i,t in enumerate(generated_harmony[0].tolist()):
    output_tokens.append( tokenizer.ids_to_tokens[t] )
print('output_tokens')
print(output_tokens)

harmony_gt_tokens = []
for i,t in enumerate(harmony_gt[0].tolist()):
    harmony_gt_tokens.append( tokenizer.ids_to_tokens[t] )
print('harmony_gt_tokens')
print(harmony_gt_tokens)

output_tokens
['<nc>', '<nc>', '<nc>', '<nc>', '<nc>', '<nc>', '<nc>', '<nc>', '<nc>', '<nc>', '<nc>', '<nc>', '<nc>', '<nc>', '<nc>', '<nc>', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'G:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D:maj', 'D

In [48]:
from music21 import harmony, stream, metadata, chord, meter
import mir_eval
import numpy as np
from copy import deepcopy

def overlay_generated_harmony(melody_part, generated_chords, ql_per_16th, skip_steps):
    # create a part for chords in midi format
    chords_part = stream.Part()
    # Create deep copy of flat melody part
    harmonized_part = deepcopy(melody_part)
    
    # Remove old chord symbols
    for el in harmonized_part.recurse().getElementsByClass(harmony.ChordSymbol):
        harmonized_part.remove(el)

    # Track inserted chords
    last_chord_symbol = None
    inserted_chords = {}

    for i, mir_chord in enumerate(generated_chords):
        if mir_chord in ("<pad>", "<nc>"):
            continue
        if mir_chord == last_chord_symbol:
            continue

        offset = (i + skip_steps) * ql_per_16th

        # Decode mir_eval chord symbol to chord symbol object
        try:
            r, t, _ = mir_eval.chord.encode(mir_chord, reduce_extended_chords=True)
            pcs = r + np.where(t > 0)[0] + 48
            c = chord.Chord(pcs.tolist())
            chord_symbol = harmony.chordSymbolFromChord(c)
        except Exception as e:
            print(f"Skipping invalid chord {mir_chord} at step {i}: {e}")
            continue

        # harmonized_part.insert(offset, chord_symbol)
        chords_part.insert(offset, c)
        inserted_chords[i] = chord_symbol
        last_chord_symbol = mir_chord

    # Convert flat part to one with measures
    harmonized_with_measures = harmonized_part.makeMeasures()

    # Repeat previous chord at start of bars with no chord
    for m in harmonized_with_measures.getElementsByClass(stream.Measure):
        bar_offset = m.offset
        # has_chord = any(isinstance(el, harmony.ChordSymbol) and el.offset == bar_offset for el in m)
        # has_chord = any( isinstance(el, harmony.ChordSymbol) for el in m )
        has_chord = any(isinstance(el, harmony.ChordSymbol) and el.offset == 0. for el in m)
        if not has_chord:
            # Find previous chord before this measure
            prev_chords = [el for el in harmonized_part.recurse().getElementsByClass(harmony.ChordSymbol)
                           if el.offset < bar_offset]
            if prev_chords:
                prev_chord = prev_chords[-1]
                m.insert(0.0, deepcopy(prev_chord))
    
    # Convert flat part to one with measures
    chords_with_measures = chords_part.makeMeasures()

    # Repeat previous chord at start of bars with no chord
    for m in chords_with_measures.getElementsByClass(stream.Measure):
        bar_offset = m.offset
        # has_chord = any(isinstance(el, chord.Chord) and el.offset == bar_offset for el in m)
        # has_chord = any( isinstance(el, chord.Chord) for el in m )
        has_chord = any(isinstance(el, chord.Chord) and el.offset == 0. for el in m)
        if not has_chord:
            # Find previous chord before this measure
            prev_chords = [el for el in chords_part.recurse().getElementsByClass(chord.Chord)
                           if el.offset < bar_offset]
            if prev_chords:
                prev_chord = prev_chords[-1]
                m.insert(0.0, deepcopy(prev_chord))

    # Create final score with chords and melody
    score = stream.Score()
    score.insert(0, harmonized_with_measures)
    score.insert(0, chords_with_measures)

    return score
# end overlay_generated_harmony

def save_harmonized_score(score, title="Harmonized Piece", out_path="harmonized.xml"):
    score.metadata = metadata.Metadata()
    score.metadata.title = title
    score.write('musicxml', fp=out_path)
# end save_harmonized_score

In [49]:
score = overlay_generated_harmony(encoded['melody_part'], output_tokens, encoded['ql_per_quantum'], encoded['skip_steps'])
save_harmonized_score(score, out_path="harmonized_output.mxl")
