In [1]:
import time
import torch
from utils import *
import re
from config import *
from transformers import GPT2Config
import argparse
import os
from tqdm import tqdm
import requests

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
patchilizer = Patchilizer()
patch_config = GPT2Config(
        num_hidden_layers=PATCH_NUM_LAYERS,
        max_length=PATCH_LENGTH,
        max_position_embeddings=PATCH_LENGTH,
        vocab_size=1,
    )
char_config = GPT2Config(
        num_hidden_layers=CHAR_NUM_LAYERS,
        max_length=PATCH_SIZE,
        max_position_embeddings=PATCH_SIZE,
        vocab_size=128,
    )
model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS)

In [7]:
weights_file = 'weights/weights.pth'
checkpoint = torch.load(weights_file)
fixed_weights = {
    k: v
    for k, v in checkpoint["model"].items()
    if not re.search("\.attn.bias|\.attn.masked_bias", k)
}
model.load_state_dict(fixed_weights)
model = model.to(device).eval()


In [10]:
with open("prompt.txt", "r") as f:
    prompt = f.read()


S:2
B:9
E:4
B:9
L:1/8
M:4/4
K:D
 de |"D" 



In [83]:
prompt = '''
S:2
E:4
B:9
L:1/8
M:4/4
K:D
'''

In [84]:
tunes = "" 
num_tunes = 1 
max_patch = 128
top_p = .8 
top_k = 10 
temperature = 1.2
seed = None
show_control_code = True 



In [121]:

prompt

'\nS:2\nE:4\nB:9\nL:1/8\nM:4/4\nK:D\n'

In [85]:
input_patches = torch.tensor([patchilizer.encode(prompt,add_special_patches=True)[:-1]],device=device)

In [86]:
print(input_patches.shape)
for x in input_patches[0]:
    print(f"{patchilizer.patch2bar(x)}->{x}")
# each input_patch[0,i] corresponds to a line of the prompt (encoded), with bos and eos appended
# the first patch input_patches[0,0] is the special indicator for the start of the tune
    

torch.Size([1, 7, 32])
->tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 2], device='cuda:0')
S:2
->tensor([ 1, 83, 58, 50, 10,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       device='cuda:0')
E:4
->tensor([ 1, 69, 58, 52, 10,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       device='cuda:0')
B:9
->tensor([ 1, 66, 58, 57, 10,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       device='cuda:0')
L:1/8
->tensor([ 1, 76, 58, 49, 47, 56, 10,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       device='cuda:0')
M:4/4
->tensor([ 1, 77, 58, 52, 47, 52, 10,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,

In [87]:
prefix = patchilizer.decode(input_patches[0])
print(prefix,"\n")
remaining_tokens = prompt[len(prefix):]
print(remaining_tokens)

S:2
E:4
B:9
L:1/8
M:4/4
K:D
 





In [88]:
tokens = torch.tensor([patchilizer.bos_token_id]+[ord(c) for c in remaining_tokens],device=device)

In [89]:
tokens

tensor([ 1, 10], device='cuda:0')

In [107]:
for i in range(10):
    p, s = model.generate(
        input_patches,
        tokens,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        seed=seed,
    )
    print(patchilizer.patch2bar(p))

 dB |
 F>G |
 dA (3AAA A2 AB |
 de |
 dB |
 AB/c/ |
 d3 c dAFA |
 (A F2) (A FE)DE |
 AG |
 F>E |


In [108]:
## compute bar embedding 
# start with input patches 1 x L x 32 with special patches at front and back
# input_patches = torch.tensor([patchilizer.encode(prompt,add_special_patches=True)[:-1]],device=device)
# omit the last special patch 
input_patches.shape

torch.Size([1, 7, 32])

In [110]:
# reshape to (len(patches, -1, 32))
patches = input_patches.reshape(len(input_patches),-1, PATCH_SIZE)
patches.shape

torch.Size([1, 7, 32])

In [119]:
patches.shape

torch.Size([1, 7, 32])

In [111]:
embedding = model.patch_level_decoder(patches)["last_hidden_state"]

In [112]:
embedding.shape

torch.Size([1, 7, 768])

In [125]:
input = torch.tensor([patchilizer.encode("K:D\n",add_special_patches=True)[:-1]])
embedding = model.patch_level_decoder(input)["last_hidden_state"]


In [126]:
embedding

tensor([[[ 0.0125,  0.0160,  0.1532,  ...,  0.0690,  0.1152,  0.0471],
         [ 0.2039, -0.7142,  1.0180,  ...,  1.7600,  0.2193, -1.0058]]],
       device='cuda:0', grad_fn=<ViewBackward0>)

In [165]:
with open("abcs/whiskey-before-breakfast.abc") as f:
    whiskey = f.read()

In [166]:
whiskey

'\n\nX:1\nL:1/8\nM:4/4\nK:D\n|: "D"DE FG A2 AA | AB AG FE DF | "G"G2 BG "D"F2 AF | "A"ED EF EC B,A, |\n"D"DE FG A2 AA | AB AG FE DF | "G"G2 BG "D"F2 AF | "A"ED EF "D"D2 A2 ::\n"D"A2 d2 d2 dd | f2 d2 B2 A2 | "Em (A)"e2 ef e2 ef | "A7" gf ed cB Ac |\n"D"d2 fd "A"c2 ec | "G"Bc dB "D"AF ED | "G"G2 BG "D"F2 AF | "A"ED EF "D"D2-D2 :|\n\n'

In [167]:
whiskey_patches = torch.tensor([patchilizer.encode(whiskey,add_special_patches=True)[:-1]],device=device)
whiskey_embedding = model.patch_level_decoder(whiskey_patches)["last_hidden_state"] 

In [168]:
for x in whiskey_patches[0]:
    print(f"{patchilizer.patch2bar(x)}->{x}")


->tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 2], device='cuda:0')
X:1
->tensor([ 1, 88, 58, 49, 10,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       device='cuda:0')
L:1/8
->tensor([ 1, 76, 58, 49, 47, 56, 10,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       device='cuda:0')
M:4/4
->tensor([ 1, 77, 58, 52, 47, 52, 10,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       device='cuda:0')
K:D
->tensor([ 1, 75, 58, 68, 10,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       device='cuda:0')
|: "D"DE FG A2 AA |->tensor([  1, 124,  58,  32,  34,  68,  34,  68,  69,  32,  70,  71,  32,  65,
         50,  32,  65,  65,  32, 124,   2,   0,   0,   0,   0,   

In [169]:
tokens = torch.tensor([patchilizer.bos_token_id],device=device)

In [170]:
new_token = model.char_level_decoder.generate(whiskey_embedding[0,1],tokens).argmax()
tokens = torch.cat([tokens,new_token.unsqueeze(0)],dim=0)

In [171]:
whiskey

'\n\nX:1\nL:1/8\nM:4/4\nK:D\n|: "D"DE FG A2 AA | AB AG FE DF | "G"G2 BG "D"F2 AF | "A"ED EF EC B,A, |\n"D"DE FG A2 AA | AB AG FE DF | "G"G2 BG "D"F2 AF | "A"ED EF "D"D2 A2 ::\n"D"A2 d2 d2 dd | f2 d2 B2 A2 | "Em (A)"e2 ef e2 ef | "A7" gf ed cB Ac |\n"D"d2 fd "A"c2 ec | "G"Bc dB "D"AF ED | "G"G2 BG "D"F2 AF | "A"ED EF "D"D2-D2 :|\n\n'

In [178]:
tortured_whiskey =[]
n_seed=None
for i in range(whiskey_embedding.shape[1]):
    token = None
    tokens = torch.tensor([patchilizer.bos_token_id],device=device)
    this_patch = whiskey_embedding[0,i]
    while token != patchilizer.eos_token_id and len(tokens)< PATCH_SIZE-1:
        prob = model.char_level_decoder.generate(this_patch,tokens).cpu().detach().numpy()
        prob = top_p_sampling(prob, top_p=top_p, return_probs=True)
        prob = top_k_sampling(prob, top_k=top_k, return_probs=True)
        token = temperature_sampling(prob, temperature=temperature, seed=n_seed)
        tokens = torch.cat([tokens,torch.tensor([token],device=device)],dim=0)
    tortured_whiskey.append(tokens)
print(patchilizer.decode(tortured_whiskey))

S:2
B:16
M:4/4
K:G
 FABA F2 ED | ABAG FDEF |"A" ED CB, A,2 A,A, |"A" E2 EF GF EF | 
"D" DD EF A2 AA | 
 AB AG FE DF | 
"G" G2 BG"D" F2 AF | 
"A" ED EF"D" D4 :: 
 ed ef d3 A | 
"G" B2 Bc BA Bc | 
"A" gf ed cB A2 :| 
"D" d2 dd f2 ff | 
"G" Bc dB"D" A2 FA | 
"G" G2 BG"D" F2 AF | 
"A" ED EF"D" D2 D2 :|


In [159]:
whiskey_embedding.shape[1]

25

In [183]:
print(whiskey)



X:1
L:1/8
M:4/4
K:D
|: "D"DE FG A2 AA | AB AG FE DF | "G"G2 BG "D"F2 AF | "A"ED EF EC B,A, |
"D"DE FG A2 AA | AB AG FE DF | "G"G2 BG "D"F2 AF | "A"ED EF "D"D2 A2 ::
"D"A2 d2 d2 dd | f2 d2 B2 A2 | "Em (A)"e2 ef e2 ef | "A7" gf ed cB Ac |
"D"d2 fd "A"c2 ec | "G"Bc dB "D"AF ED | "G"G2 BG "D"F2 AF | "A"ED EF "D"D2-D2 :|




In [184]:
with open("abcs/st-annes-reel.abc") as f:
    anne = f.read()
anne = "X:1\nM:4/4\n"+anne

In [185]:
anne_patches = torch.tensor([patchilizer.encode(anne,add_special_patches=True)[:-1]],device=device)
anne_embedding = model.patch_level_decoder(anne_patches)["last_hidden_state"] 

In [188]:
tortured_anne =[]
n_seed=None
for i in range(anne_embedding.shape[1]):
    token = None
    tokens = torch.tensor([patchilizer.bos_token_id],device=device)
    this_patch = anne_embedding[0,i]
    while token != patchilizer.eos_token_id and len(tokens)< PATCH_SIZE-1:
        prob = model.char_level_decoder.generate(this_patch,tokens).cpu().detach().numpy()
        prob = top_p_sampling(prob, top_p=top_p, return_probs=True)
        prob = top_k_sampling(prob, top_k=top_k, return_probs=True)
        token = temperature_sampling(prob, temperature=temperature, seed=n_seed)
        tokens = torch.cat([tokens,torch.tensor([token],device=device)],dim=0)
    tortured_anne.append(tokens)
print(patchilizer.decode(tortured_anne))

S:2
B:16
K:G
M:4/4
|: A2 FA DAFA |"D" A2 F2 F2 A2 |"G" B2 BG D2 D2 |"D" A2 FADFAF |"G" f2 fgfedB | A2 AFDFAd |1 
"G" B2 BA"A" cABc :| 
"D" d2 d2 d4 :| 
"Em" gfed"A7" cdef | 
"A7" gececege | 
"D" fded"A7" cAde :| 
"D" fdAd fdfa | 
"Em" aggf g2 ef | 
"A7" gfecAce | 
 fdec"D" d4 :|
