In [1]:
from copy import deepcopy
from pathlib import Path

from miditok import REMI
from miditok.pytorch_data import DatasetMIDI, DataCollator
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, GenerationConfig
from tqdm import tqdm

In [3]:
tokenizer_path = Path("tokenizer_filtered.json")
model_path = Path("checkpoint-10000")

tokenizer = REMI(params=tokenizer_path)

collator = DataCollator(tokenizer["PAD_None"], copy_inputs_as_labels=True)

model = AutoModelForCausalLM.from_pretrained(model_path)



midi_paths_test = list(Path("filtered_midi/Maestro_test").glob("**/*.mid")) + list(
    Path("filtered_midi/Maestro_test").glob("**/*.midi")
)

kwargs_dataset = {
    "max_seq_len": 1024,
    "tokenizer": tokenizer,
    "bos_token_id": tokenizer["BOS_None"],
    "eos_token_id": tokenizer["EOS_None"],
}

dataset_test = DatasetMIDI(midi_paths_test, **kwargs_dataset)

In [None]:
(gen_results_path := Path('gen_res')).mkdir(parents=True, exist_ok=True)

generation_config = GenerationConfig(
    max_new_tokens=200,  # extends samples by 200 tokens
    num_beams=3,         # no beam search
    do_sample=True,      # but sample instead
    temperature=0.87,
    top_k=10,
    top_p=0.5,
    epsilon_cutoff=3e-4,
    eta_cutoff=1e-3,
    pad_token_id=tokenizer.pad_token_id,
)

# Here the sequences are padded to the left, so that the last token along the time dimension
# is always the last token of each seq, allowing to efficiently generate by batch
collator.pad_on_left = True
collator.eos_token = None
dataloader_test = DataLoader(dataset_test, batch_size=4, collate_fn=collator)
model.eval()
count = 0
print(len(tokenizer))
for batch in tqdm(dataloader_test, desc='Testing model / Generating results'):
    res = model.generate(
        inputs=batch["input_ids"].to(model.device),
        attention_mask=batch["attention_mask"].to(model.device),
        generation_config=generation_config
    )

    # Saves the generated music, as MIDI files and tokens (json)
    for prompt, continuation in zip(batch["input_ids"], res):
        generated = continuation[len(prompt):]
        midi = tokenizer.decode([deepcopy(generated.tolist())])

        tokens = [generated, prompt, continuation]
        tokens = [seq.tolist() for seq in tokens]

        for tok_seq in tokens[1:]:
            _midi = tokenizer.decode([deepcopy(tok_seq)])
            midi.tracks.append(_midi.tracks[0])

        # midi_name = [f"Continuation of original sample ({len(generated)} tokens)", f"Original sample ({len(prompt)} tokens)", f"Original sample and continuation"]
        
        # for i in range(min(len(midi.tracks), len(midi_name))):
        #         midi.tracks[i].name = midi_name[i]
        try:
            midi.tracks[0].name = f"Continuation of original sample ({len(generated)} tokens)"
            midi.tracks[1].name = f"Original sample ({len(prompt)} tokens)"
            midi.tracks[1].name = f"Original sample and continuation"
        except:
            print("nie wygernerowano tokenów")
        
        midi.dump_midi(gen_results_path / f'{count}.mid')
        # tokenizer.save_tokens(tokens, gen_results_path / f'{count}.json') 

        count += 1