In [None]:
import random
from mido import Message, MidiFile, MidiTrack, MetaMessage
import stilus.midi.utils as utl
import stilus.models as m
import numpy as np

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from stilus.data.sets import MidiDataset
from torch.utils.data import DataLoader

In [None]:
#import a midi file
file_path = "midi/training/mozart/mz_545_3.mid"
tokens = file_path.split("/")
file_name = tokens[len(tokens)-1]
file_name_no_ext = file_name.split(".")[0]

mid = MidiFile(file_path)
file_name_no_ext

In [None]:
# print all events
for i, track in enumerate(mid.tracks):
    print('Track {}: {}'.format(i, track.name))
    for msg in track:
        #if msg.type == "note_on":
        print(msg)

In [None]:
tensor = utl.convert_midi_to_time_series(mid,5,5,8)
print(tensor.shape)
n = 136
print(tensor[:,n-32:n])

In [None]:
version ="version_5"
epochs = "4"
#net = m.ConvNet_1_0_2().load_from_checkpoint("tb_logs/ConvNet_1_0_2_mozart/"+ version +"/checkpoints/epoch="+ epochs +".ckpt")
net = m.TransformerNet_1_0_2().load_from_checkpoint("tb_logs/TransformerNet_1_0_2_mozart/"+ version +"/checkpoints/epoch="+ epochs +".ckpt")
net.eval( )

In [None]:
#set path to dataloaders
net.set_data_path("data/mozart")

In [None]:
def std_tensor_to_int(pred, net):
    return ((pred * net.midi_dataset.std) + net.midi_dataset.mean).astype(int)

In [None]:
def int_to_std_tensor(input, net):
    return ((input - net.midi_dataset.mean) / net.midi_dataset.std)

In [None]:
def imagine_midi(time_series, net, steps) :
    std_time_series = int_to_std_tensor(time_series, net)
    for i in range(steps):
        series_len = len(std_time_series[0])
        numpy_tensor = std_time_series[:,series_len-64:series_len].astype("float32")
        #print("in:", numpy_tensor)
        tensor_in = torch.unsqueeze(torch.from_numpy(numpy_tensor),0)
        pred = net(tensor_in)
        numpy_pred = pred.detach().numpy().T
        #print("out:", numpy_pred)
        std_time_series = np.concatenate((std_time_series,numpy_pred), axis=1)
    
    return std_tensor_to_int(std_time_series, net)

In [None]:
generated_series = imagine_midi(tensor[:,0:136], net, 161)
print(generated_series[:,136:172])
print(generated_series.shape)

In [None]:
def write_midi_from_series(generated_series):
    outfile = MidiFile()

    step_size = int(outfile.ticks_per_beat / 8)
    
    track = MidiTrack()
    outfile.tracks.append(track)

    track.append(Message('program_change', program=12))

    delta = 0
    
    for i in range(len(generated_series[0])):
        for j in range(len(generated_series)):
            note = generated_series[j,i]
            if note > 0:
                print(note)
                track.append(Message('note_on', note=note, velocity=100, time=delta))
                delta = 0

        delta = delta + step_size

    track.append( MetaMessage('end_of_track'))
    out_path = 'midi/weight_analysis/' + file_name_no_ext +"_" + net.name + "_" + version + '.mid'
    print("creating file: ", out_path)
    outfile.save(out_path)

In [None]:
write_midi_from_series(generated_series)