In [1]:
import numpy as np
import pickle
import sys

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from matplotlib.pyplot import imshow
import matplotlib.pyplot as plt
from PIL import Image

%matplotlib inline

np.random.seed(3333)
torch.manual_seed(3333)

<torch._C.Generator at 0x7f58c85ae210>

In [2]:
## DATALOADER FOR MUSIC FILES
from torch.utils.data import DataLoader, Dataset

class MusicDataset(Dataset):
    def __init__(self, data_file, sequence_length, data_augmentation):
        super(MusicDataset, self).__init__()
        self.sequence_length = sequence_length
        self.data_augmentation = data_augmentation
        
        with open(data_file, 'rb') as f:
            self.id_to_sheet = pickle.load(f)
            self.data = pickle.load(f)
        
        self.data = [x.astype(np.float32) for x in self.data]
        self.start_token = np.zeros((1, 130), dtype=np.float32)
        self.start_token[:, 0] = 1.0
        
        # pad all sequences to desired sequence length
        self.mask_lengths = []
        for i, x in enumerate(self.data):
            if len(x) < self.sequence_length:
                s = x.shape
                self.data[i] = np.zeros((self.sequence_length, s[1]), dtype=np.float32)
                self.data[i][:s[0], :] = x
                self.mask_lengths.append(s[0])
            else:
                self.mask_lengths.append(self.sequence_length)

    # data augmentation = shifting of chords +/- 1 octave
    # there are 4 modes per chord, so we shift in multiples of 4
    def augment_sequence(self, sequence, offset):
        sequence[:, 1:128] = np.roll(sequence[:, 1:128], offset)  # melody
        sequence[:, 143:191] = np.roll(sequence[:, 143:191], offset*4)  # chord
        
    def process_sequence(self, seq):
        seq_chords_rhythm = seq[:, 130:]
        seq_melody = np.vstack((self.start_token, seq[:, :130]))
        
        seq_input = np.hstack((seq_melody[:-1, :], seq_chords_rhythm))
        seq_output = seq_melody[1:]
        
        return (seq_input, seq_output)
                    
    def __len__(self):
        return sys.maxsize
        
    def __getitem__(self, index):
        data_index = np.random.randint(0, len(self.data))
        data_point = self.data[data_index]
        if self.data_augmentation:
            self.augment_sequence(data_point, np.random.randint(-12, 13))
        
        data_range = np.random.randint(0, len(data_point) - self.sequence_length + 1)
        seq = data_point[data_range:(data_range + self.sequence_length)]
        
        seq_chords_rhythm = seq[:, 130:]
        seq_melody = np.vstack((self.start_token, seq[:, :130]))
        
        seq_input = np.hstack((seq_melody[:-1, :], seq_chords_rhythm))
        seq_lengths = np.asarray(self.mask_lengths[data_index], dtype=np.float32)
        seq_output = seq_melody[1:]
        
        if self.mask_lengths[data_index] < self.sequence_length:
            seq_input[self.mask_lengths[data_index], :] = 0.0
        
        # return original sequence, target sequence, and original sequence lengths
        return (
            seq_input,
            seq_output,
            seq_lengths
        )

In [122]:
class LSTM(nn.Module):
    def __init__(self, input_dim_1, input_dim_2, hidden_dim, batch_size, output_dim, num_layers_bi=2, num_layers_lstm=2, inference=False):
        super(LSTM, self).__init__()
        self.input_dim_1 = input_dim_1
        self.input_dim_2 = input_dim_2
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size
        self.output_dim = output_dim
        self.num_layers_lstm = num_layers_lstm
        self.num_layers_bi = num_layers_bi
        self.inference = inference
        
        self.hidden_bi, self.hidden_lstm = self.init_hidden()
        
        # Bidirectional LSTM
        self.lstm_bi = nn.LSTM(
            self.input_dim_1,
            self.hidden_dim,
            self.num_layers_bi,
            batch_first=True,
            bidirectional=True
        )
        
        # Second pair of LSTM layers
        self.lstm = nn.LSTM(self.hidden_dim * 2 + self.input_dim_2, self.hidden_dim, self.num_layers_lstm, batch_first=True)
        
        # Last dense layer
        self.dense = nn.Linear(self.hidden_dim, self.output_dim)

    def init_hidden(self):
        # This is what we'll initialise our hidden state as
        return (
            (torch.zeros(self.num_layers_bi * 2, self.batch_size, self.hidden_dim),
                torch.zeros(self.num_layers_bi * 2, self.batch_size, self.hidden_dim)),
            (torch.zeros(self.num_layers_lstm, self.batch_size, self.hidden_dim),
                torch.zeros(self.num_layers_lstm, self.batch_size, self.hidden_dim))
               )
    
    def get_bi_output(self, input_X):
        # get chord and rhythm info from input
        input_X_1 = input_X[:, :, self.input_dim_2:]
        out_bi, _ = self.lstm_bi(input_X_1)
        
        return out_bi
    
    def process_lstm_sequence(self, bi_part, melody_part, temperature=1.0, propagate_hidden=False):
        concat_X_2 = torch.cat((melody_part, bi_part), 2)
            
        if propagate_hidden:
            out_lstm, self.hidden_lstm = self.lstm(concat_X_2, self.hidden_lstm)
        else:
            out_lstm, self.hidden_lstm = self.lstm(concat_X_2)
        
        # Generate sequence predictions
        X_pred = self.dense(out_lstm)
        
        # Apply nonlinearities
        X_softmax = F.log_softmax(X_pred / temperature, dim=2)
        
        return X_softmax

    def forward(self, input_X, lengths_X, max_length, temperature=1.0, propagate_hidden=False):
        # Forward pass through LSTM layer
        # shape of lstm_out: [batch_size, input_size, hidden_dim]
        # shape of self.hidden: (a, b), where a and b both have shape (num_layers, batch_size, hidden_dim).
        
        # get chord and rhythm info from input
        input_X_1 = input_X[:, :, self.input_dim_2:]
        
        if not self.inference:
            packed_X_1 = torch.nn.utils.rnn.pack_padded_sequence(input_X_1, lengths_X, batch_first=True)
        else:
            packed_X_1 = input_X_1
            
        if propagate_hidden:
            out_bi, self.hidden_bi = self.lstm_bi(packed_X_1, self.hidden_bi)
        else:
            out_bi, self.hidden_bi = self.lstm_bi(packed_X_1)
        
        if not self.inference:
            unpacked_X_1, _ = torch.nn.utils.rnn.pad_packed_sequence(out_bi, batch_first=True, total_length=max_length)
        else:
            unpacked_X_1 = out_bi
        
        # Concat output of bi-LSTM with melody part of input
        input_X_2 = input_X[:, :, :self.input_dim_2]
        concat_X_2 = torch.cat((input_X_2, unpacked_X_1), 2)
        
        # Feed through second pair of LSTM layers
        if not self.inference:
            packed_X_2 = torch.nn.utils.rnn.pack_padded_sequence(concat_X_2, lengths_X, batch_first=True)
        else:
            packed_X_2 = concat_X_2
            
        if propagate_hidden:
            out_lstm, self.hidden_lstm = self.lstm(packed_X_2, self.hidden_lstm)
        else:
            out_lstm, self.hidden_lstm = self.lstm(packed_X_2)
        
        if not self.inference:
            unpacked_X_2, _ = torch.nn.utils.rnn.pad_packed_sequence(out_lstm, batch_first=True, total_length=max_length)
        else:
            unpacked_X_2 = out_lstm
        
        # Generate sequence predictions
        X_pred = self.dense(unpacked_X_2)
        
        # Apply nonlinearities
        X_softmax = F.log_softmax(X_pred / temperature, dim=2)
        
        return X_softmax

model = LSTM(input_dim_1=62, input_dim_2=130, hidden_dim=512, batch_size=1, output_dim=62)

In [3]:
md = MusicDataset('../data/processed_sheets_numpy/data.pkl', 100, data_augmentation=True)

In [4]:
for k, v in md.id_to_sheet.items():
    if 'Menken' in v:
        print(k, v)

3333 Alan Menken, Stephen Schwartz - Colors Of The Wind.json
3391 Alan Menken, Howard Ashman - Skid Row (Downtown).json
4710 Alan Menken, Tim Rice - A WHOLE NEW WORLD.json
5040 Alan Menken, Howard Ashman - Beauty and the Beast.json
5156 Alan Menken, Howard Ashman - Under the Sea.json


In [6]:
md.id_to_sheet[4444]

'Oakley Haldeman, Gene Autry - Here Comes Santa Claus.json'

In [23]:
data_train_loader = DataLoader(md, batch_size=100, shuffle=False)

In [24]:
X = next(iter(data_train_loader))

In [48]:
model(X[0], X[2], 100, propagate_hidden=True)

tensor([[[-4.1127, -4.1692, -4.1670,  ..., -4.1395, -4.0613, -4.1231],
         [-4.1121, -4.1686, -4.1671,  ..., -4.1395, -4.0616, -4.1237],
         [-4.1104, -4.1684, -4.1670,  ..., -4.1384, -4.0633, -4.1237],
         ...,
         [-4.1115, -4.1692, -4.1665,  ..., -4.1407, -4.0611, -4.1240],
         [-4.1120, -4.1691, -4.1666,  ..., -4.1403, -4.0611, -4.1235],
         [-4.1117, -4.1685, -4.1669,  ..., -4.1399, -4.0613, -4.1238]]],
       grad_fn=<LogSoftmaxBackward>)

## Inference

In [130]:
model = LSTM(
        input_dim_1=62,
        input_dim_2=130,
        hidden_dim=512,
        batch_size=1,
        output_dim=130,
        num_layers_bi=2,
        num_layers_lstm=2,
        inference=True
    )

model.load_state_dict(torch.load('../chord_rhythm_melody_lstm/test/epoch_60000.model'))
model.eval()

LSTM(
  (lstm_bi): LSTM(62, 512, num_layers=2, batch_first=True, bidirectional=True)
  (lstm): LSTM(1154, 512, num_layers=2, batch_first=True)
  (dense): Linear(in_features=512, out_features=130, bias=True)
)

In [157]:
# INFERENCE LOOP
model.hidden_bi, model.hidden_lstm = model.init_hidden()
X_in = torch.FloatTensor(md.process_sequence(md.data[2322])[0]).unsqueeze(0)
bi_output = model.get_bi_output(X_in)

model.hidden_bi, model.hidden_lstm = model.init_hidden()

lstm_out = model.process_lstm_sequence(bi_output[:, 0:1, :], torch.FloatTensor(md.start_token).unsqueeze(0), temperature=0.9, propagate_hidden=True)
next_char = torch.multinomial(torch.exp(lstm_out)[0, 0], 1)
print(midi_pitch_to_pitch(int(next_char.detach().numpy())))
next_one_hot = torch.FloatTensor(np.zeros((1, 1, 130), dtype=np.float32))
next_one_hot[0, 0, next_char] = 1.0

for k in range(1, bi_output.shape[1]):
    lstm_out = model.process_lstm_sequence(bi_output[:, k:k+1, :], next_one_hot, temperature=0.9, propagate_hidden=True)
    next_char = torch.multinomial(torch.exp(lstm_out)[0, 0], 1)
    print(midi_pitch_to_pitch(int(next_char.detach().numpy())))
    next_one_hot[:, :, :] = 0.0
    next_one_hot[0, 0, next_char] = 1.0

('barline', 0, 0)
('F', 0, 4)
('G', 0, 4)
('A', 0, 4)
('barline', 0, 0)
('A', 1, 4)
('C', 0, 5)
('D', 0, 5)
('barline', 0, 0)
('D', 1, 5)
('G', 0, 4)
('barline', 0, 0)
('F', 1, 4)
('barline', 0, 0)
('D', 0, 5)
('F', 0, 4)
('barline', 0, 0)
('D', 1, 4)
('barline', 0, 0)
('G', 0, 4)
('D', 1, 4)
('barline', 0, 0)
('D', 0, 4)
('barline', 0, 0)
('F', 0, 4)
('A', 1, 3)
('barline', 0, 0)
('C', 1, 4)
('D', 1, 4)
('E', 0, 4)
('F', 1, 4)
('barline', 0, 0)
('G', 1, 4)
('D', 1, 5)
('barline', 0, 0)
('D', 0, 5)
('D', 1, 5)
('barline', 0, 0)
('G', 1, 4)
('F', 0, 5)
('barline', 0, 0)
('D', 1, 5)
('C', 1, 5)
('barline', 0, 0)
('G', 1, 4)
('E', 0, 5)
('barline', 0, 0)
('D', 1, 5)
('C', 1, 5)
('barline', 0, 0)
('F', 1, 4)
('G', 1, 4)
('A', 1, 4)
('barline', 0, 0)
('B', 0, 4)
('C', 1, 5)
('D', 1, 5)
('barline', 0, 0)
('F', 0, 5)
('D', 1, 5)
('barline', 0, 0)
('D', 0, 5)
('C', 0, 5)
('barline', 0, 0)
('F', 0, 4)
('barline', 0, 0)
('F', 0, 4)
('barline', 0, 0)
('F', 0, 4)
('barline', 0, 0)
('F', 0, 4)
('ba

In [112]:
next_one_hot

array([[[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]]], dtype=float32)

In [94]:
bi_output[:, 0:1, :].shape

torch.Size([1, 1, 1024])

In [90]:
X_in.shape

torch.Size([1, 165, 192])

In [152]:
class Constants:
    PITCH_CLASSES = [["C", 0], ["C", 1], ["D", 0], ["D", 1], ["E", 0], ["F", 0], ["F", 1], ["G", 0], ["G", 1], ["A", 0], ["A", 1], ["B", 0]]
    
def pitch_to_midi_pitch(step, alter, octave):
    """!@brief Convert MusicXML pitch representation to MIDI pitch number.

        @param step Which root note it is (e.g. C, D,...)
        @param alter If the pitch was altered (sharp or flat)
        @param octave The octave that the pitch is in

        @return The MIDI pitch representation of the input
    """
    pitch_class = 0
    if step == 'C':
        pitch_class = 0
    elif step == 'D':
        pitch_class = 2
    elif step == 'E':
        pitch_class = 4
    elif step == 'F':
        pitch_class = 5
    elif step == 'G':
        pitch_class = 7
    elif step == 'A':
        pitch_class = 9
    elif step == 'B':
        pitch_class = 11
    else:
        # Raise exception for unknown step (ex: 'Q')
        raise Exception('Unable to parse pitch step ' + step)

    pitch_class = (pitch_class + int(alter)) % 12
    midi_pitch = (12 + pitch_class) + (int(octave) * 12)
    return midi_pitch

def midi_pitch_to_pitch(midi_pitch):
    if midi_pitch == 128:
        return ('rest', 0, 0)
    elif midi_pitch == 129:
        return ('barline', 0, 0)
    octave = midi_pitch // 12 - 1
    pitch_class = Constants.PITCH_CLASSES[midi_pitch % 12]
    step, alter = pitch_class[0], pitch_class[1]

    return (step, alter, octave)

In [153]:
midi_pitch_to_pitch(128)

('rest', 0, 0)