In [1]:
from transformers import BartForConditionalGeneration, BartConfig
from transformers import RobertaTokenizerFast
import torch
from torch.utils.data import DataLoader

from miditok import REMI, TokenizerConfig, TokSequence
from pathlib import Path

from models import MelCAT_base
from dataset_utils import LiveMelCATDataset, MelCATCollator

from torch.nn import CrossEntropyLoss
import torch.nn.functional as F

import os
import numpy as np
import csv

from tqdm import tqdm

import json

import symusic
import json2midi_utils as j2m
import pretty_midi as pm

MAX_LENGTH = 1024

roberta_tokenizer_midi = RobertaTokenizerFast.from_pretrained('/media/datadisk/data/pretrained_models/pop_midi_mlm_base/pop_wordlevel_tokenizer')
remi_tokenizer = REMI(params=Path('/media/datadisk/data/pretrained_models/pop_midi_mlm_base/pop_REMI_BPE_tokenizer.json'))

bart_config = BartConfig(
    vocab_size=roberta_tokenizer_midi.vocab_size,
    pad_token_id=roberta_tokenizer_midi.pad_token_id,
    bos_token_id=roberta_tokenizer_midi.bos_token_id,
    eos_token_id=roberta_tokenizer_midi.eos_token_id,
    decoder_start_token_id=roberta_tokenizer_midi.bos_token_id,
    forced_eos_token_id=roberta_tokenizer_midi.eos_token_id,
    max_position_embeddings=MAX_LENGTH,
    encoder_layers=8,
    encoder_attention_heads=16,
    encoder_ffn_dim=4096,
    decoder_layers=8,
    decoder_attention_heads=16,
    decoder_ffn_dim=4096,
    d_model=256,
    encoder_layerdrop=0.3,
    decoder_layerdrop=0.3,
    dropout=0.3
)

run_on_gpu = False

if run_on_gpu:
    dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = MelCAT_base(bart_config, gpu=0).to(dev)
    checkpoint = torch.load('saved_models/bart_pop_embeds/bart_pop_embeds.pt', weights_only=True)
else:
    dev = torch.device("cpu")
    model = MelCAT_base(bart_config, gpu=None).to(dev)
    checkpoint = torch.load('saved_models/bart_pop_embeds/bart_pop_embeds.pt', map_location="cpu", weights_only=True)

model.load_state_dict(checkpoint)

model.eval()

DEFAULT_V_MEL = 70
DEFAULT_V_ACC = 50
DEFAULT_V_CHD = 50

temperature = 1

  from .autonotebook import tqdm as notebook_tqdm
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at /media/datadisk/data/pretrained_models/pop_midi_mlm_base/checkpoint-24064 and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at /media/datadisk/data/pretrained_models/chroma_mlm_tiny/checkpoint-14336 and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be abl

In [2]:
def sample_with_temperature(logits, temperature=1.0):
    # Scale logits by temperature
    logits = logits / temperature
    # Apply softmax to get probabilities
    probs = F.softmax(logits, dim=-1)

    # Flatten the logits if necessary
    batch_size, seq_len, vocab_size = probs.shape
    probs = probs.view(-1, vocab_size)  # Merge batch_size and seq_len dimensions
    
    # Sample from the probability distribution
    sampled_tokens = torch.multinomial(probs, num_samples=1)
    
    # Reshape back to [batch_size, seq_len, 1]
    sampled_tokens = sampled_tokens.view(batch_size, seq_len, 1)

    # # Sample from the probability distribution
    # sampled_token = torch.multinomial(probs, num_samples=1)
    return sampled_tokens

def generate_bart_tokens(d, temperature=1.0, max_seq_len=4096, num_bars=1000):
    accomp_input = {
        'input_ids' : torch.LongTensor([[roberta_tokenizer_midi.bos_token_id]]),
        'attention_mask' : torch.LongTensor([[1]])
    }
    bars_count = 0
    logits = model(d['text'], d['melody'], d['chroma'], accomp_input)
    sampled_tokens = sample_with_temperature(logits, temperature)
    accomp_input['input_ids'] = torch.cat( (accomp_input['input_ids'].to(dev), sampled_tokens[0][-1:].to(dev)), -1)
    accomp_input['attention_mask'] = torch.cat( (accomp_input['attention_mask'].to(dev), torch.full(sampled_tokens[0][-1:].shape,1).to(dev)), -1)
    bars_count += sampled_tokens[0][-1][0] == 5
    while sampled_tokens[0][-1][0] != roberta_tokenizer_midi.eos_token_id and \
        accomp_input['input_ids'].shape[-1] < max_seq_len and\
        num_bars >= bars_count:
        print(accomp_input['input_ids'].shape[-1], 'bars_count:', bars_count, end='\r')
        logits = model(d['text'], d['melody'], d['chroma'], accomp_input)
        sampled_tokens = sample_with_temperature(logits, temperature)
        # print(sampled_tokens[0][-1])
        bars_count += sampled_tokens[0][-1][0] == 5
        if num_bars < bars_count:
            break
        accomp_input['input_ids'] = torch.cat( (accomp_input['input_ids'].to(dev), sampled_tokens[0][-1:].to(dev)), -1)
        accomp_input['attention_mask'] = torch.cat( (accomp_input['attention_mask'].to(dev), torch.full(sampled_tokens[0][-1:].shape,1).to(dev)), -1)
    return accomp_input


In [3]:
midifolder = '/media/datadisk/datasets/POP909/aug_folder'

In [4]:
dataset = LiveMelCATDataset(midifolder, segment_size=40, resolution=4, max_seq_len=1024, only_beginning=True)
custom_collate_fn = MelCATCollator(max_seq_lens=dataset.max_seq_lengths, padding_values=dataset.padding_values)
dataloader = DataLoader(dataset, batch_size=1, collate_fn=custom_collate_fn, drop_last=True)



In [5]:
i = 0
for d in dataloader:
    if i == 5:
        break
    i += 1

0 871-v2#p-1.mid
remi_tokenized_melody: [TokSequence(tokens=['Bar_None', 'Position_0', 'Tempo_71.43', 'Position_16', 'Pitch_92', 'Velocity_79', 'Duration_1.0.4', 'Position_24', 'Pitch_90', 'Velocity_74', 'Duration_0.4.8', 'Position_28', 'Pitch_85', 'Velocity_79', 'Duration_2.1.2', 'Bar_None', 'Position_16', 'Pitch_92', 'Velocity_79', 'Duration_1.0.4', 'Position_24', 'Pitch_90', 'Velocity_79', 'Duration_0.4.8', 'Position_28', 'Pitch_83', 'Velocity_79', 'Duration_2.1.2', 'Bar_None', 'Position_16', 'Pitch_92', 'Velocity_79', 'Duration_1.0.4', 'Position_24', 'Pitch_90', 'Velocity_74', 'Duration_0.4.8', 'Position_28', 'Pitch_85', 'Velocity_79', 'Duration_2.1.2', 'Bar_None', 'Position_16', 'Pitch_92', 'Velocity_79', 'Duration_1.0.4', 'Position_24', 'Pitch_90', 'Velocity_79', 'Duration_0.4.8', 'Position_28', 'Pitch_83', 'Velocity_79', 'Duration_2.1.2', 'Bar_None', 'Position_16', 'Pitch_64', 'Velocity_74', 'Duration_1.0.4', 'Position_24', 'Pitch_61', 'Velocity_79', 'Duration_3.0.2', 'Position_

In [6]:
toks = roberta_tokenizer_midi.convert_ids_to_tokens(d['accomp']['input_ids'][0])
print(toks)
toks_m = roberta_tokenizer_midi.convert_ids_to_tokens(d['melody']['input_ids'][0])
print(toks_m)
print(d['chroma']['input_ids'])

['Bar_None', 'Position_0', 'Tempo_60x71', 'Position_24', 'Pitch_84', 'Velocity_95', 'Duration_0x2x8', 'Position_26', 'Pitch_60', 'Velocity_42', 'Duration_1x1x4', 'Pitch_63', 'Velocity_42', 'Duration_1x2x4', 'Position_30', 'Pitch_68', 'Velocity_47', 'Duration_1x3x4', 'Bar_None', 'Position_2', 'Pitch_72', 'Velocity_37', 'Duration_0x6x8', 'Pitch_75', 'Velocity_37', 'Duration_0x2x8', 'Position_24', 'Pitch_82', 'Velocity_105', 'Duration_0x2x8', 'Position_26', 'Pitch_67', 'Velocity_68', 'Duration_1x2x4', 'Position_30', 'Pitch_62', 'Velocity_37', 'Duration_0x2x8', 'Bar_None', 'Position_2', 'Pitch_70', 'Velocity_37', 'Duration_0x4x8', 'Pitch_74', 'Velocity_47', 'Duration_0x4x8', 'Position_18', 'Pitch_52', 'Velocity_42', 'Duration_0x6x8', 'Pitch_58', 'Velocity_79', 'Duration_0x4x8', 'Pitch_64', 'Velocity_79', 'Duration_0x2x8', 'Pitch_67', 'Velocity_89', 'Duration_0x2x8', 'Pitch_70', 'Velocity_121', 'Duration_0x2x8', 'Position_20', 'Pitch_55', 'Velocity_26', 'Duration_0x4x8', 'Position_22', 'Pit