In [1]:
import sys
import torch
import os

sys.path.append("../src")
from models.transformer_model import TransformerModel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda", 0)

# ckpt_path = os.getcwd() + '/../mt-H8-L8-epoch80.ckpt'
ckpt_path = os.getcwd() + "/../13230.ckpt"

model = TransformerModel.load_from_checkpoint(ckpt_path, map_location=device).to(device)



In [11]:
from pathlib import Path
from miditok import Structured, TokSequence, REMI
from miditok.utils import get_midi_programs
from miditoolkit import MidiFile

# tokenizer = Structured()
# tokenizer.load_params(Path("../data/tokenizer_params.json"))

tokenizer = REMI()


def generate(name, seed, sample_len, **kwargs):
    input_ids = torch.tensor(seed).unsqueeze(0).to(device)
    gen_ids = model.transformer.generate(
        input_ids=input_ids,
        max_length=sample_len,
        eos_token_id=tokenizer.vocab["EOS_None"],
        pad_token_id=tokenizer.vocab["PAD_None"],
        **kwargs
    ).tolist()[0]
    seq = TokSequence(ids=gen_ids)
    tokenizer([seq]).dump(name)
    # seq = TokSequence(ids=gen_ids, ids_bpe_encoded=True)
    # tokenizer.decode_bpe(seq)
    # return seq

In [None]:
import miditoolkit
import matplotlib.pyplot as plt
import numpy as np

# load midi file
midi_obj = miditoolkit.midi.parser.MidiFile('/home/mm/midi/snes_clean/ff2airsh.mid')

# set up a list of colormaps
cmaps = ['Blues', 'Greens', 'Reds', 'Purples', 'Oranges', 'Greys', 'YlOrBr', 'YlOrRd', 'OrRd', 'PuRd', 'RdPu', 'BuPu', 'GnBu', 'PuBu', 'YlGnBu', 'PuBuGn', 'BuGn', 'YlGn']

# prepare piano rolls for all instruments
min_time = min(note.start 
    for instrument in midi_obj.instruments 
    for note in instrument.notes)
max_time = max(note.end 
    for instrument in midi_obj.instruments 
    for note in instrument.notes)
piano_roll = np.zeros((128, max_time-min_time))

for i, instrument in enumerate(midi_obj.instruments):
    for note in instrument.notes:
        piano_roll[note.pitch, note.start-min_time:note.end-min_time] = i + 1

# plot piano rolls
plt.figure(figsize=(20,8))
plt.imshow(piano_roll, aspect='auto', origin='lower')
plt.title('Piano roll of the MIDI file')
plt.xlabel('Time')
plt.ylabel('Pitch')
plt.show()



In [4]:
import os

midi_dir = "/home/mm/midi/maestro/2004/"
midi_files = [os.path.join(midi_dir, f) for f in os.listdir(midi_dir)]

In [None]:
tokenizer.vocab

In [13]:
generate("test1.mid", [1], 500, do_sample=True, top_p=0.95)

In [15]:
import miditoolkit

midi_obj = miditoolkit.midi.parser.MidiFile("/home/mm/midi/snes_clean/ff2airsh.mid")
tokenizer(midi_obj)

[TokSequence(tokens=['Bar_None', 'Position_0', 'Pitch_67', 'Velocity_99', 'Duration_2.0.8', 'Position_16', 'Pitch_79', 'Velocity_99', 'Duration_1.0.8', 'Position_24', 'Pitch_79', 'Velocity_99', 'Duration_1.0.8', 'Bar_None', 'Position_0', 'Pitch_77', 'Velocity_99', 'Duration_1.0.8', 'Position_8', 'Pitch_75', 'Velocity_99', 'Duration_1.0.8', 'Position_16', 'Pitch_74', 'Velocity_99', 'Duration_1.0.8', 'Position_24', 'Pitch_72', 'Velocity_99', 'Duration_1.0.8', 'Bar_None', 'Position_0', 'Pitch_70', 'Velocity_99', 'Duration_2.0.8', 'Position_16', 'Pitch_72', 'Velocity_99', 'Duration_1.0.8', 'Position_24', 'Pitch_74', 'Velocity_99', 'Duration_1.0.8', 'Bar_None', 'Position_0', 'Pitch_72', 'Velocity_99', 'Duration_7.0.4', 'Bar_None', 'Position_24', 'Pitch_70', 'Velocity_99', 'Duration_1.0.8', 'Bar_None', 'Position_0', 'Pitch_72', 'Velocity_99', 'Duration_7.0.4', 'Bar_None', 'Position_24', 'Pitch_67', 'Velocity_99', 'Duration_1.0.8', 'Bar_None', 'Position_0', 'Pitch_74', 'Velocity_99', 'Duratio

In [None]:
seed = [
    Event(Event.BAR),
    Event(Event.CHORD, midi_encoder.chord_encoding.encode_event("C")),
    Event(Event.TIME_SHIFT, 16),
    Event(Event.BAR),
    Event(Event.CHORD, midi_encoder.chord_encoding.encode_event("Am")),
    Event(Event.TIME_SHIFT, 16),
    Event(Event.BAR),
]
seed_ids = [midi_encoder.token_sos] + [
    midi_encoder.encoding.encode_event(e) + midi_encoder.num_reserved_ids for e in seed
]

In [None]:
generate(seed_ids, 128)

In [None]:
f = "/home/mm/midi/Beatles/Revolver/EleanorRigby.mid"
generate(load_ids(f)[:64], 128)

In [None]:
f = "../midi/Beatles/PastMasters1/FromMeToYou.mid"
generate(load_ids(f)[:32], 256)

In [None]:
play_ids(load_ids("/home/mm/midi/snes/CloudMan.mid"))

In [None]:
import note_seq
from preprocess import Event
from note_seq.sequences_lib import steps_per_bar_in_quantized_sequence
import torch.nn.functional as F

reload(preprocess)

midi_encoder = preprocess.MIDIMetricEncoder()


class SequenceGenerator:
    def __init__(self, model, device, max_seq, midi_encoder, seed):
        self.model = model
        self.device = device
        self.current_step = 0
        self.max_seq = max_seq
        self.events = []
        self.midi_encoder = midi_encoder
        self.steps_per_bar = int(steps_per_bar_in_quantized_sequence(seed))
        self.seed = midi_encoder.encode_note_sequence_to_events(seed)
        print("steps per bar", self.steps_per_bar)

    def append_event(self, event):
        if event.event_type == Event.TIME_SHIFT:
            for i in range(event.event_value):
                self.current_step += 1
                if self.current_step % self.steps_per_bar == 0:
                    # crossing a bar -> split up time shift
                    self.events.append(Event(Event.TIME_SHIFT, i))
                    self.events.append(Event(Event.BAR))
                    if i < event.event_value:
                        self.events.append(
                            Event(Event.TIME_SHIFT, event.event_value - i)
                        )
                    return
        if event.event_type != Event.BAR:
            self.events.append(event)

    def event_ids(self):
        return [self.midi_encoder.token_sos] + [
            self.midi_encoder.encoding.encode_event(event)
            + self.midi_encoder.num_reserved_ids
            for event in self.events
        ]

    def gen_one_event(self, temperature=1.0, top_p=0.9, top_k=10):
        self.model.eval()
        x = torch.LongTensor(self.event_ids()).unsqueeze(0).to(self.device)
        logits = self.model(x)[0, -1, :]
        probs = F.softmax(logits / temperature, dim=-1)
        sorted_probs, sorted_indices = torch.sort(probs, descending=True)
        to_remove = torch.cumsum(sorted_probs, dim=-1) > top_p
        to_remove[0] = False  # always incude top result
        to_remove[top_k:] = True  # at most include top K results
        indices_to_remove = sorted_indices[to_remove]
        probs.scatter_(0, indices_to_remove, 0.0)
        c = torch.multinomial(probs, 1)
        e = self.midi_encoder.encoding.decode_event(
            c.item() - self.midi_encoder.num_reserved_ids
        )
        self.append_event(e)

    def seed_one_bar(self):
        start_step = self.current_step
        i = 0
        while self.current_step < self.steps_per_bar + start_step:
            self.events.append(self.seed[i])
            self.current_step += 1
            i += 1

    def gen_one_bar(self):
        start_step = self.current_step
        while self.current_step < start_step + self.steps_per_bar:
            self.gen_one_event()

    def to_sequence(self):
        ids = self.midi_encoder.encode_events(self.events)
        return self.midi_encoder.decode_ids(ids)

In [None]:
ns = midi_encoder.load_midi("/content/drive/MyDrive/midi/final_fantasy/ff1cast3.mid")
gen = SequenceGenerator(
    model=mt, device=device, max_seq=max_seq, midi_encoder=midi_encoder, seed=ns[0]
)
# gen.seed_one_bar()
# gen.seed_one_bar()
# gen.seed_one_bar()
# gen.seed_one_bar()
gen.seed[:50]

In [None]:
gen.gen_one_bar()
# gen.events
# play_sequence(gen.to_sequence())

In [None]:
gen_dir = "/content/drive/MyDrive/gen/final_fantasy_bar_fixed_e600_t0.9"

!mkdir -p $gen_dir

for f in os.listdir(midi_dir):
    print(f)
    midi_file = os.path.join(midi_dir, f)
    gen_file = os.path.join(gen_dir, f)
    seed = load_ids(midi_file)
    if seed is None:
        continue
    gen_ids = utils.sample(
        model=mt, 
        sample_length=512, 
        prime_sequence=seed[:200], 
        device=device, 
        top_p=0.9
    )
    gen = midi_encoder.decode_ids(gen_ids)
    # play_sequence(gen, synth=fluidsynth)
    note_sequence_to_midi_file(gen, gen_file)