In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
%pdb off

Automatic pdb calling has been turned OFF


In [2]:
import numpy as np
import torch
import torch.utils.data
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
from torchvision.utils import make_grid , save_image

In [3]:
from constants import *

In [4]:
import matplotlib.pyplot as plt

### Extracting Nietzche Corpus

In [5]:
from fastai.io import *

In [6]:
PATH='../data/nietzsche/'

In [7]:
get_data("https://s3.amazonaws.com/text-datasets/nietzsche.txt", f'{PATH}nietzsche.txt')
text = open(f'{PATH}nietzsche.txt').read()
print('corpus length:', len(text))

corpus length: 600893


In [8]:
text[:400]

'PREFACE\n\n\nSUPPOSING that Truth is a woman--what then? Is there not ground\nfor suspecting that all philosophers, in so far as they have been\ndogmatists, have failed to understand women--that the terrible\nseriousness and clumsy importunity with which they have usually paid\ntheir addresses to Truth, have been unskilled and unseemly methods for\nwinning a woman? Certainly she has never allowed herself '

In [9]:
chars = sorted(list(set(text)))
vocab_size = len(chars)+1
print('total chars:', vocab_size)

total chars: 85


In [10]:
chars.insert(0, "\0")

''.join(chars[1:-6])

'\n !"\'(),-.0123456789:;=?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]_abcdefghijklmnopqrstuvwxy'

In [11]:
char_indices = {c: i for i, c in enumerate(chars)}
indices_char = {i: c for i, c in enumerate(chars)}

In [12]:
idx = [char_indices[c] for c in text]

idx[:10]

[40, 42, 29, 30, 25, 27, 29, 1, 1, 1]

In [13]:
''.join(indices_char[i] for i in idx[:70])

'PREFACE\n\n\nSUPPOSING that Truth is a woman--what then? Is there not gro'

In [14]:
def idx2char(idx):
    return ''.join(indices_char[i] for i in idx)

### One hot encoding

In [15]:
def one_hot(a,c): 
    return np.eye(c)[a]

In [16]:
def hot2char(idx):
    val,idx = idx.max(dim=-1)
    return idx2char(idx)

### Dataset

In [17]:
class TextDataset(torch.utils.data.Dataset):
    """Face Landmarks dataset."""

    def __init__(self, text, idx, vocab_size, timesteps):
        self.vocab_size = vocab_size
        self.text = text
        self.dataset = idx
        self.data_length = len(idx)
        self.timesteps = timesteps

    def __len__(self):
        return (self.data_length // self.timesteps)

    def __getitem__(self, idx):
        start = idx*self.timesteps
        x = self.dataset[start:start+self.timesteps]
        y = self.dataset[start+1:start+self.timesteps+1]
#         return one_hot(x, vocab_size), np.array(y)
        return np.array(x), np.array(y)


In [18]:
batch_size = 64
timesteps = 64
md = TextDataset(text, idx, vocab_size, timesteps)
# md = MusicDataset(h5_file='concat_corpus.h5', set_type='train', json_file='concat_corpus.json', timesteps=timesteps, root_dir=CONCAT_DIR)

In [19]:
train_loader = torch.utils.data.DataLoader(md, batch_size=batch_size)

### Dataset sanity test

In [20]:
train_iter = enumerate(train_loader)

In [21]:
i, (x, y) = next(train_iter)
i2, (x2, y2) = next(train_iter)

In [22]:
md.dataset[:10]

[40, 42, 29, 30, 25, 27, 29, 1, 1, 1]

In [23]:
idx2char(md.dataset[:10])

'PREFACE\n\n\n'

In [24]:
idx2char(x2[0])

's bow! And perhaps also the arrow, the duty, and, who\nknows? THE'

In [25]:
# hot2char(x2[0])

### Model

In [26]:
cuda_enabled = torch.cuda.is_available()

In [27]:
def repackage_var(h):
    """Wraps h in new Variables, to detach them from their history."""
    return Variable(h.data) if type(h) == Variable else tuple(repackage_var(v) for v in h)

### AWD-LSTM

In [28]:
from fastai.lm_rnn import *

In [29]:
m = get_language_model(md.vocab_size, emb_sz=10, nhid=128, nlayers=2, pad_token=None)

### Training

In [30]:
# m = StatefulLSTM(md.vocab_size, n_hidden=256, n_factors=10, bs=batch_size, nl=2, bidirectional=False)
if cuda_enabled:
    m = m.cuda()

In [31]:
train_op = torch.optim.Adam(m.parameters(), lr=1e-2, betas=(0.8, 0.99))

In [32]:
loss_fn = torch.nn.CrossEntropyLoss()

In [None]:
display_step = 100
training_steps = 50
for step in range(training_steps):
# for step in tqdm(range(training_steps)):
    for i, (data,target) in enumerate(train_loader):
        data, target = torch.autograd.Variable(data.long()), torch.autograd.Variable(target.long())
        if cuda_enabled:
            data, target = data.cuda(), target.cuda()
        m.zero_grad()
#         forward = m(data)
        out = m(data)
        forward = out[0]
#         pdb.set_trace()
#         print(forward.shape)
#         print(forward[0].shape, forward[1][0].shape, forward[1][1].shape)
        loss = loss_fn(forward, target.view(-1))
        loss.backward()
        train_op.step()
        if ((i+1) % display_step == 0):
            print(f'Iteration: {i+1} Loss: {loss.data[0]}')
    print(f'Step: {step} Loss: {loss.data[0]}')

torch.Size([4096, 85]) torch.Size([64, 64, 128]) torch.Size([64, 64, 10])

### Saving model

In [33]:
model_path = f'{OUT_DIR}/../models/nietzsche_awdlstm_rnn_t64.h5'

In [44]:
torch.save(m.state_dict(), model_path)

In [34]:
if cuda_enabled:
    m.load_state_dict(torch.load(model_path))
else:
    m.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage))

### Generate music

Need to have unknown state 0?

In [35]:
timesteps = md.timesteps

In [36]:
list(np.arange(10))

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

In [37]:
%pdb on

Automatic pdb calling has been turned ON


In [45]:
def generate_sequence(song, seq_length, sample_prob=True, one_hot=False):
    full_song = song.tolist()
    # generate music!
#     m.reset_all_hidden(batch_size)
#     m.reset()
    for i in range(seq_length):
        if one_hot:
            seed = np.array([one_hot(full_song[-timesteps:], vocab_size)])
        else:
            seed = np.array([full_song[-timesteps:]])
#         print(idx2char(full_song[-timesteps:])))
        # Use our RNN for prediction using our seed! 
        seed_v = torch.autograd.Variable(torch.from_numpy(seed).long())
        if cuda_enabled:
            seed_v = seed_v.cuda()
#         predict_probs = m(seed_v)[0][-1].exp()
        
        predict_probs = m(seed_v)
        last_it_probs = predict_probs[0][-1].exp()

        if sample_prob:
            # Define output vector for our generated song by sampling from our predicted probability distribution        
            sampled_note = torch.multinomial(last_it_probs, 1).data[0]
            print(sampled_note)
            full_song.append(sampled_note)
        else:
            # With multi output model, use only the last prediction. As it is predicting to n timesteps
            v, idx = torch.max(last_it_probs, 0)
            full_song.append(idx.data[0])
    return full_song
    



In [46]:
def get_x_input(partial):
    _, _, _, seq = partial
    input = seq[-timesteps:]
    input_var = torch.autograd.Variable(torch.LongTensor([input]))
    if cuda_enabled:
        input_var = input_var.cuda()
    return input_var

# song = string
# seq_length = generated song length
# beam_size = what to choose from
def beam_search(song, seq_length, beam_size):    
    full_song = song.tolist()
#     m.reset_all_hidden(batch_size)
    partial_sequences = [(0, 0, [], full_song)]
    m.eval()

    for i in range(seq_length):
        partial_sequences = find_partials(partial_sequences, beam_size)
        
    final_sequence = partial_sequences[0][3]
    return final_sequence
    
def find_partials(partial_sequences, beam_size, sample_prob=True):
    partial_next = []
    for partial in partial_sequences:
        it, tot_p, p_list, seq = partial
        x_input = get_x_input(partial)

        predict_probs = m(x_input)
        # last_it_probs = torch.exp(predict_probs[-(it+1):]) # this is to predict the last few iterations
#         pdb.set_trace()
        last_it_probs = predict_probs[0][-1].exp()
        
        
        if sample_prob:
            # Define output vector for our generated song by sampling from our predicted probability distribution        
            idxs = torch.multinomial(last_it_probs, beam_size)
            top = last_it_probs[idxs]
        else:
            top, idxs = torch.topk(last_it_probs, beam_size, 1)

        for i in range(beam_size):
            prob = top.data[i]
            idx = idxs.data[i]
            new_p_list = p_list+[prob]
            partial_next.append((it+1, np.mean(new_p_list), new_p_list, seq+[idx]))

    partial_sequences = sorted(partial_next, key=lambda x: x[1], reverse=True)[:3]
    return partial_sequences

In [47]:
def random_choice(top, idxs):
    return np.random.choice(
      idxs.data.numpy().reshape(-1), 
      1,
      p=(top/top.sum()).data.numpy().reshape(-1)
    )

In [48]:
random_seed = random.randint(0, len(md.dataset)//2)
song_seed = md.dataset[random_seed:random_seed+md.timesteps]
# generated_idxs = generate_sequence(song_seed, 500)

In [49]:
''.join(indices_char[i] for i in song_seed)

'any\none, nor after, either; he places himself generally too far '

In [50]:
bs_gen_idxs = beam_search(np.array(song_seed), 500, 3)

In [None]:
gen_idxs = generate_sequence(np.array(song_seed), 500)

RuntimeError: cuda runtime error (59) : device-side assert triggered at /opt/conda/conda-bld/pytorch_1518244421288/work/torch/lib/THC/generic/THCStorage.c:36

> [0;32m<ipython-input-45-a0367d650400>[0m(23)[0;36mgenerate_sequence[0;34m()[0m
[0;32m     21 [0;31m        [0;32mif[0m [0msample_prob[0m[0;34m:[0m[0;34m[0m[0m
[0m[0;32m     22 [0;31m            [0;31m# Define output vector for our generated song by sampling from our predicted probability distribution[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 23 [0;31m            [0msampled_note[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mmultinomial[0m[0;34m([0m[0mlast_it_probs[0m[0;34m,[0m [0;36m1[0m[0;34m)[0m[0;34m.[0m[0mdata[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0m
[0m[0;32m     24 [0;31m            [0mprint[0m[0;34m([0m[0msampled_note[0m[0;34m)[0m[0;34m[0m[0m
[0m[0;32m     25 [0;31m            [0mfull_song[0m[0;34m.[0m[0mappend[0m[0;34m([0m[0msampled_note[0m[0;34m)[0m[0;34m[0m[0m
[0m


In [93]:
idx2char(gen_idxs)

NameError: name 'gen_idxs' is not defined

In [44]:
idx2char(bs_gen_idxs)

'nds, that to make unhappy and to make bad are just as\nlittle coue\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\

### Beam search end - Decoding time

In [None]:
import decode

In [None]:
def decode_output(output_idx):
    idx2token = md.concat_json['idx_to_token']
    token_list = list(map(lambda x: idx2token.get(str(x), ''), output_idx))
    return decode_token(token_list)

def decode_token(token_list):
    if (token_list[0] != START_DELIM):
        token_list.insert(0, START_DELIM)
    token_str = ''.join(token_list)
    with open(f'{SCRATCH_DIR}/utf_to_txt.json', 'r') as f:
        utf_to_txt = json.load(f)
    score, stream = decode.decode_string(utf_to_txt, token_str)
    return token_str, score, stream

# test = [idx2token[f'{x}'] for x in seq_arr]; test

In [None]:
song_seed = md.dataset[:md.timesteps]
# generated_idxs = generate_sequence(song_seed, 500)

In [None]:
bs_gen_idxs = beam_search(np.array(song_seed), 500, 3)

In [None]:
''.join(indices_char[i] for i in bs_gen_idxs)

In [None]:
gen_idxs = generate_sequence(song_seed, 500)