In [1]:
from transformers import BartForConditionalGeneration, BartConfig
from transformers import RobertaTokenizerFast
import torch
from torch.utils.data import DataLoader

from models import MelCAT_base
from dataset_utils import LiveMelCATDataset, MelCATCollator

from torch.nn import CrossEntropyLoss

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
roberta_tokenizer_midi = RobertaTokenizerFast.from_pretrained('/media/datadisk/data/pretrained_models/midi_mlm_tiny/midi_wordlevel_tokenizer')
print(roberta_tokenizer_midi.vocab_size)

170


In [3]:
bart_config = BartConfig(
    vocab_size=roberta_tokenizer_midi.vocab_size,
    pad_token_id=roberta_tokenizer_midi.pad_token_id,
    bos_token_id=roberta_tokenizer_midi.bos_token_id,
    eos_token_id=roberta_tokenizer_midi.eos_token_id,
    decoder_start_token_id=roberta_tokenizer_midi.bos_token_id,
    forced_eos_token_id=roberta_tokenizer_midi.eos_token_id,
    max_position_embeddings=4096,
    encoder_layers=4,
    encoder_attention_heads=4,
    encoder_ffn_dim=256,
    decoder_layers=4,
    decoder_attention_heads=4,
    decoder_ffn_dim=256,
    d_model=256,
    encoder_layerdrop=0.2,
    decoder_layerdrop=0.2,
    dropout=0.2
)
# model = BartForConditionalGeneration(bart_config)
model = MelCAT_base(bart_config)

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at /media/datadisk/data/pretrained_models/midi_mlm_tiny/checkpoint-5120 and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at /media/datadisk/data/pretrained_models/chroma_mlm_tiny/checkpoint-14336 and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


initialized


In [4]:
# Freeze the parameters of pretrained models
for param in model.text_encoder.parameters():
    param.requires_grad = False

for param in model.chroma_encoder.parameters():
    param.requires_grad = False

for param in model.midi_encoder.parameters():
    param.requires_grad = False

In [6]:
params = list(model.bart_model.parameters()) + list( model.text_lstm.parameters())
optimizer = torch.optim.AdamW( params, lr=0.001)

In [7]:
midifolder = '/media/datadisk/datasets/GiantMIDI-PIano/midis_v1.2/midis'
# midifolder = '/media/datadisk/data/Giant_PIano/'
dataset = LiveMelCATDataset(midifolder, segment_size=40)

custom_collate_fn = MelCATCollator(max_seq_lens=dataset.max_seq_lengths, padding_values=dataset.padding_values)



In [8]:
dataloader = DataLoader(dataset, batch_size=4, collate_fn=custom_collate_fn)

In [9]:
b = next(iter(dataloader))

idx: 0
Ismagilov, Timur, Spring Sketches, 2QxuHQoT5Dk.mid
idx: 1
Gurlitt, Cornelius, Frühlingsblumen, Op.215, WD6wHfUb-kU.mid
idx: 2
Singelée, Jean Baptiste, Fantaisie sur des motifs de 'La sonnambula', Op.39, AcaSiJG7mkU.mid
idx: 3
Simpson, Daniel Léo, Kleine Klavierstücke No.9 in F major, R4z8vPF1Hto.mid


In [10]:
# print(b)
print(b['text']['input_ids'].shape, b['melody']['input_ids'].shape, b['chroma']['input_ids'].shape)
print(b['text']['attention_mask'].shape, b['melody']['attention_mask'].shape, b['chroma']['attention_mask'].shape)

torch.Size([4, 40]) torch.Size([4, 276]) torch.Size([4, 417])
torch.Size([4, 40]) torch.Size([4, 276]) torch.Size([4, 417])


In [11]:
# shift accomp
shifted_accomp = {
    'input_ids': b['accomp']['input_ids'].new_zeros(b['accomp']['input_ids'].shape),
    'attention_mask': b['accomp']['attention_mask'].new_zeros(b['accomp']['attention_mask'].shape)
}

shifted_accomp['input_ids'][:, 1:] = b['accomp']['input_ids'][:, :-1].clone()  # Shift by one
shifted_accomp['attention_mask'][:, 1:] = b['accomp']['attention_mask'][:, :-1].clone()  # Shift by one

shifted_accomp['input_ids'][:, 0] = roberta_tokenizer_midi.bos_token_id  # Add start token
shifted_accomp['attention_mask'][:, 0] = 1  # Add attention at start

In [12]:
y = model(b['text'], b['melody'], b['chroma'], shifted_accomp) # y is the logits

in forward


torch.Size([4, 40, 768])
torch.Size([4, 40, 256])
torch.Size([4, 276, 256])
torch.Size([4, 417, 256])
torch.Size([4, 2133])
torch.Size([4, 694, 256])
torch.Size([4, 694, 256])


In [13]:
print(y.shape)
print(y)

torch.Size([4, 2133, 170])
tensor([[[-7.8665e-01,  2.8517e+00, -1.7277e-02,  ..., -2.5381e-01,
           7.0461e-01,  2.2743e-01],
         [-4.6441e-03, -1.9576e-01, -8.8439e-03,  ...,  1.4149e-01,
           3.6048e-01,  2.9292e-01],
         [ 2.9452e-02, -2.3207e-01, -2.2551e-02,  ..., -1.6979e-01,
          -1.6077e-01, -4.0275e-02],
         ...,
         [-2.9329e-01,  4.1008e-01,  6.9159e-01,  ..., -7.8541e-01,
           2.7483e-01, -3.8374e-01],
         [ 4.5838e-01, -1.8517e-01,  5.0738e-01,  ..., -3.5898e-01,
          -2.3847e-01,  2.3564e-02],
         [-3.3708e-01, -1.8842e-01,  2.0431e-02,  ...,  2.2826e-01,
           2.8365e-01,  2.7119e-02]],

        [[-9.5609e-01,  3.1644e+00, -1.2974e-01,  ..., -3.3767e-01,
           7.8777e-01,  3.9883e-01],
         [ 8.7844e-02, -1.1841e-01, -4.4279e-03,  ...,  3.0918e-01,
           5.1439e-01,  2.3475e-01],
         [-1.2745e-01, -9.3867e-02, -1.5588e-01,  ..., -3.1439e-01,
          -4.8095e-03, -2.7512e-01],
         ...

In [14]:
logits = y  # Shape: [batch_size, seq_len, vocab_size]
target_ids_shifted = b['accomp']['input_ids'].contiguous()  # Shifted target sequence

# Flatten the logits and target for the loss calculation
logits_flat = logits.view(-1, logits.size(-1))
target_flat = target_ids_shifted.view(-1)
print(logits_flat.shape)
print(target_flat.shape)

# Compute the cross-entropy loss (ignoring padding tokens)
loss_fct = CrossEntropyLoss(ignore_index=roberta_tokenizer_midi.pad_token_id)
loss = loss_fct(logits_flat, target_flat)

torch.Size([8532, 170])
torch.Size([8532])


In [15]:
optimizer.zero_grad()
loss.backward()
optimizer.step()