In [9]:
import matplotlib
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import pickle

In [10]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [11]:
with open('data.pickle', 'rb') as f:
    dataset = pickle.load(f)

In [12]:
train, validation, test = dataset

In [13]:
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.lstm = nn.LSTM(input_size=3, hidden_size=128, batch_first=True)
        self.pitch = nn.Linear(128, 128)
        self.step = nn.Linear(128, 1)
        self.duration = nn.Linear(128, 1)

    def forward(self, x, h = None):
        x = x.float()
        x, h = self.lstm(x, h)
        # x = x[:, -1, :] # Get only the last output from the LSTM
        pitch = self.pitch(x)
        step = self.step(x)
        duration = self.duration(x)
        return {'pitch': pitch, 'step': step, 'duration': duration}, h
    
    def init_state(self):
        z1 = torch.zeros(1, 2, 128)
        z2 = torch.zeros(1, 2, 128)
        return (z1, z2)

model = MyModel()
model.to(device)

MyModel(
  (lstm): LSTM(3, 128, batch_first=True)
  (pitch): Linear(in_features=128, out_features=128, bias=True)
  (step): Linear(in_features=128, out_features=1, bias=True)
  (duration): Linear(in_features=128, out_features=1, bias=True)
)

In [14]:
# Print model summary
print(model)

MyModel(
  (lstm): LSTM(3, 128, batch_first=True)
  (pitch): Linear(in_features=128, out_features=128, bias=True)
  (step): Linear(in_features=128, out_features=1, bias=True)
  (duration): Linear(in_features=128, out_features=1, bias=True)
)


In [17]:
# Set number of epochs to train for
num_epochs = 50
batch_size = 64
val_batch_size = 128

train_loader = DataLoader(train, shuffle=True, batch_size=batch_size)
val_loader = DataLoader(validation, shuffle=True, batch_size=val_batch_size)

loss = {
    'pitch': nn.CrossEntropyLoss(),
    'step': nn.MSELoss(),
    'duration': nn.MSELoss()
    }
learning_rate = 0.001
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Train the model
for epoch in range(num_epochs):
    # Train the model for one epoch
    train_loss = 0.000
    h = None
    
    for batch_idx, data in enumerate(train_loader):
        if data.size()[0] < batch_size:
            continue

        data = data.to(device)
        optimizer.zero_grad()
        outputs, h = model(data, h)
        (state_h, state_c) = h
        state_h = state_h.detach()
        state_c = state_c.detach()
        h = (state_h, state_c)
        batch_loss = []
        for idx, key in enumerate(loss):
            if key == 'pitch':
                tmp_loss = loss[key](outputs[key].permute(0, 2, 1), data[:, :, idx].long())
            else:
                tmp_loss = loss[key](outputs[key].permute(0, 2, 1), data[:, :, idx].reshape(batch_size, 64, 1).permute(0, 2, 1).float())

            batch_loss.append(tmp_loss)

        final_loss = batch_loss[0] + batch_loss[1] + batch_loss[2]
        train_loss += final_loss.item()
        
        
        final_loss.backward()
        optimizer.step()

    # Evaluate the model on the validation set
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        h = None
        
        for batch_idx, data in enumerate(val_loader):
            data = data.to(device)
            if data.size()[0] < val_batch_size:
                continue

            batch_loss = []
            outputs, h = model(data)

            (state_h, state_c) = h
            state_h = state_h.detach()
            state_c = state_c.detach()
            h = (state_h, state_c)

            # batch_loss = sum(loss[key](outputs[key], targets[key]) for key in loss)
            for idx, key in enumerate(loss):
                if key == 'pitch':
                    tmp_loss = loss[key](outputs[key].permute(0, 2, 1), data[:, :, idx].long())
                else:
                    tmp_loss = loss[key](outputs[key].permute(0, 2, 1), data[:, :, idx].reshape(val_batch_size, 64, 1).permute(0, 2, 1).float())

                batch_loss.append(tmp_loss)
            
            final_loss = batch_loss[0] + batch_loss[1] + batch_loss[2]
            val_loss += final_loss.item()
            
    model.train()


    # Print the epoch statistics
    print(f'Epoch {epoch}: train loss={train_loss/len(train_loader):.4f}, val loss={val_loss/len(val_loader):.4f}')


Epoch 0: train loss=1.6698, val loss=1.8062
Epoch 1: train loss=1.6050, val loss=1.8123
Epoch 2: train loss=1.5939, val loss=1.8089
Epoch 3: train loss=1.5827, val loss=1.8060
Epoch 4: train loss=1.5713, val loss=1.8025
Epoch 5: train loss=1.5589, val loss=1.7965
Epoch 6: train loss=1.4419, val loss=1.7914
Epoch 7: train loss=1.5360, val loss=1.7876
Epoch 8: train loss=1.5253, val loss=1.7843
Epoch 9: train loss=1.5131, val loss=1.7877
Epoch 10: train loss=1.4995, val loss=1.7754
Epoch 11: train loss=1.4876, val loss=1.7774
Epoch 12: train loss=1.4761, val loss=1.7647
Epoch 13: train loss=1.4642, val loss=1.7580
Epoch 14: train loss=1.4504, val loss=1.7551
Epoch 15: train loss=1.4415, val loss=1.7530
Epoch 16: train loss=1.4297, val loss=1.7514
Epoch 17: train loss=1.4186, val loss=1.7419
Epoch 18: train loss=1.4070, val loss=1.7400
Epoch 19: train loss=1.3972, val loss=1.7349
Epoch 20: train loss=1.3842, val loss=1.7342
Epoch 21: train loss=1.3739, val loss=1.7347
Epoch 22: train loss

In [16]:
# TODO: Create a generate music function