In [6]:
import matplotlib
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim

# **Initial Model**

In [111]:
class MusicGenRNN(nn.Module):
  def __init__(self, hidden_size=512, num_layers=1, bias=True):
    super(MusicGenRNN, self).__init__()
    # input # pitch, step and duration 
    # 128 is number of pitch possiblities
    self.one_hot_size = 128 + 2 # account for <BOS> and <EOS> this will be the second last and last features of the one hot
    self.pitch_embedding_size = 128 + 2
    self.hidden_size = hidden_size
    # size of embedding plus the 2 cts values step and duration
    self.input_size = self.pitch_embedding_size + 2 
    # size of discrete one hot plus the 2 cts values step and duration
    self.output_size = self.one_hot_size + 2 

    # identiy matrix for generating one-hot vectors
    self.ident = torch.eye(self.one_hot_size) 
    self.pitch_embedding = nn.Linear(self.one_hot_size, self.pitch_embedding_size, bias=False)

    #self.rnn = nn.LSTM(input_size, hidden_size, num_layers, bias, batch_first=True, dropout)
    self.rnn = nn.LSTM(self.input_size, hidden_size, num_layers, bias=bias, batch_first=True, dropout=0)
    # a fully-connect layer that outputs a distribution over the next token, given the RNN output
    self.decoder = nn.Linear(hidden_size, self.output_size)

  def forward(self, input, hidden_in=None):
    inp_pitch = input[:, :, 0]
    inp_step = input[:, :, 1]
    inp_duration = input[:, :, 2]
    inp_pitch = inp_pitch.long()
    # generate one-hot vector for discrete part of input
    one_hot_pitch = self.ident[inp_pitch].float()
    # embed the pitch to make it cts
    embedded_pitch = self.pitch_embedding(one_hot_pitch)
    # make inp = batch_size x sequence_length x 132
    inp = torch.concat((embedded_pitch, inp_step.reshape(*inp_step.shape, 1), inp_duration.reshape(*inp_duration.shape, 1)), dim=2)
    output, hidden_out = self.rnn(inp, hidden_in) # get the next output and hidden state
    output = self.decoder(output) # predict distribution over next tokens
    return output, hidden_out

# **Trainning Over Fitting**

In [30]:
import pickle
from google.colab import drive


drive.mount('/content/gdrive')
path_to_data = '/content/gdrive/My Drive/University/Year 4/CSC413/Project/data.pickle'

with open(path_to_data, 'rb') as f:
    dataset = pickle.load(f)

train_set, validation_set, test_set = dataset

print(train_set.shape)
#print(train_set[0])

# train contains N samples, of 64 length sequences, each sequence token is length 3.
# token in sequence is [pitch, step, duration]

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).
(3603, 64, 3)


In [126]:
def train(model, train_data, valid_data, batch_size=32, weight_decay=0.0,
           learning_rate=0.001, num_epochs=7, checkpoint_path=None):
  # get loss function, CE has softmax built in
  criterion_pitch = nn.CrossEntropyLoss()
  criterion_step = nn.MSELoss()
  criterion_duration = nn.MSELoss()
  # get optimizer
  optimizer = optim.Adam(model.parameters(),
                          lr=learning_rate,
                          weight_decay=weight_decay)
  # get dataloader, load training data
  train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=batch_size,
                                               shuffle=True)
  # learning curve information for plotting
  iters, iter_at_epoch, losses, train_acc, val_acc = [], [], [], [], []
  num_iters = 0
  # iterate the given number of epochs
  for epoch in range(num_epochs):
    # shuffling data done automatically by data loader
    for batch_of_sequences in iter(train_loader): # iterate through all data in loader
    # batch_of_sequences is of batch size
      # account for smaller last batch
      if batch_of_sequences.size()[0] < batch_size:
        continue
      # compute forward and backward pass
      model.train() # ensute model in train mode

      # add <BOS>=128 and <EOS>=129 terms, pitch values range 0-127 inclusive
      BOS = torch.tensor([128, 0.0, 0.0] * batch_size).reshape(batch_size, 1, 3)
      EOS = torch.tensor([129, 0.0, 0.0] * batch_size).reshape(batch_size, 1, 3)
      # input shape is batch_size x seqeunce_length=64+1 x token size = 3
      input = torch.concat((BOS, batch_of_sequences), dim=1) # <EOS> never input
      out, _ = model(input.float())
      # out = batch_size x sequence_size x 132 [0-129 pitch probs, 130 step, 131 duration]
      # 129 as it is the pitch tokens from 0-127 + 2 for <BOS> and <EOS>
      out_pitch = out[:, :, 0:130]
      out_step = out[:, :, 130]
      out_duration = out[:, :, 131]
      targets = torch.concat((batch_of_sequences, EOS), dim=1) # <BOS> never output
      targets = targets.float()
      targets_pitch = targets[:, :, 0]
      targets_step = targets[:, :, 1]
      targets_duration = targets[:, :, 2]
      # shape must be batch_size x # classes=130 x sequence_length
      loss_pitch = criterion_pitch(out_pitch.reshape(batch_size, -1, out_pitch.shape[1]), targets_pitch.long())
      loss_step = criterion_step(out_step, targets_step)
      loss_duration = criterion_duration(out_duration, targets_duration)
      total_loss = loss_pitch + loss_step + loss_duration 
      total_loss.backward()
      optimizer.step()
      optimizer.zero_grad()
      # gather plotting data
      num_iters += 1
      losses.append(float(total_loss) / batch_size)
      iters.append(num_iters)
    # --- epoch ended ---
    # # check point model
    # if (checkpoint_path is not None) and num_iters > 0:
    #   torch.save(model.state_dict(), checkpoint_path.format(num_iters))
    # # track learning curve info
    # iter_at_epoch.append(num_iters)
    # train_acc.append(get_accuracy(model, train_data))
    # val_acc.append(get_accuracy(model, valid_data))
    # # report accuracies on train and validation set
      print("Epoch %d. Iter %d. [Val Acc %.0f%%] [Train Acc %.0f%%, Loss %f]" % (epoch, num_iters,0,0,float(total_loss.detach().numpy())))
      # print("Epoch %d. Iter %d. [Val Acc %.0f%%] [Train Acc %.0f%%, Loss %f]" % (epoch,
      #     num_iters, val_acc[-1] * 100, train_acc[-1] * 100, float(total_loss.detach().numpy())))
  return iters, losses, iter_at_epoch, train_acc, val_acc



def plot_learning_curve(iters, losses, iter_at_epoch, train_accs, val_accs):
    """
    Plot the learning curve.
    """
    plt.title("Learning Curve: Loss per Iteration")
    plt.plot(iters, losses, label="Train")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.show()
    plt.title("Learning Curve: Accuracy per Iteration")
    plt.plot(iter_at_epoch, train_accs, label="Train")
    plt.plot(iter_at_epoch, val_accs, label="Validation")
    plt.xlabel("Iterations")
    plt.ylabel("Accuracy")
    plt.legend(loc='best')
    plt.show()


In [129]:
model = MusicGenRNN()
train(model, train_set[:2], validation_set, num_epochs=120, batch_size=2)

Epoch 0. Iter 1. [Val Acc 0%] [Train Acc 0%, Loss 5.298226]
Epoch 1. Iter 2. [Val Acc 0%] [Train Acc 0%, Loss 5.228052]
Epoch 2. Iter 3. [Val Acc 0%] [Train Acc 0%, Loss 5.165895]
Epoch 3. Iter 4. [Val Acc 0%] [Train Acc 0%, Loss 5.105360]
Epoch 4. Iter 5. [Val Acc 0%] [Train Acc 0%, Loss 5.070100]
Epoch 5. Iter 6. [Val Acc 0%] [Train Acc 0%, Loss 5.079233]
Epoch 6. Iter 7. [Val Acc 0%] [Train Acc 0%, Loss 5.033733]
Epoch 7. Iter 8. [Val Acc 0%] [Train Acc 0%, Loss 5.025095]
Epoch 8. Iter 9. [Val Acc 0%] [Train Acc 0%, Loss 5.015596]
Epoch 9. Iter 10. [Val Acc 0%] [Train Acc 0%, Loss 4.997922]
Epoch 10. Iter 11. [Val Acc 0%] [Train Acc 0%, Loss 4.969443]
Epoch 11. Iter 12. [Val Acc 0%] [Train Acc 0%, Loss 4.927563]
Epoch 12. Iter 13. [Val Acc 0%] [Train Acc 0%, Loss 4.875422]
Epoch 13. Iter 14. [Val Acc 0%] [Train Acc 0%, Loss 4.944679]
Epoch 14. Iter 15. [Val Acc 0%] [Train Acc 0%, Loss 4.804642]
Epoch 15. Iter 16. [Val Acc 0%] [Train Acc 0%, Loss 4.820263]
Epoch 16. Iter 17. [Val Acc

([1,
  2,
  3,
  4,
  5,
  6,
  7,
  8,
  9,
  10,
  11,
  12,
  13,
  14,
  15,
  16,
  17,
  18,
  19,
  20,
  21,
  22,
  23,
  24,
  25,
  26,
  27,
  28,
  29,
  30,
  31,
  32,
  33,
  34,
  35,
  36,
  37,
  38,
  39,
  40,
  41,
  42,
  43,
  44,
  45,
  46,
  47,
  48,
  49,
  50,
  51,
  52,
  53,
  54,
  55,
  56,
  57,
  58,
  59,
  60,
  61,
  62,
  63,
  64,
  65,
  66,
  67,
  68,
  69,
  70,
  71,
  72,
  73,
  74,
  75,
  76,
  77,
  78,
  79,
  80,
  81,
  82,
  83,
  84,
  85,
  86,
  87,
  88,
  89,
  90,
  91,
  92,
  93,
  94,
  95,
  96,
  97,
  98,
  99,
  100,
  101,
  102,
  103,
  104,
  105,
  106,
  107,
  108,
  109,
  110,
  111,
  112,
  113,
  114,
  115,
  116,
  117,
  118,
  119,
  120],
 [2.649113178253174,
  2.614025831222534,
  2.5829477310180664,
  2.552680015563965,
  2.5350499153137207,
  2.539616346359253,
  2.516866445541382,
  2.512547731399536,
  2.507798194885254,
  2.4989612102508545,
  2.4847216606140137,
  2.4637813568115234,
  2.437711