In [9]:
import torch
import torch.nn as nn
import torch.optim as optim

Load Dataset

In [2]:
from music21 import corpus, chord, note, stream

# Load one chorale
chorale = corpus.chorales.Iterator().__next__()

# Get soprano melody
soprano = chorale.parts[0].flat.notes

# Get chordified version of the chorale (all voices combined into chords)
chords = chorale.chordify().flat.getElementsByClass('Chord')

# Pair up chord and melody at roughly the same time step
pairs = []
for s_note in soprano:
    # Find closest chord occurring at or before the melody note
    chord_at_time = None
    for c in chords:
        if c.offset <= s_note.offset:
            chord_at_time = c
        else:
            break
    if chord_at_time:
        chord_name = chord_at_time.pitchedCommonName  # like "C major"
        note_name = s_note.nameWithOctave            # like "E4"
        pairs.append((chord_name, note_name))

print(pairs[:10])


[('G-major triad', 'G4'), ('G-major triad', 'G4'), ('D-major triad', 'D5'), ('G-major triad', 'B4'), ('D-major triad', 'A4'), ('E-minor triad', 'G4'), ('C-major triad', 'G4'), ('F#-diminished triad', 'A4'), ('G-major triad', 'B4'), ('D-major triad', 'A4')]


  return self.iter().getElementsByClass(classFilterList)


In [3]:
len(pairs)

46

In [4]:
from music21 import corpus, chord, note

# Store all (chord, melody) pairs here
all_pairs = []

# Loop over all chorales in the corpus
for chorale in corpus.chorales.Iterator():

    try:
        # Get melody (soprano) part
        soprano = chorale.parts[0].flat.notesAndRests.stream()

        # Get chords by combining all parts
        chords = chorale.chordify().flat.getElementsByClass('Chord')

        # Extract (chord, note) pairs based on timing
        for s_note in soprano.notes:
            # Skip rests
            if not isinstance(s_note, note.Note):
                continue

            # Find the closest chord that occurs at or before the note
            matching_chord = None
            for c in chords:
                if c.offset <= s_note.offset:
                    matching_chord = c
                else:
                    break

            if matching_chord:
                chord_name = matching_chord.pitchedCommonName  # e.g., 'C major'
                note_name = s_note.nameWithOctave              # e.g., 'E4'
                all_pairs.append((chord_name, note_name))

    except Exception as e:
        print("Skipping chorale due to error:", e)


  return self.iter().getElementsByClass(classFilterList)


In [5]:
import csv

with open("chord_melody_pairs.csv", "w") as f:
    writer = csv.writer(f)
    writer.writerow(["Chord", "Note"])
    writer.writerows(all_pairs)


In [6]:
from google.colab import files
files.download("chord_melody_pairs.csv")


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [7]:
data = all_pairs

# Step 1: Extract unique chords and notes
unique_chords = sorted({pair[0] for pair in data})
unique_notes = sorted({pair[1] for pair in data})

# Step 2: Assign each one an index
chord_to_idx = {chord: idx for idx, chord in enumerate(unique_chords)}
note_to_idx = {note: idx for idx, note in enumerate(unique_notes)}

# Optional: reverse maps for decoding
idx_to_chord = {idx: chord for chord, idx in chord_to_idx.items()}
idx_to_note = {idx: note for note, idx in note_to_idx.items()}

print("Chord to idx:", chord_to_idx)
print("Note to idx:", note_to_idx)


Chord to idx: {'A#-diminished seventh chord': 0, 'A#-diminished triad': 1, 'A#-major-third diminished tetrachord': 2, 'A-all-interval tetrachord': 3, 'A-augmented triad': 4, 'A-diminished seventh chord': 5, 'A-diminished triad': 6, 'A-dominant seventh chord': 7, 'A-dominant-eleventh': 8, 'A-dominant-ninth': 9, 'A-flat-ninth pentachord': 10, 'A-half-diminished seventh chord': 11, 'A-incomplete dominant-seventh chord': 12, 'A-incomplete major-seventh chord': 13, 'A-incomplete minor-seventh chord': 14, 'A-lydian tetrachord': 15, 'A-major pentachord': 16, 'A-major pentatonic': 17, 'A-major seventh chord': 18, 'A-major triad': 19, 'A-major-diminished tetrachord': 20, 'A-major-second major tetrachord': 21, 'A-major-second minor tetrachord': 22, 'A-minor seventh chord': 23, 'A-minor triad': 24, 'A-minor trichord': 25, 'A-minor-augmented tetrachord': 26, 'A-minor-diminished ninth chord': 27, 'A-minor-diminished tetrachord': 28, 'A-minor-ninth chord': 29, 'A-minor-second diminished tetrachord':

In [10]:
input_seq = []
for chord, _ in data:
  input_seq.append(chord_to_idx[chord])

# PyTorch's sequence models expect inut with two dimensions
# This makes it one batch with 8 time steps
input_seq = torch.tensor(input_seq).unsqueeze(0)

In [11]:
input_seq

tensor([[347, 347, 181,  ..., 230, 266, 224]])

In [12]:
target_seq = []
for _, note in data:
  target_seq.append(note_to_idx[note])
target_seq = torch.tensor(target_seq).unsqueeze(0)

In [13]:
target_seq

tensor([[40, 40, 22,  ..., 40, 31, 28]])

LSTM Model

In [14]:
class MelodyLSTM(nn.Module):
  # input_dim: number of possible chord types, hidden_dim: size of the internal representation, output_dim: number of possible note outputs
  def __init__(self, input_dim, hidden_dim, output_dim):
    super().__init__()
    self.embedding = nn.Embedding(input_dim, hidden_dim) # maps each chord index to a vector
    self.lstm = nn.LSTM(hidden_dim, hidden_dim, batch_first=True)
    self.fc = nn.Linear(hidden_dim, output_dim) # maps the LSTM's output to a note prediction

  def forward(self, x):
    x = self.embedding(x)
    x, _ = self.lstm(x)
    return self.fc(x)

Train

In [15]:
input_dim = len(chord_to_idx)
hidden_dim = 64
output_dim = 128

model = MelodyLSTM(input_dim, hidden_dim, output_dim)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)


In [16]:
for epoch in range(1000):
  optimizer.zero_grad()
  output = model(input_seq)
  loss = loss_fn(output.view(-1, output_dim), target_seq.view(-1))
  loss.backward()
  optimizer.step()
  if (epoch + 1) % 100 == 0:
    print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}")


Epoch 100, Loss: 0.9153
Epoch 200, Loss: 0.5644
Epoch 300, Loss: 0.4408
Epoch 400, Loss: 0.2594
Epoch 500, Loss: 0.1834
Epoch 600, Loss: 0.1351
Epoch 700, Loss: 0.1108
Epoch 800, Loss: 0.1753
Epoch 900, Loss: 0.0597
Epoch 1000, Loss: 0.0408


Download

In [18]:
torch.save(model.state_dict(), "music21_lstm.pth")

from google.colab import files
files.download("music21_lstm.pth")


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>