In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from stilus.data.sets import MidiDataset
from torch.utils.data import DataLoader

In [None]:
class TransformerNet(nn.Module):

    def __init__(self):
        super(TransformerNet, self).__init__() # bs * 5 * 32
        self.encoder0 = nn.TransformerEncoderLayer(d_model=32, nhead=8)
        self.encoder = nn.TransformerEncoder(self.encoder0, num_layers=3)
        self.fc0 = nn.Linear(32 * 5, 5)

    def forward(self, x):
        x = self.encoder(x)
        x = x.view(-1, self.num_flat_features(x))
        x = self.fc0(x)
        return x
    
    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features


In [None]:
cuda_available =  torch.cuda.is_available()
device = torch.device('cuda' if cuda_available else 'cpu')
print('Using device:', device)

transformer_net = TransformerNet().to(device)

print(transformer_net)

In [None]:
input = torch.randn(128, 5, 32).to(device)
out = transformer_net(input)
print(input.shape)
print(out.shape)

In [None]:
criterion = nn.L1Loss()
optimizer = optim.SGD(transformer_net.parameters(), lr=0.001, momentum=0.9)

In [None]:
midi_dataset = MidiDataset("training_data.npy")
dataloader = DataLoader(midi_dataset, batch_size=32, shuffle=True)

In [None]:
for epoch in range(60):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(dataloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data[:,:,0:32].to(device), data[:,:,32].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = transformer_net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 1000 == 999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 1000))
            running_loss = 0.0

print('Finished Training')

In [None]:
for i, data in enumerate(dataloader, 0):
    # get the inputs; data is a list of [inputs, labels]
    if i == 0:
        print( data.shape)
    inputs, labels = data[5:7,:,0:32].to(device), data[5:7,:,32].to(device)
    print("labels:", labels)
    print("pred:", transformer_net(inputs))
    
    if i == 2:
        break

In [None]:
torch.save(transformer_net.state_dict(), "models/60epochs_transformer_1.0.2.pth")

In [None]:
conv_net = TransformerNet()
conv_net.load_state_dict(torch.load("./models/60epochs_transformer_1.0.2.pth"))
conv_net.eval()

In [None]:
midi_test_dataset = MidiDataset("test_data.npy", midi_dataset.mean, midi_dataset.std)
test_dataloader = DataLoader(midi_test_dataset, batch_size=64, shuffle=False)

In [None]:
def std_tensor_to_int(pred):
    return ((pred * midi_dataset.std) + midi_dataset.mean).int()

In [None]:
for i, data in enumerate(test_dataloader, 0):
    # get the inputs; data is a list of [inputs, labels]
    print(data.shape)
    inputs, labels = data[1:64,:,0:32], data[1:64,:,32]
    #print(inputs)
    print("labels:",  std_tensor_to_int(labels))
    pred = conv_net(inputs)
    print("pred:",  std_tensor_to_int(pred))
   
    