In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

Toy Dataset

In [3]:
chord_to_idx = {'C_major': 0, 'F_major': 1, 'G_major': 2}
note_to_idx = {'E4': 64, 'G4': 67, 'A4': 69, 'C5':72, 'F4':65, 'D4':62}

# simple music peice, each pair is a (chord, melody note) that occures at one time step
data = [
    ('C_major', 'E4'),
    ('C_major', 'G4'),
    ('F_major', 'A4'),
    ('G_major', 'G4'),
    ('C_major', 'E4'),
    ('C_major', 'C5'),
    ('F_major', 'F4'),
    ('G_major', 'D4'),

]

In [10]:
input_seq = []
for chord, _ in data:
  input_seq.append(chord_to_idx[chord])

# PyTorch's sequence models expect inut with two dimensions
# This makes it one batch with 8 time steps
input_seq = torch.tensor(input_seq).unsqueeze(0)

In [11]:
input_seq

tensor([[0, 0, 1, 2, 0, 0, 1, 2]])

In [12]:
target_seq = []
for _, note in data:
  target_seq.append(note_to_idx[note])
target_seq = torch.tensor(target_seq).unsqueeze(0)

In [13]:
target_seq

tensor([[64, 67, 69, 67, 64, 72, 65, 62]])

LSTM Model

In [15]:
class MelodyLSTM(nn.Module):
  # input_dim: number of possible chord types, hidden_dim: size of the internal representation, output_dim: number of possible note outputs
  def __init__(self, input_dim, hidden_dim, output_dim):
    super().__init__()
    self.embedding = nn.Embedding(input_dim, hidden_dim) # maps each chord index to a vector
    self.lstm = nn.LSTM(hidden_dim, hidden_dim, batch_first=True)
    self.fc = nn.Linear(hidden_dim, output_dim) # maps the LSTM's output to a note prediction

  def forward(self, x):
    x = self.embedding(x)
    x, _ = self.lstm(x)
    return self.fc(x)

Train

In [16]:
input_dim = len(chord_to_idx)
hidden_dim = 64
output_dim = 128

model = MelodyLSTM(input_dim, hidden_dim, output_dim)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)


In [18]:
for epoch in range(1000):
  optimizer.zero_grad()
  output = model(input_seq)
  loss = loss_fn(output.view(-1, output_dim), target_seq.view(-1))
  loss.backward()
  optimizer.step()
  if (epoch + 1) % 100 == 0:
    print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}")


Epoch 100, Loss: 0.0055
Epoch 200, Loss: 0.0022
Epoch 300, Loss: 0.0012
Epoch 400, Loss: 0.0008
Epoch 500, Loss: 0.0006
Epoch 600, Loss: 0.0004
Epoch 700, Loss: 0.0003
Epoch 800, Loss: 0.0003
Epoch 900, Loss: 0.0002
Epoch 1000, Loss: 0.0002


Download

In [19]:
torch.save(model.state_dict(), "melody_lstm.pth")

from google.colab import files
files.download("melody_lstm.pth")


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>