In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("ai-guru/lakhclean_mmmtrack_4bars_d-2048")
model = AutoModelForCausalLM.from_pretrained("ai-guru/lakhclean_mmmtrack_4bars_d-2048")

In [2]:
import note_seq

BPM = 120
NOTE_LENGTH_16TH = 0.25 * 60 / BPM
BAR_LENGTH = 4.0 * 60 / BPM

def token_sequence_to_note_sequence(token_sequence, use_program=True, use_drums=True, instrument_mapper=None, only_piano=False):

    if isinstance(token_sequence, str):
        token_sequence = token_sequence.split()

    note_sequence = empty_note_sequence()

    # Render all notes.
    current_program = 1
    current_is_drum = False
    current_instrument = 0
    track_count = 0
    for token_index, token in enumerate(token_sequence):

        if token == "PIECE_START":
            pass
        elif token == "PIECE_END":
            print("The end.")
            break
        elif token == "TRACK_START":
            current_bar_index = 0
            track_count += 1
            pass
        elif token == "TRACK_END":
            pass
        elif token == "KEYS_START":
            pass
        elif token == "KEYS_END":
            pass
        elif token.startswith("KEY="):
            pass
        elif token.startswith("INST"):
            instrument = token.split("=")[-1]
            if instrument != "DRUMS" and use_program:
                if instrument_mapper is not None:
                    if instrument in instrument_mapper:
                        instrument = instrument_mapper[instrument]
                current_program = int(instrument)
                current_instrument = track_count
                current_is_drum = False
            if instrument == "DRUMS" and use_drums:
                current_instrument = 0
                current_program = 0
                current_is_drum = True
        elif token == "BAR_START":
            current_time = current_bar_index * BAR_LENGTH_120BPM
            current_notes = {}
        elif token == "BAR_END":
            current_bar_index += 1
            pass
        elif token.startswith("NOTE_ON"):
            pitch = int(token.split("=")[-1])
            note = note_sequence.notes.add()
            note.start_time = current_time
            note.end_time = current_time + 4 * NOTE_LENGTH_16TH_120BPM
            note.pitch = pitch
            note.instrument = current_instrument
            note.program = current_program
            note.velocity = 80
            note.is_drum = current_is_drum
            current_notes[pitch] = note
        elif token.startswith("NOTE_OFF"):
            pitch = int(token.split("=")[-1])
            if pitch in current_notes:
                note = current_notes[pitch]
                note.end_time = current_time
        elif token.startswith("TIME_DELTA"):
            delta = float(token.split("=")[-1]) * NOTE_LENGTH_16TH_120BPM
            current_time += delta
        elif token.startswith("DENSITY="):
            pass
        elif token == "[PAD]":
            pass
        else:
            #print(f"Ignored token {token}.")
            pass

    # Make the instruments right.
    instruments_drums = []
    for note in note_sequence.notes:
        pair = [note.program, note.is_drum]
        if pair not in instruments_drums:
            instruments_drums += [pair]
        note.instrument = instruments_drums.index(pair)

    if only_piano:
        for note in note_sequence.notes:
            if not note.is_drum:
                note.instrument = 0
                note.program = 0

    return note_sequence

def empty_note_sequence(qpm=120.0, total_time=0.0):
    note_sequence = note_seq.protobuf.music_pb2.NoteSequence()
    note_sequence.tempos.add().qpm = qpm
    note_sequence.ticks_per_quarter = note_seq.constants.STANDARD_PPQ
    note_sequence.total_time = total_time
    return note_sequence



In [3]:
import os
from datasets import Dataset
# generate a dataset object for model training from the midi files in data by converting them into note sequences
dict_ = {"midi_ids": [], "note_sequences": []}
midi_id = 1
for file in os.listdir("../data/dataset/"):
    print(file)
    if file.endswith(".mid"):
        try:
            note_sequence = note_seq.midi_file_to_note_sequence(os.path.join("../data/dataset/", file))
            dict_["midi_ids"].append(midi_id)
            midi_id += 1
            dict_["note_sequences"].append(note_sequence)
        except:
            pass


01dkc2bonus.mid
01DKC3_Bonus.mid
02dkc2&3-bonusfinished.mid
02main.mid
02snesDKC3_BossV1.1.mid
03dkc2boss.mid
03dkc3_bear.mid
04DKC3cavern.mid
04sento.mid
04snesDKC2_Boss.mid
05bramscrm.mid
05JILost_-_SNES_-_Donkey_Kong_Country_3_-_Baddies_on_Parade.mid




06dkc2scram.mid
06DKC3Factory.mid
07dkc2bram.mid
07dkc3frst.mid
08ddt_gameover.mid
08DKQBrmbl.mid
09dkc3_gameover.mid
09SNES_DKC_Brambles.mid
1-01 Prelude Xg.mid
1-01-Liberi_Fatali.mid
1-02 - To Zanarkand (remix).mid
1-02 - To Zanarkand (version2).mid
1-02 - To Zanarkand (version3).mid
1-02 - To Zanarkand (version4).mid
1-02 - To Zanarkand (version5).mid
1-02 - To Zanarkand.mid
1-02 Opening ~ Bombing Mission Xg.mid
1-02-Balamb_Garden.mid
1-03 - Prelude (Draggor).mid
1-03 - Prelude (version2).mid
1-03 - Prelude (version3).mid
1-03 - Prelude.mid
1-03 Makou Reactor Xg.mid
1-03-Blue_Fields.mid
1-04 - Tidus.mid
1-04 - Tidus_Theme.mid
1-04 Anxious Heart Xg.mid
1-04-Dont_Be_Afraid.mid
1-05 - Otherworld (version2).mid
1-05 - Other_World.mid
1-05 Tifa's Theme Xg.mid
1-05-The_Winner.mid
1-06 Barett's Theme Xg.mid
1-06-Find_Your_Way.mid
1-07 - This is your Story (unsure).mid
1-07 Hurry! Xg.mid
1-07-Seed.mid
1-08 Lurking in the Darkness Xg.mid
1-08-The_Landing.mid
1-09 - Battle_Theme (version2).mi

In [4]:
def get_tracks(note_sequence_notes):
    tracks = [] 
    current_track = []
    current_track_instrument = None
    for note_sequence in note_sequence_notes:
        if not current_track:
            current_track_instrument = note_sequence.program
            current_track.append(note_sequence)
            continue

        if note_sequence.program == current_track_instrument:
            current_track.append(note_sequence)
        else:
            tracks.append(current_track)
            current_track = [note_sequence]
            current_track_instrument = note_sequence.program
    tracks.append(current_track)
    return tracks


In [11]:
# convert the music note sequences into token sequences
def note_sequence_to_token_sequence(note_sequence):
    token_sequence = []
    token_sequence.append("PIECE_START")
    # for each track in the piece generate a token sequence
    for track in get_tracks(note_sequence.notes):
        token_sequence.append("TRACK_START")
        token_sequence.append(f"INST={track[0].program}")
        # for each bar in the track generate a token sequence
        for bar in range(0, int(note_sequence.total_time / BAR_LENGTH)):
            token_sequence.append("BAR_START")
            current_time = 0
            if track[0].start_time > 0:
                delta = track[0].start_time / NOTE_LENGTH_16TH
                token_sequence.append(f"TIME_DELTA={delta}")
            for note in track:
                if note.start_time > current_time:
                    delta = (note.start_time - current_time) / NOTE_LENGTH_16TH
                    token_sequence.append(f"TIME_DELTA={delta}")
                token_sequence.append(f"NOTE_ON={note.pitch}")
                token_sequence.append(f"NOTE_OFF={note.pitch}")
                current_time = note.end_time
            if current_time < note_sequence.total_time:
                delta = (note_sequence.total_time - current_time) / NOTE_LENGTH_16TH
                token_sequence.append(f"TIME_DELTA={delta}")
            token_sequence.append("BAR_END")
        token_sequence.append("TRACK_END")
    token_sequence.append("PIECE_END")
    return ' '.join(token_sequence)
    

In [12]:
dict_["token_sequences"] = [note_sequence_to_token_sequence(note_sequence) for note_sequence in dict_["note_sequences"]]

In [13]:
from datasets import DatasetDict

train_dict = {"midi_ids": dict_["midi_ids"][0:1200], "token_sequences": dict_["token_sequences"][0:1200]}
valid_dict = {"midi_ids": dict_["midi_ids"][1200:], "token_sequences": dict_["token_sequences"][1200:]}

raw_datasets = DatasetDict(
    {
        "train": train_dict,  # .shuffle().select(range(50000)),
        "valid": valid_dict,  # .shuffle().select(range(500))
    }
)

In [18]:
raw_datasets

In [14]:
context_length = 128

def tokenize(element):
    outputs = tokenizer(
        element["content"],
        truncation=True,
        max_length=context_length,
        return_overflowing_tokens=True,
        return_length=True,
    )
    input_batch = []
    for length, input_ids in zip(outputs["length"], outputs["input_ids"]):
        if length == context_length:
            input_batch.append(input_ids)
    return {"input_ids": input_batch}


tokenized_datasets = raw_datasets.map(
    tokenize, batched=True, remove_columns=raw_datasets["train"].column_names
)

In [15]:
from transformers import DataCollatorForLanguageModeling
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

In [16]:
from transformers import Trainer, TrainingArguments

args = TrainingArguments(
    output_dir="codeparrot-ds",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    evaluation_strategy="steps",
    eval_steps=5_000,
    logging_steps=5_000,
    gradient_accumulation_steps=8,
    num_train_epochs=1,
    weight_decay=0.1,
    warmup_steps=1_000,
    lr_scheduler_type="cosine",
    learning_rate=5e-4,
    save_steps=5_000,
    fp16=True,
    push_to_hub=True,
)

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=args,
    data_collator=data_collator,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["valid"],
)

In [3]:
generated_sequence = "PIECE_START"

In [17]:
# Encode the conditioning tokens.
input_ids = tokenizer.encode("PIECE_START", return_tensors="pt")
#print(input_ids)

# Generate more tokens.
eos_token_id = tokenizer.encode("PIECE_END")[0]
temperature = 1.0
generated_ids = model.generate(
    input_ids, 
    max_length=2048,
    do_sample=True,
    temperature=temperature,
    eos_token_id=eos_token_id,
)
generated_sequence = tokenizer.decode(generated_ids[0])
print(generated_sequence)

note_sequence = token_sequence_to_note_sequence(generated_sequence)

synth = note_seq.fluidsynth
note_seq.note_sequence_to_midi_file(note_sequence, "output.mid")