In [None]:
!pip install datasets evaluate transformers[sentencepiece]
!apt install git-lfs
!pip install tensorflow

In [None]:
# load dataset
from datasets import load_dataset
raw_datasets = load_dataset("dataset")

In [None]:
# load tokenizer
from transformers import AutoTokenizer
context_length = 256
tokenizer = AutoTokenizer.from_pretrained("local_wrapped_tokenizer")

# process data (check length、output wrong data、label content to input_ids)
def tokenize(element):
    removed_elements_counter = 0
    outputs = tokenizer(
        element["text"],
        truncation=True,
        max_length=context_length,
        return_overflowing_tokens=True,
        return_length=True,
    )
    input_batch = []
    for length, input_ids in zip(outputs["length"], outputs["input_ids"]):
        if length == context_length:
            input_batch.append(input_ids)
        else:
            removed_elements_counter += 1
    print(f"Removed chunks with size less than context_size: {removed_elements_counter}")
    return {"input_ids": input_batch}

# process data
tokenized_datasets = raw_datasets.map(
    tokenize, batched=True, remove_columns=raw_datasets["train"].column_names
)

In [None]:
# set model config
from transformers import TFGPT2LMHeadModel, AutoConfig
config = AutoConfig.from_pretrained(
    "gpt2",
    vocab_size=len(tokenizer),
    n_ctx=context_length,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
)

# build model
model = TFGPT2LMHeadModel(config)
model(model.dummy_inputs)
model.summary()

In [None]:
# set data_collator (tensorflow)
from transformers import DataCollatorForLanguageModeling
tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False, return_tensors="tf")

# tensorflow dataset
tf_train_dataset = tokenized_datasets["train"].to_tf_dataset(
    columns=["input_ids", "attention_mask", "labels"],
    collate_fn=data_collator,
    shuffle=True,
    batch_size=8
)
tf_eval_dataset = tokenized_datasets["test"].to_tf_dataset(
    columns=["input_ids", "attention_mask", "labels"],
    collate_fn=data_collator,
    shuffle=False,
    batch_size=8
)

In [None]:
# set model param.
from transformers import create_optimizer
import tensorflow as tf
num_epochs = 10
num_train_steps = len(tf_train_dataset) * num_epochs
optimizer, schedule = create_optimizer(
    init_lr=5e-5,
    num_warmup_steps=1_000,
    num_train_steps=num_train_steps,
    weight_decay_rate=0.01,
)
model.compile(optimizer=optimizer)
tf.keras.mixed_precision.set_global_policy("mixed_float16")

In [None]:
# helper function (generate 3 midi files per epoch)
from transformers import pipeline
from tensorflow.keras.callbacks import Callback
!pip install note_seq
import note_seq

BPM_1_SECOND = 60

# Variables to change based on the time signature
numerator = ""
denominator = ""

def token_sequence_to_note_sequence(token_sequence, 
                                    use_program=True, 
                                    use_drums=True, 
                                    instrument_mapper=None, 
                                    only_piano=False):

    if isinstance(token_sequence, str):
        token_sequence = token_sequence.split()

    note_sequence = empty_note_sequence()

    # Render all notes.
    current_program = 1
    current_is_drum = False
    current_instrument = 0
    track_count = 0
    for token_index, token in enumerate(token_sequence):

        if token == "PIECE_START":
            pass
        elif token == "PIECE_END":
            print("The end.")
            break
        elif token.startswith("TIME_SIGNATURE="):
            time_signature_str = token.split("=")[-1]
            numerator = int(time_signature_str.split("_")[0])
            denominator = int(time_signature_str.split("_")[-1])
            time_signature = note_sequence.time_signatures.add()
            time_signature.numerator = numerator
            time_signature.denominator = denominator
        elif token.startswith("BPM="):
            bpm_str = token.split("=")[-1]
            bpm = int(bpm_str)
            note_sequence.tempos[0].qpm = bpm
            pulse_duration, bar_duration = duration_in_sec(
                bpm, numerator, denominator
            )
        elif token == "TRACK_START":
            current_bar_index = 0
            track_count += 1
            pass
        elif token == "TRACK_END":
            pass
        elif token == "KEYS_START":
            pass
        elif token == "KEYS_END":
            pass
        elif token.startswith("KEY="):
            pass
        elif token.startswith("INST"):
            instrument = token.split("=")[-1]
            if instrument != "DRUMS" and use_program:
                if instrument_mapper is not None:
                    if instrument in instrument_mapper:
                        instrument = instrument_mapper[instrument]
                current_program = int(instrument)
                current_instrument = track_count
                current_is_drum = False
            if instrument == "DRUMS" and use_drums:
                current_instrument = 0
                current_program = 0
                current_is_drum = True
        elif token == "BAR_START":
            current_time = (current_bar_index * bar_duration)
            current_notes = {}
        elif token == "BAR_END":
            current_bar_index += 1
            pass
        elif token.startswith("NOTE_ON"):
            pitch = int(token.split("=")[-1])
            note = note_sequence.notes.add()
            note.start_time = current_time
            note.end_time = current_time + denominator * pulse_duration
            note.pitch = pitch
            note.instrument = current_instrument
            note.program = current_program
            note.velocity = 80
            note.is_drum = current_is_drum
            current_notes[pitch] = note
        elif token.startswith("NOTE_OFF"):
            pitch = int(token.split("=")[-1])
            if pitch in current_notes:
                note = current_notes[pitch]
                note.end_time = current_time
        elif token.startswith("TIME_DELTA"):
            delta = float(token.split("=")[-1]) * (0.25) * pulse_duration
            current_time += delta
        elif token.startswith("DENSITY="):
            pass
        elif token == "[PAD]":
            pass
        else:
            #print(f"Ignored token {token}.")
            pass

    # Make the instruments right.
    instruments_drums = []
    for note in note_sequence.notes:
        pair = [note.program, note.is_drum]
        if pair not in instruments_drums:
            instruments_drums += [pair]
        note.instrument = instruments_drums.index(pair)

    if only_piano:
        for note in note_sequence.notes:
            if not note.is_drum:
                note.instrument = 0
                note.program = 0

    return note_sequence

def duration_in_sec(bpm, numerator, denominator):
    pulse_duration = BPM_1_SECOND / bpm
    number_of_quarters_per_bar = (4 / denominator) * numerator
    bar_duration = pulse_duration * number_of_quarters_per_bar
    return pulse_duration, bar_duration

def empty_note_sequence(qpm=120, total_time=0.0):
    note_sequence = note_seq.protobuf.music_pb2.NoteSequence()
    note_sequence.tempos.add().qpm = qpm
    #note_sequence.ticks_per_quarter = note_seq.constants.STANDARD_PPQ
    note_sequence.total_time = total_time
    return note_sequence

class GenerateAndSaveCallback(Callback):
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.generation_seed = "PIECE_START TIME_SIGNATURE=4_4 BPM=120 TRACK_START INST=0 DENSITY=2 BAR_START NOTE_ON=43"

    def on_epoch_end(self, epoch, logs=None):
        pipe = pipeline(
            "text-generation", model=self.model, tokenizer=self.tokenizer, device=0
        )
        for i in range(3):
            generated_text = pipe(self.generation_seed, max_length=500)[0]["generated_text"]
            note_sequence = token_sequence_to_note_sequence(generated_text)
            midi_filename = f"epoch_{epoch}_{i}.mid"
            note_seq.note_sequence_to_midi_file(note_sequence, midi_filename)
            print(f"Generated MIDI for epoch {epoch} saved as {midi_filename}")


In [None]:
# ModelCheckpoint (gen midi)
from keras.callbacks import ModelCheckpoint

generate_and_save_callback = GenerateAndSaveCallback(model, tokenizer)
checkpoint_callback = ModelCheckpoint(
    filepath="check/check_{epoch}",
    save_weights_only=True,
    monitor="val_loss", 
    save_best_only=False,
    mode="min", 
    save_freq="epoch",
)

# train + save loss history
history = model.fit(
    tf_train_dataset,
    validation_data=tf_eval_dataset,
    epochs=num_epochs,
    callbacks=[checkpoint_callback, generate_and_save_callback],
)

# save Tokenizer
tokenizer.save_pretrained("save_tokenizer")


In [None]:
# plot loss
!pip install matplotlib
import matplotlib.pyplot as plt

# get loss history
train_loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(1, num_epochs + 1)

# plot loss
plt.plot(epochs, train_loss, 'b', label='Training Loss')
plt.plot(epochs, val_loss, 'r', label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

# save pic
plt.savefig('loss_plot.png')
plt.show()