In [1]:
import torch

In [7]:
MODEL_PATH = "../models/model.pt"
model = torch.jit.load(MODEL_PATH)

In [87]:
embedding_dim = model.transformer.wte.weight.shape[1]
embedding_dim

512

In [34]:
import miditok # Select a suitable encoding format
from miditoolkit import MidiFile

config = miditok.TokenizerConfig()
config.additional_params = { "base_tokenizer" : 'MIDILike' }

tokenizer = miditok.MMM(config)
example_track_path = '../data/external/Jazz Midi/5To10.mid'
tokens = tokenizer.encode(example_track_path)
# Tokens will now be in a format like a list of integers
input_ids = torch.tensor(tokens)  # Convert tokens to tensor for model input


In [36]:
import torch.nn as nn


In [37]:
max_length = 512

if len(input_ids) < max_length:
    input_ids = torch.cat([input_ids, torch.zeros(max_length - len(input_ids), dtype=torch.long)])
else:
    input_ids = input_ids[:max_length]

In [None]:

class GPT2WithClassificationHead(nn.Module):
    def __init__(self, model, num_classes):
        super(GPT2WithClassificationHead, self).__init__()
        self.model = model
        # Adding a linear classification layer on top of the model's output
        embedding_dim = model.transformer.wte.weight.shape[1]
        self.classification_head = nn.Linear(embedding_dim, num_classes)

    def forward(self, input_ids, attention_mask=None):
        # Pass input through the pre-trained model
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        # Use the last hidden state for the classification task (last token in each sequence)
        last_hidden_state = outputs[0][:, -1, :]  # shape: (batch_size, hidden_dim)
        logits = self.classification_head(last_hidden_state)
        return logits


In [19]:
num_classes = 2  # Change as needed for your specific task
model_with_head = GPT2WithClassificationHead(model, num_classes)

for param in model.parameters():
    param.requires_grad = False

# Only train the parameters of the classification head
for param in model_with_head.classification_head.parameters():
    param.requires_grad = True

In [20]:
import torch.optim as optim

# Define the optimizer to only update the parameters in the classification head
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model_with_head.parameters()), lr=1e-5)
criterion = nn.CrossEntropyLoss()  # Use CrossEntropyLoss for classification


In [72]:
!pip3 install transformers==4.24.0

Collecting transformers==4.24.0
  Downloading transformers-4.24.0-py3-none-any.whl (5.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.5/5.5 MB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Installing collected packages: transformers
  Attempting uninstall: transformers
    Found existing installation: transformers 4.27.1
    Uninstalling transformers-4.27.1:
      Successfully uninstalled transformers-4.27.1
Successfully installed transformers-4.24.0


In [74]:
from torch.utils.data import Dataset, DataLoader
from transformers import AdamW

In [85]:
import configparser

config = configparser.RawConfigParser()
config.read('../local_config.cfg')

tokens = dict(config.items('TOKENS'))
hf_token = tokens["hf_token"]

In [92]:
from miditok import REMI, TokenizerConfig
from miditok.pytorch_data import DatasetMIDI, DataCollator
from miditok.utils import split_files_for_training
from torch.utils.data import DataLoader
from pathlib import Path

# Creating a multitrack tokenizer configuration, read the doc to explore other parameters
# config = TokenizerConfig(num_velocities=16, use_chords=True, use_programs=True)
# tokenizer = REMI(config)

# Train the tokenizer with Byte Pair Encoding (BPE)
midi_paths = list(Path("/home/julia/WIMU/Orchestrify/data/external/Jazz Midi").glob("**/*.mid"))
tokenizer.train(vocab_size=512, files_paths=midi_paths)
tokenizer.save_params(Path("models", "tokenizer.json"))
# And pushing it to the Hugging Face hub (you can download it back with .from_pretrained)
tokenizer.push_to_hub("juleczka/orchestrify_tokenizer", private=True, token=hf_token)

# Split MIDIs into smaller chunks for training
dataset_chunks_dir = Path("/home/julia/WIMU/Orchestrify/data/processed")
split_files_for_training(
    files_paths=midi_paths,
    tokenizer=tokenizer,
    save_dir=dataset_chunks_dir,
    max_seq_len=1024,
)

# Create a Dataset, a DataLoader and a collator to train a model
dataset = DatasetMIDI(
    files_paths=list(dataset_chunks_dir.glob("**/*.mid")),
    tokenizer=tokenizer,
    max_seq_len=1024,
    bos_token_id=tokenizer["BOS_None"],
    eos_token_id=tokenizer["EOS_None"],
)
collator = DataCollator(tokenizer.pad_token_id, copy_inputs_as_labels=True)
dataloader = DataLoader(dataset, batch_size=64, collate_fn=collator)

  tokenizer.train(vocab_size=512, files_paths=midi_paths)
  tokenizer.save_params(Path("models", "tokenizer.json"))
No files have been modified since last commit. Skipping to prevent empty commit.
  split_files_for_training(


In [96]:
# Set up for training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.train()
optimizer = AdamW(model.parameters(), lr=5e-5)

# Training loop
num_epochs = 3
for epoch in range(num_epochs):
    for batch in dataloader:
        print(batch.keys())
        inputs = batch['input_ids']
        labels = batch['labels']
        optimizer.zero_grad()

        outputs = model(inputs)
        logits = outputs.logits

        # Calculate loss
        loss_fn = nn.CrossEntropyLoss()
        loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
        loss.backward()
        optimizer.step()

        print(f"Epoch {epoch + 1}, Loss: {loss.item()}")




dict_keys(['input_ids', 'labels', 'attention_mask'])


RuntimeError: forward() is missing value for argument 'argument_2'. Declaration: forward(__torch__.transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel self, Tensor input_ids, ((Tensor, Tensor), (Tensor, Tensor), (Tensor, Tensor), (Tensor, Tensor), (Tensor, Tensor), (Tensor, Tensor)) argument_2) -> ((Tensor, ((Tensor, Tensor), (Tensor, Tensor), (Tensor, Tensor), (Tensor, Tensor), (Tensor, Tensor), (Tensor, Tensor))))