In [1]:
from harmony_tokenizers_m21 import ChordSymbolTokenizer, PitchClassTokenizer, MelodyPitchTokenizer, MergedMelHarmTokenizer
from data_utils import StructGPTMelHarmDataset, GenCollator
from torch.utils.data import DataLoader
import torch
from transformers import AutoConfig, GPT2LMHeadModel
import numpy as np

import heapq
import torch
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
test_dir = '/media/maindisk/maximos/data/hooktheory_test'

cstok = ChordSymbolTokenizer()
pctok = PitchClassTokenizer()
meltok = MelodyPitchTokenizer()
tokenizer = MergedMelHarmTokenizer(meltok, cstok)

test_dataset = StructGPTMelHarmDataset(test_dir, tokenizer, max_length=512, num_bars=16, return_harmonization_labels=True)

collator = GenCollator(tokenizer)

trainloader = DataLoader(test_dataset, batch_size=1, shuffle=True, collate_fn=collator)

In [3]:
config = AutoConfig.from_pretrained(
    "gpt2",
    vocab_size=len(tokenizer.vocab),
    n_positions=512,
    n_layer=8,
    n_head=8,
    pad_token_id=tokenizer.vocab[tokenizer.pad_token],
    bos_token_id=tokenizer.vocab[tokenizer.bos_token],
    eos_token_id=tokenizer.vocab[tokenizer.eos_token],
    n_embd=512
)

model = GPT2LMHeadModel(config)

model_path = 'saved_models/gpt/ChordSymbolTokenizer/ChordSymbolTokenizer.pt'

device_name = 'cuda:0'
if device_name == 'cpu':
    device = torch.device('cpu')
else:
    if torch.cuda.is_available():
        device = torch.device(device_name)
    else:
        print('Selected device not available: ' + device_name)

checkpoint = torch.load(model_path, map_location=device_name, weights_only=True)
model.load_state_dict(checkpoint)
model.eval()
model.to(device)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(547, 512)
    (wpe): Embedding(512, 512)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-7): 8 x GPT2Block(
        (ln_1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D(nf=1536, nx=512)
          (c_proj): Conv1D(nf=512, nx=512)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=2048, nx=512)
          (c_proj): Conv1D(nf=512, nx=2048)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=512, out_features=547, bias=False)
)

In [63]:
batch = next(iter(trainloader))

In [123]:
class SearchNode:
    def __init__(self, tokens, logprob, heuristic, parent=None):
        self.tokens = tokens  # shape: (1, seq_len)
        self.logprob = logprob
        self.heuristic = heuristic
        self.parent = parent
        self.tried_tokens = set()  # token IDs already expanded

    def __lt__(self, other):
        return (self.logprob + self.heuristic) > (other.logprob + other.heuristic)
# end class SearchNode

class AStar:
    def __init__(self, model, tokenizer, input_ids, constraint_ids, max_length=512, beam_width=10, lookahead_k=5):
        self.model = model
        self.tokenizer = tokenizer
        self.input_ids = input_ids
        self.constraint_ids = constraint_ids.tolist()
        self.max_length = max_length
        self.beam_width = beam_width
        self.lookahead_k = lookahead_k
        self.eos_token_id = tokenizer.eos_token_id
        self.constraint_tokens_breakdown()
    # end init

    def constraint_tokens_breakdown(self):
        tokens = self.tokenizer.convert_ids_to_tokens(self.constraint_ids)
        bar_count = 0
        chord_tokens = []
        i = 0
        while i < len(tokens):
            tok = tokens[i]
            if 'bar' in tok or 'fill' in tok or '/m' in tok or '<h>' in tok:
                if 'bar' in tok:
                    bar_count += 1
                i += 1
            else:
                # we should have arrived in a position token
                position_token = tokens[i]
                # the remaining tokens should be the chord
                i += 1
                while i < len(tokens) and \
                    'fill' not in tokens[i] and \
                     'bar' not in tokens[i] and \
                     '</s>' not in tokens[i]:
                    chord_tokens.append( tokens[i] )
                    i += 1
                break
        self.constraint_bar = bar_count
        self.position_token = position_token
        self.chord_tokens = chord_tokens
    # end constraint_tokens_breakdown

    def constraint_checker(self, input_tokens):
        tokens = self.tokenizer.convert_ids_to_tokens(input_tokens.tolist())
        start_harmonization_index = tokens.index('<h>')
        tokens = tokens[start_harmonization_index:]
        print(f'constraint_checker: {tokens}')
        
        bar_count = 0
        found = False
        i = 0
        while i < len(tokens):
            tok = tokens[i]
            if tok == "<bar>":
                bar_count += 1
            elif bar_count == self.constraint_bar and tok == self.position_token:
                j = 0
                found = True
                while j < len(self.chord_tokens):
                    if i + j + 1 < len(tokens) and tokens[i + j + 1] != self.chord_tokens[j]:
                        found = False
                    break
                break
            i += 1
        print(f'checker: {found or bar_count < self.constraint_bar}')
        return found or bar_count < self.constraint_bar  # Only reject if we're past bar 3 and it's missing
    # end constraint_checker

    def expand_node(self, node):
        with torch.no_grad():
            output = model(node.tokens.to(model.device), return_dict=True)
            logits = output.logits[:, -1, :]

            # Mask out already visited tokens before softmax
            if node.tried_tokens:
                mask = torch.full_like(logits, 0.0)
                mask[:, list(node.tried_tokens)] = float('-inf')
                logits = logits + mask

            probs = F.log_softmax(logits, dim=-1)
            topk_probs, topk_ids = torch.topk(probs, self.lookahead_k, dim=-1)

        new_nodes = []
        for i in range(self.lookahead_k):
            token_id = topk_ids[0, i].item()
            token_prob = topk_probs[0, i].item()
            node.tried_tokens.add(token_id)

            new_tokens = torch.cat([node.tokens.to(model.device), topk_ids[0, i].unsqueeze(0).unsqueeze(0)], dim=-1).to(model.device)

            if not self.constraint_checker(new_tokens[0]):
                continue

            new_logprob = node.logprob + token_prob
            new_node = SearchNode(new_tokens, new_logprob, 0.0, parent=node)
            new_nodes.append(new_node)
        return new_nodes
    # end expand_node

    def decode(self):
        initial_node = SearchNode(tokens=self.input_ids, logprob=0.0, heuristic=0.0)
        open_set = [initial_node]
        finished = []
        while open_set:
            current = heapq.heappop(open_set)

            if current.tokens.shape[-1] >= self.max_length or (
                self.eos_token_id and current.tokens[0, -1].item() == self.eos_token_id
            ):
                finished.append(current)
                continue

            # Expand current node
            children = self.expand_node(current)

            if children:
                for child in children:
                    heapq.heappush(open_set, child)
            else:
                # No valid expansions – backtrack to unvisited options
                back = current.parent
                while back:
                    # Re-expand from back with unvisited options
                    more_options = self.expand_node(back)
                    if more_options:
                        for opt in more_options:
                            heapq.heappush(open_set, opt)
                        break
                    back = back.parent

            # Prune open set
            open_set = sorted(open_set, reverse=True)[:self.beam_width]

        if not finished:
            raise RuntimeError("No valid sequence could be generated under constraints.")

        best = sorted(finished, key=lambda x: x.logprob + x.heuristic, reverse=True)[0]
        return best.tokens
    # end decode
# end class AStar

In [90]:
all_ids = batch['input_ids']
print(all_ids)
melody_end_index = all_ids[0].tolist().index( tokenizer.vocab['</m>'] )
harmony_start_index = all_ids[0].tolist().index( tokenizer.vocab['<h>'] )
print(melody_end_index)
constraint_ids = all_ids[0][melody_end_index:]
print(constraint_ids)
# start_harmony_position = np.where( all_ids == harmony_start_index )[0][0]
input_ids = all_ids[0][:(harmony_start_index+1)]
input_ids = input_ids.reshape(1, -1)
print(input_ids.shape)

tensor([[  2,   6, 183,  98,  58, 102,  65, 104,  66, 106,  65, 110,  63, 114,
          61, 116,  63, 118,  61, 120,  60, 122,  58, 126,  58,   6,  98,  58,
         102,  61, 104,  63, 106,  61, 110,  60, 114,  58, 116,  60, 118,  58,
         120,  56, 122,  58,   6,  98,  58, 102,  65, 104,  66, 106,  65, 110,
          63, 114,  61, 116,  63, 118,  61, 120,  60, 122,  58, 126,  58,   6,
          98,  58, 100,  60, 102,  61, 104,  63, 106,  61, 110,  60, 114,  58,
         118,  56, 122,  58,   6,  98,  58, 102,  65, 104,  66, 106,  65, 110,
          63, 114,  61, 116,  63, 118,  61, 120,  60, 122,  58, 126,  58,   6,
          98,  58, 102,  61, 104,  63, 106,  61, 110,  60, 114,  58, 116,  60,
         118,  58, 120,  56, 122,  58,   6,  98,  58, 102,  65, 104,  66, 106,
          65, 110,  63, 114,  61, 116,  63, 118,  61, 120,  60, 122,  58, 126,
          58,   6,  98,  61, 102,  63, 104,  61, 106,  60, 110,  58, 112,  56,
         114,  58,   8,   6,   9,   6,  98, 461,   9

In [124]:
astar = AStar( model, tokenizer, input_ids, constraint_ids, max_length=512, beam_width=10, lookahead_k=5 )

In [112]:
print(astar.constraint_bar)
print(astar.position_token)
print(astar.chord_tokens)

2
position_0x00
['A:min']


In [125]:
generated_ids = astar.decode()
generated_tokens = []

constraint_checker: ['<h>', '<bar>']
checker: True
constraint_checker: ['<h>', 'position_2x00']
checker: True
constraint_checker: ['<h>', 'position_0x00']
checker: True
constraint_checker: ['<h>', 'position_3x00']
checker: True
constraint_checker: ['<h>', 'position_1x00']
checker: True
constraint_checker: ['<h>', 'position_1x00', 'A:min']
checker: True
constraint_checker: ['<h>', 'position_1x00', 'F:maj']
checker: True
constraint_checker: ['<h>', 'position_1x00', 'D:min']
checker: True
constraint_checker: ['<h>', 'position_1x00', 'C:maj']
checker: True
constraint_checker: ['<h>', 'position_1x00', 'F:maj7']
checker: True
constraint_checker: ['<h>', 'position_1x00', 'F:maj7', 'position_2x00']
checker: True
constraint_checker: ['<h>', 'position_1x00', 'F:maj7', '<bar>']
checker: True
constraint_checker: ['<h>', 'position_1x00', 'F:maj7', 'position_3x00']
checker: True
constraint_checker: ['<h>', 'position_1x00', 'F:maj7', 'position_1x50']
checker: True
constraint_checker: ['<h>', 'positio

In [121]:
print(generated_ids)

tensor([[  2,   6, 183,  98,  58, 102,  65, 104,  66, 106,  65, 110,  63, 114,
          61, 116,  63, 118,  61, 120,  60, 122,  58, 126,  58,   6,  98,  58,
         102,  61, 104,  63, 106,  61, 110,  60, 114,  58, 116,  60, 118,  58,
         120,  56, 122,  58,   6,  98,  58, 102,  65, 104,  66, 106,  65, 110,
          63, 114,  61, 116,  63, 118,  61, 120,  60, 122,  58, 126,  58,   6,
          98,  58, 100,  60, 102,  61, 104,  63, 106,  61, 110,  60, 114,  58,
         118,  56, 122,  58,   6,  98,  58, 102,  65, 104,  66, 106,  65, 110,
          63, 114,  61, 116,  63, 118,  61, 120,  60, 122,  58, 126,  58,   6,
          98,  58, 102,  61, 104,  63, 106,  61, 110,  60, 114,  58, 116,  60,
         118,  58, 120,  56, 122,  58,   6,  98,  58, 102,  65, 104,  66, 106,
          65, 110,  63, 114,  61, 116,  63, 118,  61, 120,  60, 122,  58, 126,
          58,   6,  98,  61, 102,  63, 104,  61, 106,  60, 110,  58, 112,  56,
         114,  58,   8,   6,   9,   6,  98, 461,   9

In [122]:
for i in generated_ids[0]:
    generated_tokens.append( tokenizer.ids_to_tokens[ int(i) ].replace(' ','x') )
print(generated_tokens)

['<s>', '<bar>', 'ts_4x4', 'position_0x00', 'P:69', 'position_0x50', 'P:76', 'position_0x75', 'P:77', 'position_1x00', 'P:76', 'position_1x50', 'P:74', 'position_2x00', 'P:72', 'position_2x25', 'P:74', 'position_2x50', 'P:72', 'position_2x75', 'P:71', 'position_3x00', 'P:69', 'position_3x50', 'P:69', '<bar>', 'position_0x00', 'P:69', 'position_0x50', 'P:72', 'position_0x75', 'P:74', 'position_1x00', 'P:72', 'position_1x50', 'P:71', 'position_2x00', 'P:69', 'position_2x25', 'P:71', 'position_2x50', 'P:69', 'position_2x75', 'P:67', 'position_3x00', 'P:69', '<bar>', 'position_0x00', 'P:69', 'position_0x50', 'P:76', 'position_0x75', 'P:77', 'position_1x00', 'P:76', 'position_1x50', 'P:74', 'position_2x00', 'P:72', 'position_2x25', 'P:74', 'position_2x50', 'P:72', 'position_2x75', 'P:71', 'position_3x00', 'P:69', 'position_3x50', 'P:69', '<bar>', 'position_0x00', 'P:69', 'position_0x25', 'P:71', 'position_0x50', 'P:72', 'position_0x75', 'P:74', 'position_1x00', 'P:72', 'position_1x50', 'P:7