In [1]:
!pip3 install transformers



In [4]:
from transformers import PreTrainedTokenizerFast, GPT2LMHeadModel

model_path = "Milos121/MMM_jsb_mmmbar"
tokenizer_path = '../data/external/Jazz Midi/jsb_mmmtrack/tokenizer.json'


model = GPT2LMHeadModel.from_pretrained(model_path)
tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_path)

In [None]:
import random

# so we give it the db of encoded midi files
def get_priming_token_sequence(data_path, stop_on_track_end=None, stop_after_n_tokens=None, return_original=False):

    # Get a random token sequence from the file.
    lines = open(data_path, "r").readlines()
    token_sequence = random.choice(lines)

    result_tokens = []
    track_end_index = 0
    for token_index, token in enumerate(token_sequence.split()):
        result_tokens += [token]

        if stop_on_track_end == track_end_index and token == "TRACK_END":
            break

        if token == "TRACK_END":
            track_end_index += 1

        if stop_after_n_tokens != 0 and token_index + 1 == stop_after_n_tokens:
            break

    result = " ".join(result_tokens)
    if not return_original:
        return result
    else:
        return result, token_sequence

In [7]:
validation_data_path = '../data/external/Jazz Midi/jsb_mmmtrack/token_sequences_valid.txt'
# with current settings this retrieves the first track from the encoded sequence.
priming_sample, priming_sample_original = get_priming_token_sequence(
    validation_data_path,
    stop_on_track_end=0,
    stop_after_n_tokens=20,
    return_original=True
)

In [8]:
input_ids = tokenizer.encode(priming_sample, return_tensors="pt")

generated_sequence = model.generate(
    input_ids,
    max_length=1000,
    temperature=0.9,
)

In [9]:
decoded_sequence = tokenizer.decode(generated_sequence[0])

In [10]:
import note_seq

NOTE_LENGTH_16TH_120BPM = 0.25 * 60 / 120
BAR_LENGTH_120BPM = 4.0 * 60 / 120

def empty_note_sequence(qpm=120.0, total_time=0.0):
    note_sequence = note_seq.protobuf.music_pb2.NoteSequence()
    note_sequence.tempos.add().qpm = qpm
    note_sequence.ticks_per_quarter = note_seq.constants.STANDARD_PPQ
    note_sequence.total_time = total_time
    return note_sequence


def token_sequence_to_note_sequence(token_sequence, use_program=True, use_drums=True):

    if isinstance(token_sequence, str):
        token_sequence = token_sequence.split()

    note_sequence = empty_note_sequence()
    current_program = 1
    current_is_drum = False
    for token_index, token in enumerate(token_sequence):

        if token == "PIECE_START":
            pass
        elif token == "PIECE_END":
            print("The end.")
            break
        elif token == "TRACK_START":
            current_bar_index = 0
            pass
        elif token == "TRACK_END":
            pass
        elif token.startswith("INST"):
            current_instrument = token.split("=")[-1]
            if current_instrument != "DRUMS" and use_program:
                current_instrument = int(current_instrument)
                current_program = int(current_instrument)
                current_is_drum = False
            if current_instrument == "DRUMS" and use_drums:
                current_instrument = 0
                current_program = 0
                current_is_drum = True
        elif token == "BAR_START":
            current_time = current_bar_index * BAR_LENGTH_120BPM
            current_notes = {}
        elif token == "BAR_END":
            current_bar_index += 1
            pass
        elif token.startswith("NOTE_ON"):
            pitch = int(token.split("=")[-1])
            note = note_sequence.notes.add()
            note.start_time = current_time
            note.end_time = current_time + 4 * NOTE_LENGTH_16TH_120BPM
            note.pitch = pitch
            note.instrument = int(current_instrument)
            note.program = current_program
            note.velocity = 80
            note.is_drum = current_is_drum
            current_notes[pitch] = note
        elif token.startswith("NOTE_OFF"):
            pitch = int(token.split("=")[-1])
            if pitch in current_notes:
                note = current_notes[pitch]
                note.end_time = current_time
        elif token.startswith("TIME_DELTA"):
            delta = float(token.split("=")[-1]) * NOTE_LENGTH_16TH_120BPM
            current_time += delta
        elif token.startswith("DENSITY="):
            pass
        elif token == "[PAD]":
            pass
        else:
            assert False, token

    return note_sequence

In [15]:
def render_token_sequence(token_sequence, use_program=True, use_drums=True):
    note_sequence = token_sequence_to_note_sequence(token_sequence, use_program=use_program, use_drums=use_drums)
    note_seq.plot_sequence(note_sequence)

In [16]:
note_sequence = token_sequence_to_note_sequence(decoded_sequence, use_program=False, use_drums=True)
note_seq.play_sequence(note_sequence)

### Now, let's try to transfer learn

In [21]:
from torch.utils.data.dataset import Dataset
import random
import numpy as np
import os
import torch

class TokenSequenceDataset(Dataset):

    def __init__(self, tokenizer, dataset_paths, block_size, simulate=False):

        pad_token_id = tokenizer.encode("[PAD]")[0]
        unk_token_id = tokenizer.encode("[UNK]")[0]

        # Read all lines from all files.
        lines = []
        for dataset_path in dataset_paths:
            assert os.path.isfile(dataset_path), f"Input file path {dataset_path} not found"
            lines += open(dataset_path, "r").readlines()

        # In simulation just use a few samples.
        if simulate:
            random.shuffle(lines)
            lines = lines[:10]

        # Turn lines into training examples. Also gather some statistics.
        self.examples = []
        unknown_tokens_set = []
        unknown_tokens = []
        tokens_count = 0
        unknown_token_lines_count = 0
        too_long_lines_count = 0
        encoded_lengths = []
        for line in lines:

            #Skip empty lines.
            line = line.strip()
            if line == "":
                continue

            # Encode the line.
            encoded_line = tokenizer.encode(line)
            encoded_lengths += [len(encoded_line)]
            tokens_count += len(encoded_line)

            # Create a warning about unknown tokens. And then skip the line.
            if unk_token_id in encoded_line:
                index = encoded_line.index(unk_token_id)
                token = tokenizer.decode(encoded_line[index])
                token = line.split()[index]
                if token not in unknown_tokens_set:
                    unknown_tokens_set += [token]
                #logger.warning(f"Skipping line because of unknown token {token}")
                unknown_tokens += [token]
                unknown_token_lines_count += 1
                continue

            # Skip sequence if it is too long.
            if len(encoded_line) > block_size:
                #logger.warning(f"Skipping line because it is too long... {len(encoded_line)} > {block_size}")
                too_long_lines_count += 1
                continue

            # Pad and truncate.
            tensor = np.full((block_size,), pad_token_id, dtype=np.longlong)
            tensor[:len(encoded_line)] = encoded_line
            assert len(tensor) == block_size

            self.examples += [{
                "input_ids": torch.tensor(tensor, dtype=torch.long),
                "labels": torch.tensor(tensor, dtype=torch.long)
            }]

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i):
        return self.examples[i]

In [22]:
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(
    tokenizer=tokenizer,
    padding="max_length",
    max_length=768
)

dataset_train = TokenSequenceDataset(
    tokenizer=tokenizer,
    dataset_paths=['../data/external/Jazz Midi/jsb_mmmtrack/token_sequences_train.txt'],
    block_size=768,
    simulate=False
)

dataset_valid = TokenSequenceDataset(
    tokenizer=tokenizer,
    dataset_paths=['../data/external/Jazz Midi/jsb_mmmtrack/token_sequences_valid.txt'],
    block_size=768,
    simulate=False
)

In [23]:
import os
import torch
from transformers import (
    Trainer,
    TrainingArguments
)
from transformers.data.data_collator import DataCollatorWithPadding

model.resize_token_embeddings(len(tokenizer))

# Freeze all layers by default
for param in model.parameters():
    param.requires_grad = False

# Unfreeze the last N layers
N = 4
for name, param in model.named_parameters():
    if "transformer.h." in name:
        layer_number = int(name.split(".")[2])  # Extract the layer number
        if layer_number >= (model.config.n_layer - N):
            param.requires_grad = True
            print(f"Unfreezing layer {layer_number}: {name}")
        else:
            print(f"Freezing layer {layer_number}: {name}")
    else:
        print(f"Freezing non-transformer layer: {name}")

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"Trainable parameters: {trainable_params}/{total_params}")

Freezing non-transformer layer: transformer.wte.weight
Freezing non-transformer layer: transformer.wpe.weight
Freezing layer 0: transformer.h.0.ln_1.weight
Freezing layer 0: transformer.h.0.ln_1.bias
Freezing layer 0: transformer.h.0.attn.c_attn.weight
Freezing layer 0: transformer.h.0.attn.c_attn.bias
Freezing layer 0: transformer.h.0.attn.c_proj.weight
Freezing layer 0: transformer.h.0.attn.c_proj.bias
Freezing layer 0: transformer.h.0.ln_2.weight
Freezing layer 0: transformer.h.0.ln_2.bias
Freezing layer 0: transformer.h.0.mlp.c_fc.weight
Freezing layer 0: transformer.h.0.mlp.c_fc.bias
Freezing layer 0: transformer.h.0.mlp.c_proj.weight
Freezing layer 0: transformer.h.0.mlp.c_proj.bias
Freezing layer 1: transformer.h.1.ln_1.weight
Freezing layer 1: transformer.h.1.ln_1.bias
Freezing layer 1: transformer.h.1.attn.c_attn.weight
Freezing layer 1: transformer.h.1.attn.c_attn.bias
Freezing layer 1: transformer.h.1.attn.c_proj.weight
Freezing layer 1: transformer.h.1.attn.c_proj.bias
Free

In [26]:
# Define training arguments
output_path = "../models"
tokenizer.add_special_tokens({'pad_token': '[PAD]'}) # move more up

training_args = TrainingArguments(
    output_dir=output_path,
    overwrite_output_dir=True,
    evaluation_strategy="steps",
    save_steps=500,
    save_total_limit=2,
    logging_steps=500,
    logging_dir=os.path.join(output_path, "logs"),
    num_train_epochs=3,  # Set as needed
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    load_best_model_at_end=True,
    learning_rate=5e-5,
    weight_decay=0.01
)

# Filter out frozen parameters for the optimizer
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=5e-5
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset_train,
    eval_dataset=dataset_valid
)


Assigning [PAD] to the pad_token key of the tokenizer
using `logging_steps` to initialize `eval_steps` to 500
PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


In [None]:
trainer.train()

***** Running training *****
  Num examples = 27050
  Num Epochs = 3
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 10146
  Number of trainable parameters = 12609536


Step,Training Loss,Validation Loss


In [None]:

finetuned_model_path = os.path.join(output_path, "finetuned_model")
trainer.save_model(finetuned_model_path)
print(f"Fine-tuned model saved to {finetuned_model_path}")