In [1]:
import os


os.environ['https_proxy'] = 'http://172.17.0.1:1081'
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [7]:
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import enable_progress_bars
import logging


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


enable_progress_bars()


def download_model_weights():
    weights_path = "weights_notagenx_p_size_16_p_length_1024_p_layers_20_h_size_1280.pth"
    local_weights_path = os.path.join('/models/NotaGen', weights_path)

    # Check if weights already exist locally
    if os.path.exists(local_weights_path):
        logger.info(f"Model weights already exist at {local_weights_path}")
        return local_weights_path

    logger.info("Downloading model weights from HuggingFace Hub...")
    try:
        # Download from HuggingFace
        downloaded_path = hf_hub_download(
            repo_id="ElectricAlexis/NotaGen",
            filename=weights_path,
            local_dir='/models/NotaGen',
            local_dir_use_symlinks=False,
            cache_dir='/models/.cache',
        )
        logger.info(f"Model weights downloaded successfully to {downloaded_path}")
        return downloaded_path
    except Exception as e:
        logger.error(f"Error downloading model weights: {str(e)}")
        raise RuntimeError(f"Failed to download model weights: {str(e)}")

download_model_weights()

INFO:__main__:Downloading model weights from HuggingFace Hub...
INFO:__main__:Model weights downloaded successfully to /models/NotaGen/weights_notagenx_p_size_16_p_length_1024_p_layers_20_h_size_1280.pth


'/models/NotaGen/weights_notagenx_p_size_16_p_length_1024_p_layers_20_h_size_1280.pth'

In [2]:
checkpoint_path = os.path.join('/models/NotaGen', 'weights_notagenx_p_size_16_p_length_1024_p_layers_20_h_size_1280.pth')
checkpoint_path

'/models/NotaGen/weights_notagenx_p_size_16_p_length_1024_p_layers_20_h_size_1280.pth'

In [3]:
import torch
from inference.utils import Patchilizer, NotaGenLMHeadModel
from inference.config import (INFERENCE_WEIGHTS_PATH, PATCH_NUM_LAYERS,
							PATCH_LENGTH, PATCH_SIZE, HIDDEN_SIZE,
							CHAR_NUM_LAYERS, PATCH_LENGTH)
from transformers import GPT2Config
from abctoolkit.utils import Barline_regexPattern
from abctoolkit.transpose import Note_list
from abctoolkit.duration import calculate_bartext_duration


Note_list = Note_list + ['z', 'x']

device = torch.device("cuda")

patchilizer = Patchilizer()

patch_config = GPT2Config(num_hidden_layers=PATCH_NUM_LAYERS,
                          max_length=PATCH_LENGTH,
                          max_position_embeddings=PATCH_LENGTH,
                          n_embd=HIDDEN_SIZE,
                          num_attention_heads=HIDDEN_SIZE // 64,
                          vocab_size=1)
byte_config = GPT2Config(num_hidden_layers=CHAR_NUM_LAYERS,
                         max_length=PATCH_SIZE + 1,
                         max_position_embeddings=PATCH_SIZE + 1,
                         hidden_size=HIDDEN_SIZE,
                         num_attention_heads=HIDDEN_SIZE // 64,
                         vocab_size=128)

model = NotaGenLMHeadModel(encoder_config=patch_config, decoder_config=byte_config).to(device)

checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
model.load_state_dict(checkpoint['model'])
model = model.to(dtype=torch.float16).to(device)
model.eval()


  from .autonotebook import tqdm as notebook_tqdm
  checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))


NotaGenLMHeadModel(
  (patch_level_decoder): PatchLevelDecoder(
    (patch_embedding): Linear(in_features=2048, out_features=1280, bias=True)
    (base): GPT2Model(
      (wte): Embedding(1, 1280)
      (wpe): Embedding(1024, 1280)
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0-19): 20 x GPT2Block(
          (ln_1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          (attn): GPT2Attention(
            (c_attn): Conv1D()
            (c_proj): Conv1D()
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          (mlp): GPT2MLP(
            (c_fc): Conv1D()
            (c_proj): Conv1D()
            (act): NewGELUActivation()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (ln_f): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
  )
  (char_leve

In [4]:
period, composer, instrumentation = 'Classical', 'Beethoven, Ludwig van', 'Keyboard'

prompt_lines=[
    '%' + period + '\n',
    '%' + composer + '\n',
    '%' + instrumentation + '\n']
prompt_lines

['%Classical\n', '%Beethoven, Ludwig van\n', '%Keyboard\n']

In [5]:
patchilizer.bos_token_id, patchilizer.eos_token_id, patchilizer.special_token_id

(1, 2, 0)

In [6]:
bos_patch = [patchilizer.bos_token_id] * (PATCH_SIZE - 1) + [patchilizer.eos_token_id]
bos_patch

[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2]

In [7]:
prompt_patches = patchilizer.patchilize_metadata(prompt_lines)
prompt_patches

['%Classical\n\x02', '%Beethoven, Ludw', 'ig van\n\x02', '%Keyboard\n\x02']

In [8]:
byte_list = list(''.join(prompt_lines))
print(''.join(byte_list), end='')

%Classical
%Beethoven, Ludwig van
%Keyboard


In [9]:
prompt_patches = [[ord(c) for c in patch] + [patchilizer.special_token_id] * (PATCH_SIZE - len(patch)) for patch
                          in prompt_patches]
prompt_patches.insert(0, bos_patch)

prompt_patches

[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2],
 [37, 67, 108, 97, 115, 115, 105, 99, 97, 108, 10, 2, 0, 0, 0, 0],
 [37, 66, 101, 101, 116, 104, 111, 118, 101, 110, 44, 32, 76, 117, 100, 119],
 [105, 103, 32, 118, 97, 110, 10, 2, 0, 0, 0, 0, 0, 0, 0, 0],
 [37, 75, 101, 121, 98, 111, 97, 114, 100, 10, 2, 0, 0, 0, 0, 0]]

In [40]:
input_patches = torch.tensor(prompt_patches, device=device).reshape(1, -1)
input_patches.shape, input_patches.dtype

(torch.Size([1, 80]), torch.int64)

---

In [49]:
patches = input_patches.unsqueeze(0)
patches.shape[-1] % PATCH_SIZE

15

In [None]:
tokens = patches[:,:,-(patches.shape[-1]%PATCH_SIZE):].squeeze(0, 1)
tokens = torch.cat((torch.tensor([model.bos_token_id], device=model.device), tokens), dim=-1)
tokens

tensor([  1,  37, 101, 110, 100,  10,   2,   2,   2,   2,   2,   2,   2,   2,
          2,   2], device='cuda:0')

In [52]:
patches = patches[:,:,:-(patches.shape[-1]%PATCH_SIZE)]
patches

tensor([[[  1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
            1,   2,  37,  67, 108,  97, 115, 115, 105,  99,  97, 108,  10,   2,
            0,   0,   0,   0,  37,  66, 101, 101, 116, 104, 111, 118, 101, 110,
           44,  32,  76, 117, 100, 119, 105, 103,  32, 118,  97, 110,  10,   2,
            0,   0,   0,   0,   0,   0,   0,   0,  37,  75, 101, 121,  98, 111,
           97, 114, 100,  10,   2,   0,   0,   0,   0,   0]]], device='cuda:0')

In [61]:
tokens =  torch.tensor([model.bos_token_id], device=model.device)
tokens

tensor([1], device='cuda:0')

In [23]:
model.bos_token_id

1

In [34]:
tokens =  torch.tensor([model.bos_token_id], device=model.device)
patches = patches.reshape(len(patches), -1, PATCH_SIZE)
patches.shape

torch.Size([1, 5, 16])

In [53]:
with torch.no_grad():
	patch_result = model.patch_level_decoder(patches)
patch_result

BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=tensor([[[ 0.0077, -0.0576, -0.1288,  ..., -0.1088, -0.0108,  0.0047],
         [ 0.2473, -0.6167,  1.8516,  ...,  2.7832,  0.0039, -1.0430],
         [-0.8037,  0.7085, -1.9131,  ...,  0.6143,  0.3135, -1.2129],
         [-0.5723,  1.5752,  1.1904,  ..., -2.0234,  0.5127,  1.6221],
         [ 0.3896,  1.0879, -1.0654,  ...,  1.4385, -1.5654,  0.6265]]],
       device='cuda:0', dtype=torch.float16), past_key_values=((tensor([[[[ 0.1833,  0.1517,  0.2196,  ..., -0.1962,  0.1675,  0.0435],
          [-0.5234, -0.4353, -0.6196,  ...,  0.5596, -0.4863, -0.1309],
          [-0.4915, -0.3892, -0.5776,  ...,  0.5225, -0.4390, -0.1219],
          [-0.5410, -0.4651, -0.6348,  ...,  0.5742, -0.5112, -0.1559],
          [-0.6030, -0.5083, -0.7085,  ...,  0.6431, -0.5601, -0.1659]],

         [[ 0.2507, -0.2192,  0.2104,  ...,  0.2352, -0.0362,  0.2299],
          [-1.0918,  1.0635, -0.6885,  ..., -0.9819, -0.1671, -0.9272],
          [-1

In [54]:
encoded_patches = patch_result.last_hidden_state
encoded_patches.shape

torch.Size([1, 5, 1280])

In [55]:
with torch.no_grad():
	prob = model.char_level_decoder.generate(encoded_patches[0][-1], tokens).cpu().detach().numpy()
prob.shape

(128,)

In [56]:
from samplings import temperature_sampling


token = temperature_sampling(prob, temperature=1.2)
token

2

In [62]:
tokens = torch.cat((tokens, torch.tensor([token], device=model.device)), dim=0)
tokens

tensor([1, 2], device='cuda:0')

In [63]:
generated_patch = [token]
generated_patch

[2]

In [64]:
while len(tokens) < PATCH_SIZE:
	with torch.no_grad():
		prob = model.char_level_decoder.generate(encoded_patches[0][-1], tokens).cpu().detach().numpy()
		token = temperature_sampling(prob, temperature=1.2)
		tokens = torch.cat((tokens, torch.tensor([token], device=model.device)), dim=0)

		generated_patch.append(token)

generated_patch

[2, 101, 110, 100, 10, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]

In [66]:
next_patch = patchilizer.decode([generated_patch])
next_patch

''

In [44]:
predicted_patch = [ord(c) for c in '[r:0/'] + generated_patch
predicted_patch

[91, 114, 58, 48, 47, 37, 101, 110, 100, 10, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]

In [47]:
predicted_patch = torch.tensor([generated_patch], device=device)
predicted_patch

tensor([[ 37, 101, 110, 100,  10,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           2]], device='cuda:0')

In [48]:
input_patches = torch.cat([input_patches, predicted_patch], dim=1)
input_patches.shape

torch.Size([1, 95])

In [45]:
next_patch = patchilizer.decode([predicted_patch])
next_patch

'[r:0/%end\n'

---

In [36]:
failure_flag = False
end_flag = False
cut_index = None

tunebody_flag = False


In [41]:
import time

from inference.config import TOP_K, TOP_P, TEMPERATURE


def gen_patch(input_patches):
	global failure_flag, end_flag, cut_index, tunebody_flag

	predicted_patch = model.generate(input_patches.unsqueeze(0),
									 top_k=TOP_K,
									 top_p=TOP_P,
									 temperature=TEMPERATURE)
	if not tunebody_flag and patchilizer.decode([predicted_patch]).startswith('[r:'):  # start with [r:0/
		tunebody_flag = True
		r0_patch = torch.tensor([ord(c) for c in '[r:0/']).unsqueeze(0).to(device)
		temp_input_patches = torch.concat([input_patches, r0_patch], axis=-1)
		predicted_patch = model.generate(temp_input_patches.unsqueeze(0),
										 top_k=TOP_K,
										 top_p=TOP_P,
										 temperature=TEMPERATURE)
		predicted_patch = [ord(c) for c in '[r:0/'] + predicted_patch
	if predicted_patch[0] == patchilizer.bos_token_id and predicted_patch[1] == patchilizer.eos_token_id:
		end_flag = True
		return None

	next_patch = patchilizer.decode([predicted_patch])
	#print(f'{next_patch=}')

	for char in next_patch:
		byte_list.append(char)
		print(char, end='')
	print('🍎', end='')

	patch_end_flag = False
	for j in range(len(predicted_patch)):
		if patch_end_flag:
			predicted_patch[j] = patchilizer.special_token_id
		if predicted_patch[j] == patchilizer.eos_token_id:
			patch_end_flag = True

	predicted_patch = torch.tensor([predicted_patch], device=device)  # (1, 16)
	input_patches = torch.cat([input_patches, predicted_patch], dim=1)  # (1, 16 * patch_len)
	
	if len(byte_list) > 102400:
		failure_flag = True
		return None
	
	if input_patches.shape[1] >= PATCH_LENGTH * PATCH_SIZE and not end_flag:
		print('Stream generating...')
		abc_code = ''.join(byte_list)
		abc_lines = abc_code.split('\n')
	
		tunebody_index = None
		for i, line in enumerate(abc_lines):
			if line.startswith('[r:') or line.startswith('[V:'):
				tunebody_index = i
				return None
		if tunebody_index is None or tunebody_index == len(abc_lines) - 1:
			return None
		
		metadata_lines = abc_lines[:tunebody_index]
		tunebody_lines = abc_lines[tunebody_index:]
	
		metadata_lines = [line + '\n' for line in metadata_lines]
		if not abc_code.endswith('\n'):  
			tunebody_lines = [tunebody_lines[i] + '\n' for i in range(len(tunebody_lines) - 1)] + [
				tunebody_lines[-1]]
		else:
			tunebody_lines = [tunebody_lines[i] + '\n' for i in range(len(tunebody_lines))]
	
		if cut_index is None:
			cut_index = len(tunebody_lines) // 2
	
		abc_code_slice = ''.join(metadata_lines + tunebody_lines[-cut_index:])
		input_patches = patchilizer.encode_generate(abc_code_slice)
	
		input_patches = [item for sublist in input_patches for item in sublist]
		input_patches = torch.tensor([input_patches], device=device)
		input_patches = input_patches.reshape(1, -1)

	return input_patches


In [17]:
input_patches.shape

torch.Size([1, 80])

In [28]:
input_patches = gen_patch(input_patches)
input_patches.shape

next_patch='iano" snm="Pno."'
iano" snm="Pno."

torch.Size([1, 224])

In [42]:
while input_patches is not None:
	input_patches = gen_patch(input_patches)

%end
🍎%%score { ( 1 3 🍎) | ( 2 4 ) }
🍎L:1/8
🍎Q:1/4=71
🍎M:3/4
🍎K:C
🍎V:1 treble nm="P🍎iano" snm="Pno."🍎
🍎V:3 treble 
🍎V:2 bass 
🍎V:4 bass 
🍎[r:0/62][V:1]"^A🍎dagio sostenuto"🍎!pp! (3E,A,C (3E🍎CA, (3E,A,C🍎|[V:2][A,,,A,,]2🍎 z2 z2|
🍎[r:1/61][V:1](3E🍎CA, (3E,A,C (3E,🍎A,C🍎|[V:2][G,,,G,,]2🍎 z2 z2|
🍎[r:2/60][V:1]!<(🍎! (3F,A,C (3FCA,🍎 (3F,_B,D!<)!🍎|[V:2][F,,,F,,]2🍎 z2 [D,,,D,,]2|
🍎[r:3/59][V:1]!p!🍎!>(! (3E,A,B, (3🍎E,A,B, (3E,^G,B,🍎!>)!🍎|[V:2][E,,,E,,]2🍎 z2 z2|
🍎[r:4/58][V:1]!pp🍎! z2 z2 E>E🍎|[V:2](3E,A,C (3🍎E,A,C (3E,A,C🍎|[V:3][I:staff +🍎1] A,,2[I:staff 🍎-1] z2 z2|
🍎[r:5/57][V:1]E4 🍎E>E🍎|[V:2](3E,B,D (3🍎E,B,D (3E,B,D🍎|[V:3][I:staff +🍎1] ^G,,2[I:staff🍎 -1] z2 z2|
🍎[r:6/56][V:1](E2🍎 F2 D2🍎|[V:2](3E,A,C (3🍎F,A,D (3F,A,B,🍎|[V:3][I:staff +🍎1] A,,2[I:staff 🍎-1] z2[I:staff +🍎1] D,,2|
🍎[r:7/55][V:1]E4 🍎G2🍎|[V:2](3E,G,C (3🍎E,G,C (3F,G,B,🍎|[V:3]G,,2 z2 G,🍎,2|
🍎[r:8/54][V:1]C2)🍎 z2 z2🍎|[V:2](3E,G,C (3🍎E,G,C (3E,G,C🍎|[V:3]C,2 z2 z2|🍎
🍎[r:9/53][V:1]z2 🍎z2 _E>E🍎|[V:2](3_E,G,C (🍎3E,G,C (3E,G,C🍎|[V:3]C,2 z2 z2|🍎
