# Predicition
This notebook details the pipeline for next-chord prediction.

In [1]:
# Useful starting lines
%matplotlib inline
%load_ext autoreload
%autoreload 2

## Load the data 
When loading the chord dataset, we can choose whether to keep sections in major or minor key, or both.

In [2]:
from load_data import load_train_test_sentences, all_composers, non_noisy_composers

In [35]:
# Choose which composers to train on and which to test on
composers = all_composers
#composers = non_noisy_composers
test_composers = ['Pleyel']

train_sentences, test_sentences = load_train_test_sentences(composers, test_composers, key_mode='MAJOR')
print(len(test_sentences))

11


## Apply Word2Vec
Several hyperparameters to choose.

In [36]:
from gensim.models import Word2Vec
from load_data import get_chord_sentences

In [37]:
# Ignore words with a lower frequency frequency than this
min_count = 5
# Size of the embedding space
size = 20
# Neighborhood of the focus word to study
window = 2
# 0 for CBOW, 1 for skip-gram
sg = 1
# Number of iterations (epochs)
iter = 500

# The first argument has to be a list of lists of words
w2v_model = Word2Vec(train_sentences, min_count=min_count, size=size, window=window, sg=sg, iter=iter)

In [38]:
w2v_model.wv.vocab.keys()

dict_keys(['I:MAJ', 'V:MAJ', 'IV:MAJ', '#IV:DIM', 'II:MAJ', 'VI:MIN', 'bVII:MAJ', 'VII:DIM', 'III:MAJ', 'VI:MAJ', 'II:MIN', '#I:DIM', 'V:MIN', 'III:DIM', 'II:DIM', 'IV:MIN', 'I:DIM', '#V:DIM', 'VII:MAJ', 'III:MIN', 'I:MIN', 'VII:MIN', 'bIII:MAJ', 'bVI:MAJ', '#II:DIM', 'VI:DIM', 'bbVII:MAJ', 'I:AUG', 'IV:AUG', 'V:AUG', '#VI:DIM', 'bVII:MIN', '#II:MAJ', 'II:AUG', '#V:MAJ', 'bVI:AUG', 'bIII:AUG', 'bII:MAJ', '#IV:MIN', 'bVII:AUG', '#IV:MAJ', 'bV:MAJ', 'bIV:MAJ', 'bVII:DIM', 'V:DIM', 'bI:MAJ', 'IV:DIM', 'bIII:MIN', '#I:MAJ', 'bI:MIN', 'bVI:MIN', '#VI:MAJ', 'bIII:DIM', '#VII:DIM', '#I:MIN', '#II:MIN', 'III:AUG', '#III:MAJ', 'bII:MIN', '#V:MIN', '##IV:DIM', '#III:DIM'])

## Predict
Train the LSTM predictor on the same dataset as the Word2Vec model, then test it on the test dataset

In [39]:
from lstm import LSTMPredictor
import torch
import torch.nn as nn
import torch.optim as optim

### Train the predictor

In [40]:
lstm_predictor = LSTMPredictor(w2v_model, 15)
optimiser = optim.Adam(lstm_predictor.parameters(), lr=0.001)

lstm_predictor.learn(train_sentences, optimiser, 2)
# Training takes a couple minutes

Starting epoch 0
Iteration 5000 : average loss = 2.481371134388447
Iteration 10000 : average loss = 2.696560618257523
Iteration 15000 : average loss = 2.3028311231821776
Iteration 20000 : average loss = 2.0471320151239634
Iteration 25000 : average loss = 2.1376013145715
Iteration 30000 : average loss = 1.7969899514466525
Iteration 35000 : average loss = 1.3552037417069078
Iteration 40000 : average loss = 1.5395047161892057
Iteration 45000 : average loss = 1.8766031429111958
Iteration 50000 : average loss = 2.259413107815385
Iteration 55000 : average loss = 1.6999046997725964
Iteration 60000 : average loss = 1.4278137263149022
Iteration 65000 : average loss = 1.9280818780869247
Closing epoch 0 

Starting epoch 1
Iteration 5000 : average loss = 2.0709528877168895
Iteration 10000 : average loss = 2.586650201436877
Iteration 15000 : average loss = 2.1332288523107765
Iteration 20000 : average loss = 1.9335909777104854
Iteration 25000 : average loss = 2.0599059022992847
Iteration 30000 : ave

### Test the predictor

In [41]:
accuracy_total, accuracy_by_chord, occurrences_by_chord = lstm_predictor.test(test_sentences)

print('Total accuracy:', accuracy_total)
print('Accuracy by chord\n', accuracy_by_chord)
print('Occurrences by chord\n', occurrences_by_chord)

Total accuracy: 0.5343383584589615
Accuracy by chord
 {'V:MAJ': 0.33507853403141363, 'I:MAJ': 0.875, 'IV:MAJ': 0.896551724137931, 'II:MIN': 0.1794871794871795, 'VI:MIN': 0.2916666666666667, 'VII:DIM': 0.0, 'III:MIN': 0.0, 'I:MIN': 0.0, 'VI:MAJ': 0.0, '#IV:DIM': 0.0, 'II:MAJ': 0.0, 'IV:MIN': 0.0, 'III:MAJ': 0.0, '#II:DIM': 0.0, '#I:DIM': 0.0, 'I:AUG': 0.0, '#V:DIM': 0.0, 'VI:DIM': 0.0, 'III:DIM': 0.0, 'II:DIM': 0.0}
Occurrences by chord
 {'V:MAJ': 191, 'I:MAJ': 216, 'IV:MAJ': 58, 'II:MIN': 39, 'VI:MIN': 24, 'VII:DIM': 20, 'III:MIN': 2, 'I:MIN': 4, 'VI:MAJ': 7, '#IV:DIM': 6, 'II:MAJ': 5, 'IV:MIN': 1, 'III:MAJ': 4, '#II:DIM': 3, '#I:DIM': 8, 'I:AUG': 2, '#V:DIM': 4, 'VI:DIM': 1, 'III:DIM': 1, 'II:DIM': 1}
