In [1]:
from miditok import REMI, TokenizerConfig
from miditoolkit import MidiFile
import os
from tqdm import tqdm
import pickle
from contextlib import nullcontext
import torch
import numpy as np
from model import GPTConfig, GPT

model_list = os.listdir("out")

In [3]:
for model_name in model_list:
    if "GAN" not in model_name:
        continue
    print(model_name)

    init_from = 'resume'
    out_dir = 'out'
    use_model = model_name
    start = "\n"
    num_samples = 1
    max_new_tokens = 1024
    temperature = 0.8
    top_k = 200
    seed = 1337
    device = 'cuda'
    dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
    compile = False


    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
    ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
    ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)


    # init from a model saved in a specific directory
    ckpt_path = os.path.join(out_dir, use_model)
    checkpoint = torch.load(ckpt_path, map_location=device)
    gptconf = GPTConfig(**checkpoint['model_args'])
    model = GPT(gptconf)
    state_dict = checkpoint['model']
    unwanted_prefix = '_orig_mod.'
    for k,v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict)

    model.eval()
    model.to(device)

    REMIconfig = TokenizerConfig(nb_velocities=16, use_chords=False, use_programs=True)
    tokenizer = REMI(REMIconfig)

    midis_path = "test/IN"
    midi_name = "crt.mid"
    midi = MidiFile(os.path.join(midis_path, midi_name))
    tokens = tokenizer(midi).ids
    tokens = torch.tensor(tokens, requires_grad=False, device=device)[None, ...]

    # run generation
    with torch.no_grad():
        with ctx:
            for k in range(num_samples):
                y = model.generate(tokens, max_new_tokens, temperature=temperature, top_k=top_k).squeeze().cpu().numpy()
    
    converted_back_midi = tokenizer(y)
    converted_back_midi.dump(f"test/OUT/CRT/{model_name.split('.')[0]}_{midi_name}")

GAN100k.pt
number of parameters: 85.15M


100%|██████████| 1024/1024 [00:05<00:00, 176.31it/s]


GAN10k.pt
number of parameters: 85.15M


100%|██████████| 1024/1024 [00:05<00:00, 179.43it/s]


GAN110k.pt
number of parameters: 85.15M


100%|██████████| 1024/1024 [00:05<00:00, 176.16it/s]


GAN120k.pt
number of parameters: 85.15M


100%|██████████| 1024/1024 [00:05<00:00, 179.43it/s]


GAN130k.pt
number of parameters: 85.15M


100%|██████████| 1024/1024 [00:05<00:00, 178.12it/s]


GAN140k.pt
number of parameters: 85.15M


100%|██████████| 1024/1024 [00:07<00:00, 137.04it/s]


GAN150k.pt
number of parameters: 85.15M


100%|██████████| 1024/1024 [00:05<00:00, 177.77it/s]


GAN160k.pt
number of parameters: 85.15M


100%|██████████| 1024/1024 [00:05<00:00, 179.36it/s]


GAN170k.pt
number of parameters: 85.15M


100%|██████████| 1024/1024 [00:05<00:00, 177.77it/s]


GAN180k.pt
number of parameters: 85.15M


100%|██████████| 1024/1024 [00:05<00:00, 179.89it/s]


GAN20k.pt
number of parameters: 85.15M


100%|██████████| 1024/1024 [00:05<00:00, 179.79it/s]


GAN30k.pt
number of parameters: 85.15M


100%|██████████| 1024/1024 [00:05<00:00, 179.74it/s]


GAN40k.pt
number of parameters: 85.15M


100%|██████████| 1024/1024 [00:05<00:00, 179.52it/s]


GAN50k.pt
number of parameters: 85.15M


100%|██████████| 1024/1024 [00:05<00:00, 179.85it/s]


GAN60k.pt
number of parameters: 85.15M


100%|██████████| 1024/1024 [00:05<00:00, 179.24it/s]


GAN70k.pt
number of parameters: 85.15M


100%|██████████| 1024/1024 [00:05<00:00, 179.64it/s]


GAN80k.pt
number of parameters: 85.15M


100%|██████████| 1024/1024 [00:07<00:00, 139.32it/s]


GAN90k.pt
number of parameters: 85.15M


100%|██████████| 1024/1024 [00:05<00:00, 178.22it/s]


new_GAN10k.pt
number of parameters: 85.15M


100%|██████████| 1024/1024 [00:05<00:00, 179.63it/s]
