In [50]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
%pdb off

Automatic pdb calling has been turned OFF


In [2]:
import numpy as np
import torch
import torch.utils.data
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
from torchvision.utils import make_grid , save_image

In [3]:
from constants import *

In [4]:
import matplotlib.pyplot as plt

### Extracting Nietzche Corpus

In [5]:
from fastai.io import *

In [6]:
PATH='../data/nietzsche/'

In [7]:
get_data("https://s3.amazonaws.com/text-datasets/nietzsche.txt", f'{PATH}nietzsche.txt')
text = open(f'{PATH}nietzsche.txt').read()
print('corpus length:', len(text))

corpus length: 600893


In [8]:
text[:400]

'PREFACE\n\n\nSUPPOSING that Truth is a woman--what then? Is there not ground\nfor suspecting that all philosophers, in so far as they have been\ndogmatists, have failed to understand women--that the terrible\nseriousness and clumsy importunity with which they have usually paid\ntheir addresses to Truth, have been unskilled and unseemly methods for\nwinning a woman? Certainly she has never allowed herself '

In [9]:
chars = sorted(list(set(text)))
vocab_size = len(chars)+1
print('total chars:', vocab_size)

total chars: 85


In [10]:
chars.insert(0, "\0")

''.join(chars[1:-6])

'\n !"\'(),-.0123456789:;=?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]_abcdefghijklmnopqrstuvwxy'

In [11]:
char_indices = {c: i for i, c in enumerate(chars)}
indices_char = {i: c for i, c in enumerate(chars)}

In [12]:
idx = [char_indices[c] for c in text]

idx[:10]

[40, 42, 29, 30, 25, 27, 29, 1, 1, 1]

In [13]:
''.join(indices_char[i] for i in idx[:70])

'PREFACE\n\n\nSUPPOSING that Truth is a woman--what then? Is there not gro'

In [14]:
def idx2char(idx):
    return ''.join(indices_char[i] for i in idx)

### One hot encoding

In [15]:
def one_hot(a,c): 
    return np.eye(c)[a]

In [16]:
def hot2char(idx):
    val,idx = idx.max(dim=-1)
    return idx2char(idx)

### Dataset

In [17]:
class TextDataset(torch.utils.data.Dataset):
    """Face Landmarks dataset."""

    def __init__(self, text, idx, vocab_size, timesteps):
        self.vocab_size = vocab_size
        self.text = text
        self.dataset = idx
        self.data_length = len(idx)
        self.timesteps = timesteps

    def __len__(self):
        return (self.data_length // self.timesteps)

    def __getitem__(self, idx):
        start = idx*self.timesteps
        x = self.dataset[start:start+self.timesteps]
        y = self.dataset[start+1:start+self.timesteps+1]
#         return one_hot(x, vocab_size), np.array(y)
        return np.array(x), np.array(y)


In [18]:
batch_size = 64
timesteps = 64
md = TextDataset(text, idx, vocab_size, timesteps)
# md = MusicDataset(h5_file='concat_corpus.h5', set_type='train', json_file='concat_corpus.json', timesteps=timesteps, root_dir=CONCAT_DIR)

In [19]:
train_loader = torch.utils.data.DataLoader(md,
    batch_size=batch_size)

### Dataset sanity test

In [20]:
train_iter = enumerate(train_loader)

In [21]:
i, (x, y) = next(train_iter)
i2, (x2, y2) = next(train_iter)

In [22]:
md.dataset[:10]

[40, 42, 29, 30, 25, 27, 29, 1, 1, 1]

In [23]:
idx2char(md.dataset[:10])

'PREFACE\n\n\n'

In [24]:
idx2char(x2[0])

's bow! And perhaps also the arrow, the duty, and, who\nknows? THE'

In [25]:
# hot2char(x2[0])

### Model

In [26]:
cuda_enabled = torch.cuda.is_available()

In [27]:
def repackage_var(h):
    """Wraps h in new Variables, to detach them from their history."""
    if type(h) == torch.autograd.Variable:
        v = torch.autograd.Variable(h.data)
        return v.cuda() if cuda_enabled else v
    else:
        return tuple(repackage_var(v) for v in h)

In [63]:
class StatefulLSTM(torch.nn.Module):
    def __init__(self, scale_size, n_hidden, n_factors, bs, nl, bidirectional=False):
        super().__init__()
        self.scale_size = scale_size
        self.nl = nl
        self.bidirectional = bidirectional
        self.bmult = 2 if bidirectional else 1
        self.embedding = torch.nn.Embedding(scale_size, n_factors)
        
        self.rnn1 = torch.nn.LSTM(n_factors, n_hidden, nl, dropout=0.5, batch_first=True, bidirectional=bidirectional)
        self.rnn2 = torch.nn.LSTM(n_hidden, n_hidden*self.bmult, nl, dropout=0.5, batch_first=True, bidirectional=bidirectional)
        self.rnn3 = torch.nn.LSTM(n_hidden, n_hidden*self.bmult, nl, dropout=0.5, batch_first=True, bidirectional=bidirectional)
        
        if cuda_enabled:
            self.rnn1 = self.rnn1.cuda()
            self.rnn2 = self.rnn2.cuda()
            self.rnn3 = self.rnn3.cuda()
        
        self.bn1 = nn.utils.weight_norm(self.rnn1, 'weight_hh_l0')
        self.bn1 = nn.utils.weight_norm(self.bn1, 'weight_ih_l0')
        self.bn2 = nn.utils.weight_norm(self.rnn2, 'weight_hh_l0')
        self.bn2 = nn.utils.weight_norm(self.bn2, 'weight_ih_l0')
        self.bn3 = nn.utils.weight_norm(self.rnn3, 'weight_hh_l0')
        self.bn3 = nn.utils.weight_norm(self.bn3, 'weight_ih_l0')
        
        # pytorch rnn does not currently work with batchnorm
        self.l_out = torch.nn.Linear(n_hidden, scale_size)
        self.n_hidden = n_hidden
        self.reset_all_hidden(bs)
        self.bs = bs
        
    def forward(self, notes):
        bs = notes.shape[0]
        if self.h1[0].size(1) != bs: 
            self.reset_all_hidden(bs)
        emb = self.embedding(notes)
        outp1,h1 = self.rnn1(emb, self.h1)
#         pdb.set_trace()
        outp2,h2 = self.bn2(outp1, self.h2)
        outp3,h3 = self.bn3(outp2, self.h3)
        self.h1 = repackage_var(h1)
        self.h2 = repackage_var(h2)
        self.h3 = repackage_var(h3)
        return torch.nn.functional.log_softmax(self.l_out(outp3), dim=-1).view(-1, self.scale_size)
#         return torch.nn.functional.log_softmax(self.l_out(outp[:, -1, :]), dim=-1)
#         return torch.nn.functional.softmax(self.l_out(outp[:, -1, :]), dim=-1)
    
    def reset_all_hidden(self, bs):
        self.h1 = self.init_hidden(bs)
        self.h2 = self.init_hidden(bs)
        self.h3 = self.init_hidden(bs)
        
    def init_hidden(self, bs):
        h1 = torch.autograd.Variable(torch.zeros(self.nl*self.bmult, bs, self.n_hidden))
        h2 = torch.autograd.Variable(torch.zeros(self.nl*self.bmult, bs, self.n_hidden))
        if cuda_enabled:
            return (h1.cuda(), h2.cuda())
        return h1, h2

### Training

In [64]:
m = StatefulLSTM(md.vocab_size, n_hidden=256, n_factors=10, bs=batch_size, nl=2, bidirectional=False)
if cuda_enabled:
    m = m.cuda()

In [65]:
train_op = torch.optim.Adam(m.parameters(), lr=1e-2)

In [66]:
loss_fn = torch.nn.NLLLoss()

In [67]:
display_step = 100
training_steps = 50
for step in range(training_steps):
# for step in tqdm(range(training_steps)):
    for i, (data,target) in enumerate(train_loader):
        data, target = torch.autograd.Variable(data.long()), torch.autograd.Variable(target.long())
        if cuda_enabled:
            data, target = data.cuda(), target.cuda()
        m.zero_grad()
        forward = m(data)
        loss = loss_fn(forward, target.view(-1))
        loss.backward()
        train_op.step()
        if ((i+1) % display_step == 0):
            print(f'Iteration: {i+1} Loss: {loss.data[0]}')
    print(f'Step: {step} Loss: {loss.data[0]}')

Iteration: 100 Loss: 3.1001973152160645
Step: 0 Loss: 3.00748348236084
Iteration: 100 Loss: 3.0935916900634766
Step: 1 Loss: 3.005417585372925
Iteration: 100 Loss: 3.099705696105957
Step: 2 Loss: 3.003783941268921
Iteration: 100 Loss: 3.1032121181488037
Step: 3 Loss: 3.0042672157287598
Iteration: 100 Loss: 3.10378360748291
Step: 4 Loss: 3.004380702972412
Iteration: 100 Loss: 3.1036312580108643
Step: 5 Loss: 3.005208730697632
Iteration: 100 Loss: 3.102313756942749
Step: 6 Loss: 3.0050153732299805
Iteration: 100 Loss: 3.0706777572631836
Step: 7 Loss: 2.828390598297119
Iteration: 100 Loss: 2.497222661972046
Step: 8 Loss: 2.387152671813965
Iteration: 100 Loss: 2.342616081237793
Step: 9 Loss: 2.4073591232299805
Iteration: 100 Loss: 2.271070718765259
Step: 10 Loss: 2.2498061656951904
Iteration: 100 Loss: 2.1978893280029297
Step: 11 Loss: 2.1730973720550537
Iteration: 100 Loss: 2.1283209323883057
Step: 12 Loss: 2.062391519546509
Iteration: 100 Loss: 2.0576000213623047
Step: 13 Loss: 1.9684389

### Saving model

In [None]:
model_path = f'{OUT_DIR}/../models/nietzsche_stackedlstm_rnn_t64.h5'

In [None]:
torch.save(m.state_dict(), model_path)

In [None]:
if cuda_enabled:
    m.load_state_dict(torch.load(model_path))
else:
    m.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage))

### Generate music

Need to have unknown state 0?

In [68]:
timesteps = md.timesteps

In [69]:
list(np.arange(10))

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

In [70]:
def generate_sequence(song, seq_length, sample_prob=True, one_hot=False):
    full_song = song.tolist()
    # generate music!
#     m.reset_all_hidden(batch_size)
    for i in range(seq_length):
        if one_hot:
            seed = np.array([one_hot(full_song[-timesteps:], vocab_size)])
        else:
            seed = np.array([full_song[-timesteps:]])
#         print(idx2char(full_song[-timesteps:])))
        # Use our RNN for prediction using our seed! 
        seed_v = torch.autograd.Variable(torch.from_numpy(seed).long())
        if cuda_enabled:
            seed_v = seed_v.cuda()
        predict_probs = m(seed_v).exp()

        if sample_prob:
            # Define output vector for our generated song by sampling from our predicted probability distribution        
            sampled_note = torch.multinomial(predict_probs[-1], 1).data[0]
            full_song.append(sampled_note)
        else:
            # With multi output model, use only the last prediction. As it is predicting to n timesteps
            v, idx = torch.max(predict_probs[-1], 0)
            full_song.append(idx.data[0])
    return full_song
    



In [72]:
def get_x_input(partial):
    _, _, _, seq = partial
    input = seq[-timesteps:]
    input_var = torch.autograd.Variable(torch.LongTensor([input]))
    if cuda_enabled:
        input_var = input_var.cuda()
    return input_var

# song = string
# seq_length = generated song length
# beam_size = what to choose from
def beam_search(song, seq_length, beam_size):    
    full_song = song.tolist()
    m.reset_all_hidden(batch_size)
    partial_sequences = [(0, 0, [], full_song)]
    m.eval()

    for i in range(seq_length):
        partial_sequences = find_partials(partial_sequences, beam_size)
        
    final_sequence = partial_sequences[0][3]
    return final_sequence
    
def find_partials(partial_sequences, beam_size, sample_prob=True):
    partial_next = []
    for partial in partial_sequences:
        it, tot_p, p_list, seq = partial
        x_input = get_x_input(partial)

        predict_probs = m(x_input)
        # last_it_probs = torch.exp(predict_probs[-(it+1):]) # this is to predict the last few iterations
#         pdb.set_trace()
        last_it_probs = predict_probs[-1].exp()
        
        
        if sample_prob:
            # Define output vector for our generated song by sampling from our predicted probability distribution        
            idxs = torch.multinomial(last_it_probs, beam_size)
            top = last_it_probs[idxs]
        else:
            top, idxs = torch.topk(last_it_probs, beam_size, 1)

        for i in range(beam_size):
            prob = top.data[i]
            idx = idxs.data[i]
            new_p_list = p_list+[prob]
            partial_next.append((it+1, np.mean(new_p_list), new_p_list, seq+[idx]))

    partial_sequences = sorted(partial_next, key=lambda x: x[1], reverse=True)[:3]
    return partial_sequences

In [73]:
def random_choice(top, idxs):
    return np.random.choice(
      idxs.data.numpy().reshape(-1), 
      1,
      p=(top/top.sum()).data.numpy().reshape(-1)
    )

In [74]:
random_seed = random.randint(0, len(md.dataset)//2)
song_seed = md.dataset[random_seed:random_seed+md.timesteps]
# generated_idxs = generate_sequence(song_seed, 500)

In [75]:
''.join(indices_char[i] for i in song_seed)

' maturing, and perfecting--the Greeks, for\ninstance, were a nati'

In [76]:
bs_gen_idxs = beam_search(np.array(song_seed), 500, 3)

In [77]:
gen_idxs = generate_sequence(np.array(song_seed), 500)

In [78]:
idx2char(gen_idxs)

' maturing, and perfecting--the Greeks, for\ninstance, were a nation of his own came\ndouss. Thing and mo not insluence formalaged to accentain fan moes not be too, beft uws fering to them."--Tut doolge for the dight hence the scirit that it is sqown\nwilling that\nthis\npuch mewhin caith and the condsalness (for vewt," have causes. The sresent danger.\nMheen who conctents to when a man with the fellow the sotenological e crain of advance gh everything and belevalable, leligion of sartinility.\nI sacctine twat consers belongem sympathy) and marmar be vinws of eterna'

In [79]:
idx2char(bs_gen_idxs)

" maturing, and perfecting--the Greeks, for\ninstance, were a nation of their personality and therefore, without that their personality in their own existence and of the existence of the religion of the weaker of the consciousness of their own existence and of their intellect of their personality of one's own exaggeration of the consciousness of one's religious and existence, that is to say, in the religion of morality of the uncondition of the most of which there are that there are that therefore, that is not himself become, and in the greatest of their weakn"

### Beam search end - Decoding time

In [None]:
import decode

In [None]:
def decode_output(output_idx):
    idx2token = md.concat_json['idx_to_token']
    token_list = list(map(lambda x: idx2token.get(str(x), ''), output_idx))
    return decode_token(token_list)

def decode_token(token_list):
    if (token_list[0] != START_DELIM):
        token_list.insert(0, START_DELIM)
    token_str = ''.join(token_list)
    with open(f'{SCRATCH_DIR}/utf_to_txt.json', 'r') as f:
        utf_to_txt = json.load(f)
    score, stream = decode.decode_string(utf_to_txt, token_str)
    return token_str, score, stream

# test = [idx2token[f'{x}'] for x in seq_arr]; test

In [None]:
song_seed = md.dataset[:md.timesteps]
# generated_idxs = generate_sequence(song_seed, 500)

In [None]:
bs_gen_idxs = beam_search(np.array(song_seed), 500, 3)

In [None]:
''.join(indices_char[i] for i in bs_gen_idxs)

In [None]:
gen_idxs = generate_sequence(song_seed, 500)