In [1]:
import math
import os
import pickle
import random
from datetime import datetime

import numpy as np
import torch
from matplotlib import pyplot as plt
from music21 import converter
from torch import nn
from torch.nn import functional as F
from torch.nn.utils import clip_grad_norm_
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import trange, tqdm
from IPython.display import display, Audio

from utils import decode_midi

In [2]:
class MusicRNN(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, vocab_size, rnn_type="lstm", num_layers=1, dropout=0.0):
        assert rnn_type in ["lstm", "gru"]

        super().__init__()

        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        self.rnn_type = rnn_type
        self.num_layers = num_layers
        self.dropout = dropout

        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        if rnn_type == "lstm":
            self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers, dropout=dropout, batch_first=True) 
        else:
            self.rnn = nn.GRU(embedding_dim, hidden_dim, num_layers, dropout=dropout, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
    
        self.hidden_state = None

    def forward(self, x, reset_hidden_state=True):
        x = self.embeddings(x)
        if reset_hidden_state:
            x, _ = self.rnn(x)
        else:
            x, self.hidden_state = self.rnn(x, self.hidden_state)
        x = self.fc(x)
        return x

    def generate_sequence(self, beam_width=1, seq_start=None, max_length=1024, **kwargs):
        if not seq_start:
            seq_start = [BOS_TOKEN]
        seq = seq_start.copy()
        with torch.no_grad():
            # Generate k most likely tokens
            prev_top_seq, next_top_seq = [], []
            prev_top_seq = [(seq, 0.0)]
            
            kpos_next_tokens = self._generate_next_token(
                candidates=torch.LongTensor(prev_top_seq).to(device),
                reset_hidden=True,
                **kwargs
            )
            
            next_top_seq.extend(kpos_next_tokens)
            
            while len(seq) <= max_length:
                for sequence in prev_top_seq:
                    kpos_next_tokens = self._generate_next_token(
                        candidates=torch.LongTensor([kpos_next_tokens]).to(device),
                        reset_hidden=False,
                        **kwargs
                    )
                    
                    next_top_seq.extend(kpos_next_tokens)
                
                next_top_seq.sort(reverse=True, key=lambda tup: tup[1])
                prev_top_seq = next_top_seq[:self.beam_width]
                next_top_seq = []
                
                seq.append(next_token)
        
        n = min(len(prev_top_seq), max_length)
        return prev_top_seq[:n]
    
    def _generate_next_token(self, candidates, reset_hidden=False, temp=1.0, topk=5, argmax=False):
        # The model expects a batch input, so we add a fake batch dimension.
        # Removed the fake batch dimension (.unsqueeze(0)) because we have batch of beam_width tokens
        import pdb;
        pdb.set_trace()
        model_input = np.array([tup[0] for tup in candidates]) # Previous sequences: [(seq , score), (seq2, score), (seq3, score)]
        # Then, we need to remove the fake batch dimension from the output.
        # Also removed the (.squeeze(0)) for similar reasons
        model_output = self(model_input, reset_hidden)
        
        # Apply softmax to top beam_width tokens
        
        for i in range(len(candidates)):
            next_token_probs = [(candidates[i].first + token, candidates[i].second + np.log(score)) for score in F.softmax(model_output[:, i] / temp, dim=0)]
        
        if argmax:
            # Keep top beam_width tokens
            kpos_next_tokens = torch.topk(next_token_probs, self.beam_width)
        else:
            # TO-DO: Implement beam search for this case too
            top_tokens = torch.topk(next_token_probs, topk)
            top_indices = top_tokens.indices
            top_probs = top_tokens.values
            top_probs /= torch.sum(top_probs)
            kpos_next_tokens = np.random.choice(top_indices.cpu().numpy(), p=top_probs.cpu().numpy())
        return kpos_next_tokens
    
    def prune_notes():
        pass

In [3]:
DATA_ROOT = "data/pop_pickle"
N_SAMPLES = 909
VOCAB_SIZE = 390
BOS_TOKEN = VOCAB_SIZE - 2
PAD_TOKEN = VOCAB_SIZE - 1

In [4]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Device: {device}")

Device: cuda


In [5]:
dataset = []
for i in range(N_SAMPLES):
    file_name = str(i + 1).zfill(3) + ".pickle"
    path = os.path.join(DATA_ROOT, file_name)
    with open(path, "rb") as f:
        seq = pickle.load(f)
        seq_tensor = torch.LongTensor(seq)
    dataset.append(seq_tensor)

In [6]:
N_VAL = N_TEST = int(0.05 * 909)
N_TRAIN = 909 - (N_VAL + N_TEST)

train_songs = dataset[:N_TRAIN]
val_songs = dataset[N_TRAIN:N_TRAIN+N_VAL]
test_songs = dataset[N_TRAIN+N_VAL:]

print(f"Train: {len(train_songs)} \t Val: {len(val_songs)} \t Test: {len(test_songs)}")

Train: 819 	 Val: 45 	 Test: 45


In [7]:
FIRST_N = 125
IDX = 10

primer = test_songs[IDX][:FIRST_N].tolist()

In [8]:
lstm = MusicRNN(
    embedding_dim=64,
    hidden_dim=256,
    vocab_size=VOCAB_SIZE,
    rnn_type="lstm",
    num_layers=2,
    dropout=0.5,
).to(device)

sd = torch.load("models/lstm(64,256),lr=0.001,bsz=256,nepochs=250,sl=1024,dropout=0.5,t=2020-11-13_14:30.pt")["model_state_dict"]
lstm.load_state_dict(sd)

<All keys matched successfully>

In [9]:
primer = test_songs[IDX][:FIRST_N].tolist()
continuation = lstm.generate_sequence(primer, temp=1.0, topk=128)
display_audio(continuation)

TypeError: not a sequence