In [1]:
from harmony_tokenizers_m21 import ChordSymbolTokenizer, PitchClassTokenizer, MelodyPitchTokenizer, MergedMelHarmTokenizer
from data_utils import StructGPTMelHarmDataset, GenCollator
from torch.utils.data import DataLoader
import torch
from transformers import AutoConfig, GPT2LMHeadModel
import numpy as np

import heapq
import torch
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
test_dir = '/media/maindisk/maximos/data/hooktheory_test'

cstok = ChordSymbolTokenizer()
pctok = PitchClassTokenizer()
meltok = MelodyPitchTokenizer()
tokenizer = MergedMelHarmTokenizer(meltok, cstok)

test_dataset = StructGPTMelHarmDataset(test_dir, tokenizer, max_length=512, num_bars=16, return_harmonization_labels=True)

collator = GenCollator(tokenizer)

trainloader = DataLoader(test_dataset, batch_size=1, shuffle=True, collate_fn=collator)

In [3]:
config = AutoConfig.from_pretrained(
    "gpt2",
    vocab_size=len(tokenizer.vocab),
    n_positions=512,
    n_layer=8,
    n_head=8,
    pad_token_id=tokenizer.vocab[tokenizer.pad_token],
    bos_token_id=tokenizer.vocab[tokenizer.bos_token],
    eos_token_id=tokenizer.vocab[tokenizer.eos_token],
    n_embd=512
)

model = GPT2LMHeadModel(config)

model_path = 'saved_models/gpt/ChordSymbolTokenizer/ChordSymbolTokenizer.pt'

# device_name = 'cuda:0'
device_name = 'cpu'
if device_name == 'cpu':
    device = torch.device('cpu')
else:
    if torch.cuda.is_available():
        device = torch.device(device_name)
    else:
        print('Selected device not available: ' + device_name)

checkpoint = torch.load(model_path, map_location=device_name, weights_only=True)
model.load_state_dict(checkpoint)
model.eval()
model.to(device)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(547, 512)
    (wpe): Embedding(512, 512)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-7): 8 x GPT2Block(
        (ln_1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D(nf=1536, nx=512)
          (c_proj): Conv1D(nf=512, nx=512)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=2048, nx=512)
          (c_proj): Conv1D(nf=512, nx=2048)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=512, out_features=547, bias=False)
)

In [4]:
batch = next(iter(trainloader))

  return self.iter().getElementsByClass(classFilterList)


In [5]:
class SearchNode:
    def __init__(self, tokens, logprob, heuristic, parent=None):
        self.tokens = tokens  # shape: (1, seq_len)
        self.logprob = logprob
        self.heuristic = heuristic
        self.parent = parent
        self.tried_tokens = set()  # token IDs already expanded

    def __lt__(self, other):
        return (self.logprob + self.heuristic) > (other.logprob + other.heuristic)
# end class SearchNode

class AStar:
    def __init__(self, model, tokenizer, input_ids, constraint_ids, max_length=512, beam_width=10, lookahead_k=5):
        self.model = model
        self.tokenizer = tokenizer
        self.input_ids = input_ids
        self.constraint_ids = constraint_ids.tolist()
        self.max_length = max_length
        self.beam_width = beam_width
        self.lookahead_k = lookahead_k
        self.eos_token_id = tokenizer.eos_token_id
        self.eos_token = tokenizer.eos_token
        self.constraint_tokens_breakdown()
    # end init

    def constraint_tokens_breakdown(self):
        tokens = self.tokenizer.convert_ids_to_tokens(self.constraint_ids)
        bar_count = 0
        chord_tokens = []
        i = 0
        while i < len(tokens):
            tok = tokens[i]
            if 'bar' in tok or 'fill' in tok or '/m' in tok or '<h>' in tok:
                if 'bar' in tok:
                    bar_count += 1
                i += 1
            else:
                # we should have arrived in a position token
                position_token = tokens[i]
                # the remaining tokens should be the chord
                i += 1
                while i < len(tokens) and \
                    'fill' not in tokens[i] and \
                     'bar' not in tokens[i] and \
                     '</s>' not in tokens[i]:
                    chord_tokens.append( tokens[i] )
                    i += 1
                break
        self.constraint_bar = bar_count
        self.position_token = position_token
        self.chord_tokens = chord_tokens
    # end constraint_tokens_breakdown

    def constraint_checker(self, input_tokens):
        tokens = self.tokenizer.convert_ids_to_tokens(input_tokens.tolist())
        start_harmonization_index = tokens.index('<h>')
        tokens = tokens[start_harmonization_index:]
        print(f'constraint_checker: {tokens}')
        
        bar_count = 0
        found = False
        i = 0
        while i < len(tokens):
            tok = tokens[i]
            if tok == "<bar>":
                bar_count += 1
            elif bar_count == self.constraint_bar and tok == self.position_token:
                j = 0
                found = True
                while j < len(self.chord_tokens):
                    if i + j + 1 < len(tokens):
                        if tokens[i + j + 1] != self.chord_tokens[j]:
                            found = False
                            break
                    else:
                        found = False
                        break
                    j += 1
                # break
                # no break, we need to keep counting bars to check premature ending
            i += 1
        # if the sequence has reached eos and the constraint has not been met, it should fail
        if tokens[-1] == self.eos_token:
            condition = found
            print(f'checker EOS: {condition}')
        elif len(tokens) == self.max_length:
            condition = False
            print(f'checker PRE: {condition}')
        else:
            condition = found or bar_count < self.constraint_bar  # Only reject if we're past bar 3 and it's missing
            print(f'checker PRE: {condition}')
        return condition
    # end constraint_checker

    def expand_node(self, node):
        with torch.no_grad():
            output = model(node.tokens.to(model.device), return_dict=True)
            logits = output.logits[:, -1, :]

            # Mask out already visited tokens before softmax
            if node.tried_tokens:
                mask = torch.full_like(logits, 0.0)
                mask[:, list(node.tried_tokens)] = float('-inf')
                logits = logits + mask

            probs = F.log_softmax(logits, dim=-1)
            topk_probs, topk_ids = torch.topk(probs, self.lookahead_k, dim=-1)

        new_nodes = []
        for i in range(self.lookahead_k):
            token_id = topk_ids[0, i].item()
            token_prob = topk_probs[0, i].item()
            node.tried_tokens.add(token_id)

            new_tokens = torch.cat([node.tokens.to(model.device), topk_ids[0, i].unsqueeze(0).unsqueeze(0)], dim=-1).to(model.device)

            if not self.constraint_checker(new_tokens[0]):
                continue

            new_logprob = node.logprob + token_prob
            new_node = SearchNode(new_tokens, new_logprob, 0.0, parent=node)
            new_nodes.append(new_node)
        return new_nodes
    # end expand_node

    def decode(self):
        initial_node = SearchNode(tokens=self.input_ids, logprob=0.0, heuristic=0.0)
        open_set = [initial_node]
        finished = []
        while open_set:
            current = heapq.heappop(open_set)

            # if current.tokens.shape[-1] >= self.max_length or (
            #     self.eos_token_id and current.tokens[0, -1].item() == self.eos_token_id
            # ):
            #     finished.append(current)
            #     continue
            if current.tokens[0, -1].item() == self.eos_token_id and \
                self.constraint_checker(current.tokens):
                finished.append(current)
                continue
            if current.tokens.shape[-1] >= self.max_length:
                continue

            # Expand current node
            children = self.expand_node(current)

            if children:
                for child in children:
                    heapq.heappush(open_set, child)
            else:
                # No valid expansions – backtrack to unvisited options
                back = current.parent
                while back:
                    # Re-expand from back with unvisited options
                    more_options = self.expand_node(back)
                    if more_options:
                        for opt in more_options:
                            heapq.heappush(open_set, opt)
                        break
                    back = back.parent

            # Prune open set
            open_set = sorted(open_set, reverse=True)[:self.beam_width]

        if not finished:
            raise RuntimeError("No valid sequence could be generated under constraints.")

        best = sorted(finished, key=lambda x: x.logprob + x.heuristic, reverse=True)[0]
        return best.tokens, finished
    # end decode
# end class AStar

In [6]:
all_ids = batch['input_ids']
print(all_ids)
melody_end_index = all_ids[0].tolist().index( tokenizer.vocab['</m>'] )
harmony_start_index = all_ids[0].tolist().index( tokenizer.vocab['<h>'] )
print(melody_end_index)
constraint_ids = all_ids[0][melody_end_index:]
print(constraint_ids)
# start_harmony_position = np.where( all_ids == harmony_start_index )[0][0]
input_ids = all_ids[0][:(harmony_start_index+2)]
input_ids = input_ids.reshape(1, -1)
print(input_ids.shape)

tensor([[  2,   6, 183,  98,  54, 110,  56, 126,  56,   6,  98,  56, 122,  51,
           6,  98,  54, 110,  56, 126,  56,   6,  98,  56, 122,  49,   6,  98,
          54, 110,  56, 122,  49,   6,  98,  49, 122,  51, 126,  44,   6,  98,
          44,   6,  98,  44,   6,  98,  54, 110,  56, 126,  56,   6,  98,  56,
         122,  51,   6,  98,  54, 110,  56, 126,  56,   6,  98,  56, 122,  61,
           6,  98,  58, 110,  54, 122,  49,   6,  98,  49, 122,  51, 126,  56,
           6,  98,  56,   6,  98,  56,   8,   6,   9,   6,   9,   6,   9,   6,
           9,   6,   9,   6,   9,   6,   9,   6,   9,   6,   9,   6,   9,   6,
           9,   6,   9,   6,   9,   6,  98, 344,   9,   6,   9,   6,   9,   3,
           7,   6,  98, 403,   6,  98, 403,   6,  98, 286,   6,  98, 286,   6,
          98, 344,   6,  98, 344,   6,  98, 200,   6,  98, 200,   6,  98, 403,
           6,  98, 403,   6,  98, 286,   6,  98, 286,   6,  98, 344,   6,  98,
         344,   6,  98, 200,   6,  98, 200,   3]])
9

In [10]:
astar = AStar( model, tokenizer, input_ids, constraint_ids, max_length=512, beam_width=10, lookahead_k=2 )

In [11]:
print(astar.constraint_bar)
print(astar.position_token)
print(astar.chord_tokens)

14
position_0x00
['F:maj']


In [12]:
generated_ids, finished = astar.decode()
generated_tokens = []

constraint_checker: ['<h>', '<bar>', 'position_0x00']
checker PRE: True
constraint_checker: ['<h>', '<bar>', '<bar>']
checker PRE: True
constraint_checker: ['<h>', '<bar>', '<bar>', 'position_0x00']
checker PRE: True
constraint_checker: ['<h>', '<bar>', '<bar>', '<bar>']
checker PRE: True
constraint_checker: ['<h>', '<bar>', '<bar>', '<bar>', '<bar>']
checker PRE: True
constraint_checker: ['<h>', '<bar>', '<bar>', '<bar>', 'position_0x00']
checker PRE: True
constraint_checker: ['<h>', '<bar>', '<bar>', '<bar>', 'position_0x00', 'F:maj']
checker PRE: True
constraint_checker: ['<h>', '<bar>', '<bar>', '<bar>', 'position_0x00', 'D:min']
checker PRE: True
constraint_checker: ['<h>', '<bar>', '<bar>', '<bar>', 'position_0x00', 'D:min', '<bar>']
checker PRE: True
constraint_checker: ['<h>', '<bar>', '<bar>', '<bar>', 'position_0x00', 'D:min', 'position_3x50']
checker PRE: True
constraint_checker: ['<h>', '<bar>', '<bar>', '<bar>', 'position_0x00', 'D:min', 'position_3x50', 'F:maj']
checker P

RuntimeError: No valid sequence could be generated under constraints.

In [None]:
print(generated_ids.shape)

torch.Size([1, 512])


In [None]:
print(len(finished))
print(finished[0].tokens)

15
tensor([[  2,   6, 183,  98,  51, 102,  51, 106,  49, 110,  46, 114,  49, 118,
          49, 122,  51,   6,  98,  49, 106,  49, 114,  49,   6,  98,  51, 102,
          51, 106,  49, 110,  46, 114,  49, 118,  49, 122,  51,   6,  98,  49,
         106,  49, 114,  49,   6,  98,  51, 102,  51, 106,  49, 110,  46, 114,
          49, 118,  49, 122,  51,   6,  98,  61, 102,  58, 106,  54, 110,  56,
         114,  51,   6,  98,  51, 102,   4, 106,  51, 110,   4, 114,  51, 118,
          51, 122,   4, 126,  51,   6,  98,  51, 102,  51, 106,  51, 114,  51,
         122,  49,   6,  98,  56, 102,  58, 110,  68, 112,  70, 118,  68, 126,
          66,   6,  98,  65, 114,  61, 118,  58, 122,  54, 126,  56,   6,  98,
          56, 122,  54, 126,  51,   6,  98,  51, 114,  54, 118,  54, 122,  56,
         126,  56,   6,  98,  58, 106,  54, 110,  51, 114,  49,   6,  98,  49,
         110,  49, 114,  61, 118,  58, 122,  54, 126,  56,   6,  98,  56, 106,
          54, 110,  56, 122,  54, 126,  56,   6, 

In [None]:
for i in generated_ids[0]:
    generated_tokens.append( tokenizer.ids_to_tokens[ int(i) ].replace(' ','x') )
print(generated_tokens[harmony_start_index:])

['<h>', '<bar>', 'position_0x50', 'A:min', 'position_0x75', 'E:min', 'position_1x25', 'G:maj', '<bar>', 'position_0x25', 'C:maj', 'position_1x50', 'A:min', 'position_1x75', 'E:min', 'position_3x00', 'C:maj', 'position_2x50', 'F:maj', 'position_3x00', 'D:min', 'position_2x50', 'G:maj', 'position_3x00', 'A:min', 'position_2x50', 'F:maj', 'position_3x75', 'D:min', 'position_5x50', 'A:min', 'position_5x00', 'F:maj', 'position_5x00', 'D:min', 'position_1x00', 'F:maj', 'position_5x50', 'G:maj', 'position_3x50', 'D:min', 'position_3x00', 'A:min', 'position_5x50', 'F:maj', 'position_3x50', 'C:maj', 'position_3x00', 'D:min', 'position_3x50', 'F:maj', 'position_4x00', 'A:min', 'position_4x50', 'G:maj', 'position_5x25', 'D:min', 'G:maj', 'position_3x00', 'C:maj', 'position_0x25', 'E:min', 'position_1x00', 'D:min', 'position_1x75', 'A:min', 'position_2x00', 'D:min', '<bar>', 'position_2x25', 'G:maj', '<bar>', '<bar>', 'position_0x75', 'A:min', '<bar>', 'position_2x00', 'D:min', 'position_2x50', 'E

In [None]:
t = generated_tokens[harmony_start_index:]
line = []
for i in range(len(t)):
    line.append(t[i])
    if i+1 < len(t) and 'bar' in t[i+1]:
        print(line)
        line = []
print(line)

['<h>']
['<bar>', 'position_0x50', 'A:min', 'position_0x75', 'E:min', 'position_1x25', 'G:maj']
['<bar>', 'position_0x25', 'C:maj', 'position_1x50', 'A:min', 'position_1x75', 'E:min', 'position_3x00', 'C:maj', 'position_2x50', 'F:maj', 'position_3x00', 'D:min', 'position_2x50', 'G:maj', 'position_3x00', 'A:min', 'position_2x50', 'F:maj', 'position_3x75', 'D:min', 'position_5x50', 'A:min', 'position_5x00', 'F:maj', 'position_5x00', 'D:min', 'position_1x00', 'F:maj', 'position_5x50', 'G:maj', 'position_3x50', 'D:min', 'position_3x00', 'A:min', 'position_5x50', 'F:maj', 'position_3x50', 'C:maj', 'position_3x00', 'D:min', 'position_3x50', 'F:maj', 'position_4x00', 'A:min', 'position_4x50', 'G:maj', 'position_5x25', 'D:min', 'G:maj', 'position_3x00', 'C:maj', 'position_0x25', 'E:min', 'position_1x00', 'D:min', 'position_1x75', 'A:min', 'position_2x00', 'D:min']
['<bar>', 'position_2x25', 'G:maj']
['<bar>']
['<bar>', 'position_0x75', 'A:min']
['<bar>', 'position_2x00', 'D:min', 'position_2x5