In [1]:
from harmony_tokenizers_m21 import ChordSymbolTokenizer, PitchClassTokenizer, MelodyPitchTokenizer, MergedMelHarmTokenizer
from data_utils import StructGPTMelHarmDataset, GenCollator
from torch.utils.data import DataLoader
import torch
from transformers import AutoConfig, GPT2LMHeadModel
import numpy as np

import heapq
import torch
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
test_dir = '/media/maindisk/maximos/data/hooktheory_test'

cstok = ChordSymbolTokenizer()
pctok = PitchClassTokenizer()
meltok = MelodyPitchTokenizer()
tokenizer = MergedMelHarmTokenizer(meltok, cstok)

test_dataset = StructGPTMelHarmDataset(test_dir, tokenizer, max_length=512, num_bars=16, return_harmonization_labels=True)

collator = GenCollator(tokenizer)

trainloader = DataLoader(test_dataset, batch_size=1, shuffle=True, collate_fn=collator)

In [3]:
config = AutoConfig.from_pretrained(
    "gpt2",
    vocab_size=len(tokenizer.vocab),
    n_positions=512,
    n_layer=8,
    n_head=8,
    pad_token_id=tokenizer.vocab[tokenizer.pad_token],
    bos_token_id=tokenizer.vocab[tokenizer.bos_token],
    eos_token_id=tokenizer.vocab[tokenizer.eos_token],
    n_embd=512
)

model = GPT2LMHeadModel(config)

model_path = 'saved_models/gpt/ChordSymbolTokenizer/ChordSymbolTokenizer.pt'

# device_name = 'cuda:0'
device_name = 'cpu'
if device_name == 'cpu':
    device = torch.device('cpu')
else:
    if torch.cuda.is_available():
        device = torch.device(device_name)
    else:
        print('Selected device not available: ' + device_name)

checkpoint = torch.load(model_path, map_location=device_name, weights_only=True)
model.load_state_dict(checkpoint)
model.eval()
model.to(device)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(547, 512)
    (wpe): Embedding(512, 512)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-7): 8 x GPT2Block(
        (ln_1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D(nf=1536, nx=512)
          (c_proj): Conv1D(nf=512, nx=512)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=2048, nx=512)
          (c_proj): Conv1D(nf=512, nx=2048)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=512, out_features=547, bias=False)
)

In [165]:
batch = next(iter(trainloader))

In [160]:
class SearchNode:
    def __init__(self, tokens, logprob, heuristic, parent=None):
        self.tokens = tokens  # shape: (1, seq_len)
        self.logprob = logprob
        self.heuristic = heuristic
        self.parent = parent
        self.tried_tokens = set()  # token IDs already expanded

    def __lt__(self, other):
        return (self.logprob + self.heuristic) > (other.logprob + other.heuristic)
# end class SearchNode

class AStar:
    def __init__(self, model, tokenizer, input_ids, constraint_ids, max_length=512, beam_width=10, lookahead_k=5):
        self.model = model
        self.tokenizer = tokenizer
        self.input_ids = input_ids
        self.constraint_ids = constraint_ids.tolist()
        self.max_length = max_length
        self.beam_width = beam_width
        self.lookahead_k = lookahead_k
        self.eos_token_id = tokenizer.eos_token_id
        self.eos_token = tokenizer.eos_token
        self.constraint_tokens_breakdown()
    # end init

    def constraint_tokens_breakdown(self):
        tokens = self.tokenizer.convert_ids_to_tokens(self.constraint_ids)
        bar_count = 0
        chord_tokens = []
        i = 0
        while i < len(tokens):
            tok = tokens[i]
            if 'bar' in tok or 'fill' in tok or '/m' in tok or '<h>' in tok:
                if 'bar' in tok:
                    bar_count += 1
                i += 1
            else:
                # we should have arrived in a position token
                position_token = tokens[i]
                # the remaining tokens should be the chord
                i += 1
                while i < len(tokens) and \
                    'fill' not in tokens[i] and \
                     'bar' not in tokens[i] and \
                     '</s>' not in tokens[i]:
                    chord_tokens.append( tokens[i] )
                    i += 1
                break
        self.constraint_bar = bar_count
        self.position_token = position_token
        self.chord_tokens = chord_tokens
    # end constraint_tokens_breakdown

    def consistency_checker(self, tokens):
        consistent = True
        current_bar_time = None
        for i in range( 1, len(tokens), 1 ):
            # no melody-related tokens here
            if 'P:' in tokens[i] or 'rest' in tokens[i] or \
                'fill' in tokens[i] or '<s>' in tokens[i] \
                or 'ts_' in tokens[i]:
                consistent = False
                break
            # no two consequtive position tokens
            if 'position_' in tokens[i] and 'position_' in tokens[i-1]:
                consistent = False
                break
            # only chord / pc after position
            if 'position_' in tokens[i-1] and (
                'bar' in tokens[i] or '</s>' in tokens[i] or 
                '<h>' in tokens[i] or '</m>' in tokens[i]
            ):
                consistent = False
                break
            # only position, bar, </m>, </s> and <h> after bar
            if 'bar' in tokens[i-1]:
                if 'position_' in tokens[i]:
                    current_bar_time = float( tokens[i].split('_')[-1].replace('x', '.') )
                elif 'bar' not in tokens[i] and '</m>' not in tokens[i] and \
                    '</s>' not in tokens[i] and '<h>' not in tokens[i]:
                    consistent = False
                    break
            # time should increase within bar
            if 'position_' in tokens[i] and 'bar' not in tokens[i-1]:
                if float( tokens[i].split('_')[-1].replace('x', '.') ) <= current_bar_time:
                    consistent = False
                    break
                else:
                    current_bar_time = float( tokens[i].split('_')[-1].replace('x', '.') )
        return consistent
        # end consistency_checker

    def constraint_checker(self, input_tokens):
        tokens = self.tokenizer.convert_ids_to_tokens(input_tokens.tolist())
        start_harmonization_index = tokens.index('<h>')
        tokens = tokens[start_harmonization_index:]
        print(f'constraint_checker: {tokens}')
        print(f'num_bars: {tokens.count('<bar>')} - num_tokens: {len(tokens)}')
        if not self.consistency_checker(tokens):
            print('inconsistent')
            return False
        
        bar_count = 0
        found = False
        i = 0
        while i < len(tokens):
            tok = tokens[i]
            if tok == "<bar>":
                bar_count += 1
            elif bar_count == self.constraint_bar and tok == self.position_token:
                j = 0
                found = True
                while j < len(self.chord_tokens):
                    if i + j + 1 < len(tokens):
                        if tokens[i + j + 1] != self.chord_tokens[j]:
                            found = False
                            break
                    else:
                        break
                    j += 1
                # break
                # no break, we need to keep counting bars to check premature ending
            i += 1
        # if the sequence has reached eos and the constraint has not been met, it should fail
        if tokens[-1] == self.eos_token:
            condition = found
            # print(f'checker EOS: {condition}')
        elif len(tokens) == self.max_length:
            condition = False
            # print(f'checker PRE: {condition}')
        else:
            condition = found or bar_count < self.constraint_bar  # Only reject if we're past bar 3 and it's missing
            # print(f'checker PRE: {condition}')
        return condition
    # end constraint_checker

    def expand_node(self, node):
        with torch.no_grad():
            output = model(node.tokens.to(model.device), return_dict=True)
            logits = output.logits[:, -1, :]

            # Mask out already visited tokens before softmax
            if node.tried_tokens:
                mask = torch.full_like(logits, 0.0)
                mask[:, list(node.tried_tokens)] = float('-inf')
                logits = logits + mask

            probs = F.log_softmax(logits, dim=-1)
            topk_probs, topk_ids = torch.topk(probs, self.lookahead_k, dim=-1)

        new_nodes = []
        for i in range(self.lookahead_k):
            token_id = topk_ids[0, i].item()
            token_prob = topk_probs[0, i].item()
            node.tried_tokens.add(token_id)

            new_tokens = torch.cat([node.tokens.to(model.device), topk_ids[0, i].unsqueeze(0).unsqueeze(0)], dim=-1).to(model.device)

            if not self.constraint_checker(new_tokens[0]):
                continue
            
            new_logprob = node.logprob + token_prob
            print('new_logprob:', new_logprob)
            new_node = SearchNode(new_tokens, new_logprob, 0.0, parent=node)
            new_nodes.append(new_node)
        return new_nodes
    # end expand_node

    def decode(self):
        initial_node = SearchNode(tokens=self.input_ids, logprob=0.0, heuristic=0.0)
        open_set = [initial_node]
        finished = []
        while open_set:
            current = heapq.heappop(open_set)

            # if current.tokens.shape[-1] >= self.max_length or (
            #     self.eos_token_id and current.tokens[0, -1].item() == self.eos_token_id
            # ):
            #     finished.append(current)
            #     continue
            if current.tokens[0, -1].item() == self.eos_token_id and \
                self.constraint_checker(current.tokens[0]):
                finished.append(current)
                continue
            if current.tokens.shape[-1] >= self.max_length:
                continue

            # Expand current node
            print('children')
            children = self.expand_node(current)

            if children:
                for child in children:
                    heapq.heappush(open_set, child)
            else:
                # No valid expansions – backtrack to unvisited options
                back = current.parent
                while back:
                    # Re-expand from back with unvisited options
                    print('parents')
                    more_options = self.expand_node(back)
                    if more_options:
                        for opt in more_options:
                            heapq.heappush(open_set, opt)
                        break
                    back = back.parent

            # Prune open set
            open_set = sorted(open_set, reverse=False)[:self.beam_width]
            print('finished:', len(finished))
            # just keep the first one found
            if len(finished) >= 1:
                break

        if not finished:
            raise RuntimeError("No valid sequence could be generated under constraints.")

        best = sorted(finished, key=lambda x: x.logprob + x.heuristic, reverse=True)[0]
        return best.tokens, finished
    # end decode
# end class AStar

In [166]:
all_ids = batch['input_ids']
print(all_ids)
melody_end_index = all_ids[0].tolist().index( tokenizer.vocab['</m>'] )
harmony_start_index = all_ids[0].tolist().index( tokenizer.vocab['<h>'] )
print(melody_end_index)
constraint_ids = all_ids[0][melody_end_index:]
print(constraint_ids)
# start_harmony_position = np.where( all_ids == harmony_start_index )[0][0]
input_ids = all_ids[0][:(harmony_start_index+2)]
input_ids = input_ids.reshape(1, -1)
print(input_ids.shape)

tensor([[  2,   6, 186,  98,   4, 110,  56, 122,  53, 134,   4,   6,  98,   4,
         106,  56, 118,  53, 122,  51, 130,  53, 134,  51, 142,  53,   6,  98,
          51, 110,  48, 122,  45, 130,  41, 142,  39,   6,  98,  39, 110,   4,
         126,   4, 130,  53, 134,  53, 142,  51,   6,  98,  53, 106,  46, 110,
          46, 118,  46, 126,   4, 130,  53, 134,  53, 142,  51,   6,  98,  53,
         106,  49, 110,  49, 122,  49, 134,  49,   6,  98,  54, 110,  53, 122,
          51, 130,  49, 142,  48,   6,  98,  48, 106,  56, 110,  55, 118,  56,
         122,  58, 130,  56, 134,  55, 142,  56,   6,  98,  53, 110,  56, 122,
          53, 134,   4,   6,  98,   4, 106,  56, 118,  53, 122,  51, 130,  53,
         134,  51, 142,  53,   6,  98,  51, 110,  48, 122,  45, 130,  41, 142,
          39,   6,  98,  39, 110,   4, 126,   4, 130,  53, 134,  53, 142,  51,
           6,  98,  53, 106,  46, 114,  46, 122,  46, 130,  53, 134,  53, 142,
          51,   6,  98,  53, 106,  49, 110,  49, 118

In [167]:
astar = AStar( model, tokenizer, input_ids, constraint_ids, max_length=512, beam_width=20, lookahead_k=10 )

In [168]:
print(astar.constraint_bar)
print(astar.position_token)
print(astar.chord_tokens)

13
position_0x00
['A:min']


In [169]:
generated_ids, finished = astar.decode()
generated_tokens = []

children
constraint_checker: ['<h>', '<bar>', 'position_0x00']
num_bars: 1 - num_tokens: 3
new_logprob: -0.4387596845626831
constraint_checker: ['<h>', '<bar>', '<bar>']
num_bars: 2 - num_tokens: 3
new_logprob: -1.4307745695114136
constraint_checker: ['<h>', '<bar>', 'position_3x00']
num_bars: 1 - num_tokens: 3
new_logprob: -2.792074680328369
constraint_checker: ['<h>', '<bar>', 'position_1x50']
num_bars: 1 - num_tokens: 3
new_logprob: -4.241106033325195
constraint_checker: ['<h>', '<bar>', 'position_4x50']
num_bars: 1 - num_tokens: 3
new_logprob: -4.481249809265137
constraint_checker: ['<h>', '<bar>', 'position_2x00']
num_bars: 1 - num_tokens: 3
new_logprob: -5.006075859069824
constraint_checker: ['<h>', '<bar>', 'position_4x00']
num_bars: 1 - num_tokens: 3
new_logprob: -5.595470428466797
constraint_checker: ['<h>', '<bar>', 'position_2x50']
num_bars: 1 - num_tokens: 3
new_logprob: -5.6255693435668945
constraint_checker: ['<h>', '<bar>', 'position_3x50']
num_bars: 1 - num_tokens: 3
ne

KeyboardInterrupt: 

In [151]:
print(generated_ids.shape)

torch.Size([1, 272])


In [152]:
print(len(finished))
print(finished[0].tokens)

2
tensor([[  2,   6, 183,  98,   4, 106,  49, 110,  48, 114,  49, 122,  56, 126,
          53,   6,  98,  53, 102,  49, 110,  46, 118,   4, 122,  49, 126,  46,
           6,  98,  52, 106,  51, 110,  49, 126,  49,   6,  98,   4, 102,  45,
         106,  49, 110,  45, 114,  49, 122,  51,   6,  98,   4, 106,  49, 110,
          51, 114,  53, 122,  54, 126,  53,   6,  98,  53, 102,   4, 106,  49,
         110,  56, 118,  54, 126,  54,   6,  98,  53, 114,  53, 118,  49, 122,
          46, 126,  49,   6,  98,  49, 114,  49, 118,   4,   6,  98,   4, 106,
          49, 110,  48, 114,  49, 122,  56, 126,  53,   6,  98,  53, 102,  49,
         110,  46, 118,   4, 122,  49, 126,  46,   6,  98,  52, 106,  51, 110,
          49, 122,   4,   6,  98,   4, 102,  45, 106,  49, 110,  45, 114,  52,
         122,  51,   6,  98,   4, 106,  49, 110,  51, 114,  53, 122,  54, 126,
          53,   6,  98,  53, 102,   4, 106,  49, 110,  56, 118,  54, 122,  53,
         126,  54,   6,  98,  53, 114,   4, 118,  

In [170]:
for i in generated_ids[0]:
    generated_tokens.append( tokenizer.ids_to_tokens[ int(i) ].replace(' ','x') )
print(generated_tokens[harmony_start_index:])

['E:maj', '<bar>', 'position_0x00', 'A:min7', '<bar>', 'position_0x00', 'F:maj', '<bar>', 'position_0x00', 'C:maj', '<bar>', 'position_0x00', 'E:maj', '<bar>', 'position_0x00', 'A:min7', '<bar>', 'position_0x00', 'F:maj', '<bar>', 'position_0x00', 'C:maj', '<bar>', 'position_0x00', 'E:maj', '</s>', '<s>', '<bar>', 'ts_4x4', 'position_0x00', '<rest>', 'position_1x00', 'P:60', 'position_1x50', 'P:59', 'position_2x00', 'P:60', 'position_3x00', 'P:67', 'position_3x50', 'P:64', '<bar>', 'position_0x00', 'P:64', 'position_0x50', 'P:60', 'position_1x50', 'P:57', 'position_2x50', '<rest>', 'position_3x00', 'P:60', 'position_3x50', 'P:57', '<bar>', 'position_0x00', 'P:63', 'position_1x00', 'P:62', 'position_1x50', 'P:60', 'position_3x50', 'P:60', '<bar>', 'position_0x00', '<rest>', 'position_0x50', 'P:56', 'position_1x00', 'P:60', 'position_1x50', 'P:56', 'position_2x00', 'P:60', 'position_3x00', 'P:62', '<bar>', 'position_0x00', '<rest>', 'position_1x00', 'P:60', 'position_1x50', 'P:62', 'posi

In [171]:
t = generated_tokens[harmony_start_index:]
line = []
for i in range(len(t)):
    line.append(t[i])
    if i+1 < len(t) and 'bar' in t[i+1]:
        print(line)
        line = []
print(line)

['E:maj']
['<bar>', 'position_0x00', 'A:min7']
['<bar>', 'position_0x00', 'F:maj']
['<bar>', 'position_0x00', 'C:maj']
['<bar>', 'position_0x00', 'E:maj']
['<bar>', 'position_0x00', 'A:min7']
['<bar>', 'position_0x00', 'F:maj']
['<bar>', 'position_0x00', 'C:maj']
['<bar>', 'position_0x00', 'E:maj', '</s>', '<s>']
['<bar>', 'ts_4x4', 'position_0x00', '<rest>', 'position_1x00', 'P:60', 'position_1x50', 'P:59', 'position_2x00', 'P:60', 'position_3x00', 'P:67', 'position_3x50', 'P:64']
['<bar>', 'position_0x00', 'P:64', 'position_0x50', 'P:60', 'position_1x50', 'P:57', 'position_2x50', '<rest>', 'position_3x00', 'P:60', 'position_3x50', 'P:57']
['<bar>', 'position_0x00', 'P:63', 'position_1x00', 'P:62', 'position_1x50', 'P:60', 'position_3x50', 'P:60']
['<bar>', 'position_0x00', '<rest>', 'position_0x50', 'P:56', 'position_1x00', 'P:60', 'position_1x50', 'P:56', 'position_2x00', 'P:60', 'position_3x00', 'P:62']
['<bar>', 'position_0x00', '<rest>', 'position_1x00', 'P:60', 'position_1x50', 