In [None]:
from transformers import TFGPT2LMHeadModel, TFGPT2LMHeadModel, AutoConfig
from transformers import pipeline

%pip install tokenizers
from tokenizers import Tokenizer
from tokenizers.models import WordPiece
from tokenizers.pre_tokenizers import Whitespace

# load tokenizer
from transformers import AutoTokenizer
tokenizer_path = "./local_wrapped_tokenizer"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

# load model
context_length = 256
config = AutoConfig.from_pretrained(
    "gpt2",
    vocab_size=len(tokenizer),
    n_ctx=context_length,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
)
model = TFGPT2LMHeadModel(config)
model.load_weights("check/check_1")

pipe = pipeline(
    "text-generation", model=model, tokenizer=tokenizer, device=0
)

!pip install note_seq
import note_seq
BPM_1_SECOND = 60

# Variables to change based on the time signature
numerator = ""
denominator = ""

def token_sequence_to_note_sequence(token_sequence, 
                                    use_program=True, 
                                    use_drums=True, 
                                    instrument_mapper=None, 
                                    only_piano=False):

    if isinstance(token_sequence, str):
        token_sequence = token_sequence.split()

    note_sequence = empty_note_sequence()

    # Render all notes.
    current_program = 1
    current_is_drum = False
    current_instrument = 0
    track_count = 0
    for token_index, token in enumerate(token_sequence):

        if token == "PIECE_START":
            pass
        elif token == "PIECE_END":
            print("The end.")
            break
        elif token.startswith("TIME_SIGNATURE="):
            time_signature_str = token.split("=")[-1]
            numerator = int(time_signature_str.split("_")[0])
            denominator = int(time_signature_str.split("_")[-1])
            time_signature = note_sequence.time_signatures.add()
            time_signature.numerator = numerator
            time_signature.denominator = denominator
        elif token.startswith("BPM="):
            bpm_str = token.split("=")[-1]
            bpm = int(bpm_str)
            note_sequence.tempos[0].qpm = bpm
            pulse_duration, bar_duration = duration_in_sec(
                bpm, numerator, denominator
            )
        elif token == "TRACK_START":
            current_bar_index = 0
            track_count += 1
            pass
        elif token == "TRACK_END":
            pass
        elif token == "KEYS_START":
            pass
        elif token == "KEYS_END":
            pass
        elif token.startswith("KEY="):
            pass
        elif token.startswith("INST"):
            instrument = token.split("=")[-1]
            if instrument != "DRUMS" and use_program:
                if instrument_mapper is not None:
                    if instrument in instrument_mapper:
                        instrument = instrument_mapper[instrument]
                current_program = int(instrument)
                current_instrument = track_count
                current_is_drum = False
            if instrument == "DRUMS" and use_drums:
                current_instrument = 0
                current_program = 0
                current_is_drum = True
        elif token == "BAR_START":
            current_time = (current_bar_index * bar_duration)
            current_notes = {}
        elif token == "BAR_END":
            current_bar_index += 1
            pass
        elif token.startswith("NOTE_ON"):
            pitch = int(token.split("=")[-1])
            note = note_sequence.notes.add()
            note.start_time = current_time
            note.end_time = current_time + denominator * pulse_duration
            note.pitch = pitch
            note.instrument = current_instrument
            note.program = current_program
            note.velocity = 80
            note.is_drum = current_is_drum
            current_notes[pitch] = note
        elif token.startswith("NOTE_OFF"):
            pitch = int(token.split("=")[-1])
            if pitch in current_notes:
                note = current_notes[pitch]
                note.end_time = current_time
        elif token.startswith("TIME_DELTA"):
            delta = float(token.split("=")[-1]) * (0.25) * pulse_duration
            current_time += delta
        elif token.startswith("DENSITY="):
            pass
        elif token == "[PAD]":
            pass
        else:
            #print(f"Ignored token {token}.")
            pass

    # Make the instruments right.
    instruments_drums = []
    for note in note_sequence.notes:
        pair = [note.program, note.is_drum]
        if pair not in instruments_drums:
            instruments_drums += [pair]
        note.instrument = instruments_drums.index(pair)

    if only_piano:
        for note in note_sequence.notes:
            if not note.is_drum:
                note.instrument = 0
                note.program = 0

    return note_sequence

def duration_in_sec(bpm, numerator, denominator):
    pulse_duration = BPM_1_SECOND / bpm
    number_of_quarters_per_bar = (4 / denominator) * numerator
    bar_duration = pulse_duration * number_of_quarters_per_bar
    return pulse_duration, bar_duration

def empty_note_sequence(qpm=120, total_time=0.0):
    note_sequence = note_seq.protobuf.music_pb2.NoteSequence()
    note_sequence.tempos.add().qpm = qpm
    note_sequence.total_time = total_time
    return note_sequence

seed = "PIECE_START TIME_SIGNATURE=4_4 BPM=120 TRACK_START INST=0 DENSITY=2 BAR_START NOTE_ON=43"

for i in range(1):
    first_piece = pipe(seed, max_length=100)[0]["generated_text"]
    note_sequence = token_sequence_to_note_sequence(first_piece)
    synth = note_seq.midi_synth.synthesize
    note_seq.note_sequence_to_midi_file(note_sequence, 'ggg'+str(i)+'.mid')
