In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import stilus.models as m
from stilus.data.sets import MidiDataset
from torch.utils.data import DataLoader

In [None]:
cuda_available =  torch.cuda.is_available()
device = torch.device('cuda' if cuda_available else 'cpu')
print('Using device:', device)

nets = [m.ConvNet_1_0_0().to(device), m.ConvNet_1_0_1().to(device), m.ConvNet_1_0_2().to(device), m.ConvNet_1_0_3().to(device),
       m.TransformerNet_1_0_0().to(device),m.TransformerNet_1_0_1().to(device),m.TransformerNet_1_0_2().to(device)]

print(nets)

In [None]:
for net in nets:
    print(type(net).__name__)
    input = torch.randn(128, 5, 32).to(device)
    out = net(input)
    print(input.shape)
    print(out.shape)

In [None]:
epochs = 20

In [None]:
midi_dataset = MidiDataset("training_data.npy")
dataloader = DataLoader(midi_dataset, batch_size=128, shuffle=True)

In [None]:
for net in nets:
    criterion = nn.L1Loss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    print("Starting to train ", type(net).__name__)
    for epoch in range(epochs):  # loop over the dataset multiple times
        running_loss = 0.0
        for i, data in enumerate(dataloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            
            inputs, labels = data[:,:,0:32].to(device), data[:,:,32].to(device)
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 1000 == 999:    # print every 2000 mini-batches
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 1000))
                running_loss = 0.0

    print('Finished Training', type(net).__name__)

In [None]:
for net in nets:
    model_pth_name = "model_weights/" + str(epochs*3) +"epochs" + type(net).__name__ + ".pth" 
    torch.save(net.state_dict(), model_pth_name)

In [None]:
net = m.ConvNet_1_0_0().to(device)
net.load_state_dict(torch.load("./model_weights/30epochs_1.0.0.pth", map_location=torch.device(device)))
net.eval()

In [None]:
midi_test_dataset = MidiDataset("test_data.npy", midi_dataset.mean, midi_dataset.std)
test_dataloader = DataLoader(midi_test_dataset, batch_size=64, shuffle=False)

In [None]:
def std_tensor_to_int(pred):
    return ((pred * midi_dataset.std) + midi_dataset.mean).int()

In [None]:
for i, data in enumerate(test_dataloader, 0):
    # get the inputs; data is a list of [inputs, labels]
    print(data.shape)
    inputs, labels = data[1:64,:,0:32], data[1:64,:,32]
    #print(inputs)
    print("labels:",  std_tensor_to_int(labels))
    pred = net(inputs)
    print("pred:",  std_tensor_to_int(pred))
   
    