In [1]:
import os

import numpy as np
import miditok
import torch

from nanogpt_model import GPTConfig, GPT

Load model

In [None]:
out_dir = "out"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Resuming model from {out_dir}")
ckpt_path = os.path.join(out_dir, "ckpt_beethoven.pt") # ckpt_pre_trained.pt
checkpoint = torch.load(ckpt_path, map_location=device)
checkpoint_model_args = checkpoint["model_args"]

gptconf = GPTConfig(**checkpoint_model_args)
model = GPT(gptconf)
state_dict = checkpoint["model"]

unwanted_prefix = "_orig_mod."
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
iter_num = checkpoint["iter_num"]
best_val_loss = checkpoint["best_val_loss"]

model.to(device)
model.eval()

Load unique tokens used (this reduces the vocab size)

In [None]:
unique_tokens = torch.from_numpy(np.load("unique_tokens.npy"))
vocab_size = len(unique_tokens)
print("vocab_size", vocab_size)
unique_tokens

Define map functions

In [None]:
tokens_mapping = {unique_tokens[i].item(): i for i in range(len(unique_tokens))}
tokens_unmapping = {i: unique_tokens[i].item() for i in range(len(unique_tokens))}
print(tokens_mapping)
print(tokens_unmapping)

def map_tokens(tokens, tokens_mapping):
    mapped_tokens = torch.zeros_like(tokens)
    for i in range(len(tokens)):
        mapped_tokens[i] = tokens_mapping[tokens[i].item()]
    return mapped_tokens

def unmap_tokens(mapped_tokens, tokens_unmapping):
    unmapped_tokens = torch.zeros_like(mapped_tokens)
    for i in range(len(mapped_tokens)):
        unmapped_tokens[i] = tokens_unmapping[mapped_tokens[i].item()]
    return unmapped_tokens

Initialize tokenizer

In [None]:
TOKENIZER_PARAMS = {
    "special_tokens": ["PAD", "BOS", "EOS", "MASK"],
    "use_tempos": True,
    "use_programs": True,
    "one_token_stream_for_programs": True,
    "use_time_signatures": True
}
tokenizer = miditok.REMI(miditok.TokenizerConfig(**TOKENIZER_PARAMS))

Generate 5 tracks from scratch, just providing the BOS (begin of sentence) token (1)

In [None]:
inputs = torch.tensor([[1] for _ in range(5)], dtype=torch.int64).to(device)
for i in range(len(inputs)):
    inputs[i] = map_tokens(inputs[i], tokens_mapping)
print(inputs.shape)
inputs

Load a midi file to generate a variation of it, truncating in the desired point

In [None]:
input_midi_path = os.path.join("path", "to", "midi_file.mid")
input_token_ids = tokenizer.encode(input_midi_path).ids[:200] # use list slicing here to truncate the tokenized ids

# Optionally save the truncated input midi to listen to it
# input_midi = tokenizer.decode(input_token_ids)
# input_midi.dump_midi("input.mid")

inputs = map_tokens(torch.tensor(input_token_ids, dtype=torch.int64), tokens_mapping).unsqueeze(0).to(device)
print(inputs.shape)
inputs

Generate outputs

In [None]:
outputs = model.generate(inputs, max_new_tokens=(1024 - inputs.shape[1]), temperature=1.0) # generate a total of 1024 tokens, otherwise specify directly the new tokens to be generated
print(outputs.shape)
for i in range(len(outputs)):
    outputs[i] = unmap_tokens(outputs[i], tokens_unmapping)

Save outputs as midi files

In [None]:
for i in range(outputs.shape[0]):
    output_midi = tokenizer.decode(outputs[i].tolist())
    output_midi.dump_midi(f"output_{i}.mid")