In [1]:
import matplotlib
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import pickle

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
with open('data.pickle', 'rb') as f:
    dataset = pickle.load(f)

In [4]:
train, validation, test = dataset

In [5]:
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.lstm = nn.LSTM(input_size=3, hidden_size=128, batch_first=True)
        self.pitch = nn.Linear(128, 128)
        self.step = nn.Linear(128, 1)
        self.duration = nn.Linear(128, 1)

    def forward(self, x, h = None):
        x = x.float()
        x, h = self.lstm(x, h)
        # x = x[:, -1, :] # Get only the last output from the LSTM
        pitch = self.pitch(x)
        step = self.step(x)
        duration = self.duration(x)
        return {'pitch': pitch, 'step': step, 'duration': duration}, h
    
    def init_state(self):
        z1 = torch.zeros(1, 2, 128)
        z2 = torch.zeros(1, 2, 128)
        return (z1, z2)

model = MyModel()
model.to(device)

MyModel(
  (lstm): LSTM(3, 128, batch_first=True)
  (pitch): Linear(in_features=128, out_features=128, bias=True)
  (step): Linear(in_features=128, out_features=1, bias=True)
  (duration): Linear(in_features=128, out_features=1, bias=True)
)

In [6]:
# Print model summary
print(model)

MyModel(
  (lstm): LSTM(3, 128, batch_first=True)
  (pitch): Linear(in_features=128, out_features=128, bias=True)
  (step): Linear(in_features=128, out_features=1, bias=True)
  (duration): Linear(in_features=128, out_features=1, bias=True)
)


In [7]:
# Set number of epochs to train for
num_epochs = 50
batch_size = 64
val_batch_size = 128

train_loader = DataLoader(train, shuffle=True, batch_size=batch_size)
val_loader = DataLoader(validation, shuffle=True, batch_size=val_batch_size)

loss = {
    'pitch': nn.CrossEntropyLoss(),
    'step': nn.MSELoss(),
    'duration': nn.MSELoss()
    }
learning_rate = 0.001
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Train the model
for epoch in range(num_epochs):
    # Train the model for one epoch
    train_loss = 0.000
    
    for batch_idx, data in enumerate(train_loader):
        if data.size()[0] < batch_size:
            continue

        h = None
        data = data.to(device)
        optimizer.zero_grad()
        outputs, h = model(data, h)
        batch_loss = []
        for idx, key in enumerate(loss):
            if key == 'pitch':
                tmp_loss = loss[key](outputs[key].permute(0, 2, 1), data[:, :, idx].long())
            else:
                tmp_loss = loss[key](outputs[key].permute(0, 2, 1), data[:, :, idx].reshape(batch_size, 64, 1).permute(0, 2, 1).float())

            batch_loss.append(tmp_loss)

        final_loss = batch_loss[0] + batch_loss[1] + batch_loss[2]
        train_loss += final_loss.item()
        (state_h, state_c) = h
        state_h = state_h.detach()
        state_c = state_c.detach()
        
        final_loss.backward()
        optimizer.step()

    # Evaluate the model on the validation set
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        
        for batch_idx, data in enumerate(val_loader):
            data = data.to(device)
            if data.size()[0] < val_batch_size:
                continue

            h = None
            batch_loss = []
            outputs, h = model(data)
            # batch_loss = sum(loss[key](outputs[key], targets[key]) for key in loss)
            for idx, key in enumerate(loss):
                if key == 'pitch':
                    tmp_loss = loss[key](outputs[key].permute(0, 2, 1), data[:, :, idx].long())
                else:
                    tmp_loss = loss[key](outputs[key].permute(0, 2, 1), data[:, :, idx].reshape(val_batch_size, 64, 1).permute(0, 2, 1).float())

                batch_loss.append(tmp_loss)
            
            final_loss = batch_loss[0] + batch_loss[1] + batch_loss[2]
            val_loss += final_loss.item()
    model.train()


    # Print the epoch statistics
    print(f'Epoch {epoch}: train loss={train_loss/len(train_loader):.4f}, val loss={val_loss/len(val_loader):.4f}')


Epoch 0: train loss=4.4162, val loss=3.2227
Epoch 1: train loss=3.7094, val loss=2.8887
Epoch 2: train loss=3.4744, val loss=2.7423
Epoch 3: train loss=3.3403, val loss=2.6623
Epoch 4: train loss=3.2262, val loss=2.5705
Epoch 5: train loss=3.1249, val loss=2.4877
Epoch 6: train loss=3.0385, val loss=2.4191
Epoch 7: train loss=2.9667, val loss=2.3517
Epoch 8: train loss=2.9027, val loss=2.3121
Epoch 9: train loss=2.8433, val loss=2.2557
Epoch 10: train loss=2.7824, val loss=2.2186
Epoch 11: train loss=2.7271, val loss=2.1681
Epoch 12: train loss=2.6755, val loss=2.1185
Epoch 13: train loss=2.6253, val loss=2.0769
Epoch 14: train loss=2.5816, val loss=2.0483
Epoch 15: train loss=2.4815, val loss=2.0043
Epoch 16: train loss=2.4930, val loss=1.9743
Epoch 17: train loss=2.4561, val loss=1.9370
Epoch 18: train loss=2.4164, val loss=1.8958
Epoch 19: train loss=2.3792, val loss=1.8770
Epoch 20: train loss=2.3461, val loss=1.8512
Epoch 21: train loss=2.3093, val loss=1.8212
Epoch 22: train loss

In [None]:
# TODO: Create a generate music function