In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import stilus.models as m
import pytorch_lightning as pl
from stilus.data.sets import MidiDataset
from torch.utils.data import DataLoader
from pytorch_lightning.callbacks import EarlyStopping

In [None]:
cuda_available =  torch.cuda.is_available()
device = torch.device('cuda' if cuda_available else 'cpu')
print('Using device:', device)

nets = [m.ConvNet_1_0_0(), m.ConvNet_1_0_1(), m.ConvNet_1_0_2(), m.ConvNet_1_0_3(),m.TransformerNet_1_0_0(),m.TransformerNet_1_0_1(),m.TransformerNet_1_0_2()]

print(nets)

In [None]:
# Test that tensor shapes are correct
for net in nets:
    print(type(net).__name__)
    input = torch.randn(128, 5, 32).to(device)
    out = net(input)
    print("input:", input.shape)
    print("output:", out.shape)

In [None]:
for net in nets:
    net_name = type(net).__name__
    
    early_stopping = EarlyStopping('loss')
    
    print('Starting to train:', net_name)
    trainer = pl.Trainer(min_epochs=30, gpus=0, early_stop_callback=early_stopping)
    trainer.fit(net)
    trainer.test(net)
    print('Finished training', net_name)

In [None]:
for net in nets:
    model_pth_name = "model_weights/tmp_" + type(net).__name__ + ".pth" 
    torch.save(net.state_dict(), model_pth_name)

In [None]:
net = m.ConvNet_1_0_0().to(device)
net.load_state_dict(torch.load("./model_weights/30epochs_1.0.0.pth", map_location=torch.device(device)))
net.eval()

In [None]:
midi_test_dataset = MidiDataset("test_data.npy", net.midi_dataset.mean, net.midi_dataset.std)
test_dataloader = DataLoader(midi_test_dataset, batch_size=64, shuffle=False)

In [None]:
def std_tensor_to_int(pred):
    return ((pred * net.midi_dataset.std) + net.midi_dataset.mean).int()

In [None]:
for i, data in enumerate(test_dataloader, 0):
    # get the inputs; data is a list of [inputs, labels]
    inputs, labels = data
    #print(inputs)
    print("labels:",  std_tensor_to_int(labels))
    pred = net(inputs)
    print("pred:",  std_tensor_to_int(pred))
   
    