In [58]:
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
import math
from tqdm import tqdm

In [59]:
CUDA = False

In [60]:
def mcuda(tensor):
    if CUDA:
        tensor = tensor.cuda()
    return tensor

In [61]:
def one_hot_char(i, n):
    x = torch.zeros(n)
    x[i] = 1
    return x.reshape(1, -1)

In [62]:
class RNNCell(nn.Module):
    def __init__(self, in_dim, hidden_dim):
        super(RNNCell, self).__init__()
        self.Wxh = Parameter(mcuda(torch.Tensor(in_dim, hidden_dim)))
        self.Whh = Parameter(mcuda(torch.Tensor(hidden_dim, hidden_dim)))

        self.bh = Parameter(mcuda(torch.Tensor(hidden_dim)))
        self.activation = nn.Tanh()
        self.reset_parameters()
    
    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.Whh.size(1))
        self.Wxh.data.uniform_(-stdv, stdv)
        self.Whh.data.uniform_(-stdv, stdv)
        self.bh.data.zero_()
    
    def forward(self, input, state):
        preact = (
            self.bh +
            torch.matmul(input, self.Wxh) +
            torch.matmul(state, self.Whh)
        )
        return self.activation(preact)
    
class CharRNN(nn.Module):
    def __init__(self, vocab_size, hidden_size):
        super(CharRNN, self).__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.dropout = nn.Dropout(0.2)
        self.Why = Parameter(mcuda(torch.Tensor(self.hidden_size, self.vocab_size)))
        self.by = Parameter(mcuda(torch.Tensor(self.vocab_size)))
        self.reset_parameters()
        self.init_rnns()
    
    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.Why.size(1))
        self.Why.data.uniform_(-stdv, stdv)
        self.by.data.zero_()

    def init_rnns(self):
        self.cell1 = RNNCell(self.vocab_size, self.hidden_size)
        self.cell2 = RNNCell(self.hidden_size, self.hidden_size)
        self.cell3 = RNNCell(self.hidden_size, self.hidden_size)

    def init_states(self):
        return [
            mcuda(torch.zeros(self.hidden_size)),
            mcuda(torch.zeros(self.hidden_size)),
            mcuda(torch.zeros(self.hidden_size))
        ]

    def forward(self, input, state):
        for x in input:
            x = mcuda(one_hot_char(x, self.vocab_size))
            state1 = self.cell1(x, state[0])
            state2 = self.cell2(state1, state[1])
            state3 = self.cell3(state2, state[2])
            state3 = self.dropout(state3)
        
        preds = nn.LogSoftmax(dim=1)(torch.matmul(state3, self.Why) + self.by)
        

        return preds, [state1, state2, state3]


In [63]:
with open("data/tinyshakespeare.txt") as f:
    data = f.read()

corpus = data
chars = set(corpus)
vocab_size = len(chars)
c2i = {c: i for i, c in enumerate(chars)}
i2c = {i: c for i,c in enumerate(chars)}

encoded_data = [c2i[char] for char in corpus]

In [64]:
import random
chunk_len = 200

def random_chunk():
    start_index = random.randint(0, len(corpus) - chunk_len)
    end_index = start_index + chunk_len + 1
    return corpus[start_index:end_index]

print(random_chunk())

more shows off
Your wonder: but yet speak; first, you, my liege,
Comes it not something near?

LEONTES:
Her natural posture!
Chide me, dear stone, that I may say indeed
Thou art Hermione; or rather, th


In [65]:
def char_map_index(x):
    return [c2i[char] for char in x]

In [66]:
def random_training_set():    
    chunk = random_chunk()
    inp = char_map_index(chunk[:-1])
    target = char_map_index(chunk[1:])
    return inp, target

In [67]:
def evaluate(model, prime_str='A', predict_len=100, temperature=0.8):
    model.eval()
    hidden = model.init_states()
    prime_input = char_map_index(prime_str)
    predicted = prime_str

    # Use priming string to "build up" hidden state
    for p in range(len(prime_str) - 1):
        _, hidden = model([prime_input[p]], hidden)
    inp = prime_input[-1]
    
    for p in range(predict_len):
        output, hidden = model([inp], hidden)
        
        # Sample from the network as a multinomial distribution
        output_dist = output.data.view(-1).div(temperature).exp()
        top_i = torch.multinomial(output_dist, 1)[0]
        
        # Add predicted character to string and use as next input
        predicted_char = i2c[top_i.item()]
        predicted += predicted_char
        inp = char_map_index(predicted_char)

    return predicted

In [68]:
import time, math

def time_since(since):
    s = time.time() - since
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

In [69]:
def train(inp, target):
    decoder.train()
    hidden = decoder.init_states()
    decoder.zero_grad()
    loss = 0

    for c in range(chunk_len):
        output, hidden = decoder([inp[c]], hidden)
        loss += criterion(output, one_hot_char(target[c], vocab_size))

    loss.backward()
    torch.nn.utils.clip_grad_norm_(decoder.parameters(), 5.0)
    decoder_optimizer.step()

    return loss.item() / chunk_len

In [70]:
n_epochs = 20000
print_every = 4000
hidden_size = 150
lr = 0.003

decoder = CharRNN(vocab_size, hidden_size)
decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

start = time.time()
all_losses = []
#loss_avg = 0

for epoch in range(1, n_epochs + 1):
    loss = train(*random_training_set())       
    #loss_avg += loss

    if epoch % print_every == 0:
        print('[%s (%d %d%%) %.4f]' % (time_since(start), epoch, epoch / n_epochs * 100, loss))
        print(evaluate(decoder,'Citizen:', 100), '\n')


[6m 19s (4000 20%) 2.0249]
Citizen: I dee sceath: I her to thou for be stonce
so men she sirter the tou, cransow's browe youn's them se 

[12m 45s (8000 40%) 2.2972]
Citizen:
You do you afen
bed bandess he asfore,
Pome fayens iry but in will you, but,
And siven for you disp 

[19m 13s (12000 60%) 1.9375]
Citizen: but as the that cell your have bucking a vorise aud in all the rewer.

LACHOLGS:
Whit with ta thee  

[25m 41s (16000 80%) 1.8572]
Citizen:
To mes us ernour have sto the crince to have ont ho to the been,
Parst;
In steasht seemmed in thy f 

[32m 6s (20000 100%) 2.1926]
Citizen: the will head apeit is me even onstrelf
To with which yad you are the snands to mant thy dar mayns  



In [72]:
print(corpus[:100])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [74]:
print(evaluate(decoder, corpus[:100], 1000))

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You mard be casst thing hard in the man! is lett all in and po's, is far with he not mate this!
Are a bond now all be me not misscord of of this man,
I she oft my the sie! is the glaw your ke death and not mart: sir beens mad my to me seithous not the fewling mand toal do the brile so musent noble of three:

QUENIO:
Whom the pleat and 'ud comanf you moust,
I' thoughts trind thou conning mape, whes thou us it.

ILINIUSHe therefore now beain cioved y to the hones our love we'tisher and in the wirpe:
The woule?

MARENIUSHAR:
Bot un of to is is donoted not blotelent me minger At nhear af the conges that I we here and tweo hy be him I man ould michmen: wrused of shwer me: I woo lord his not same the eve in some.
Whee, wive:
This it ut make not call enturs hould mar: sir more ind hang and cother mis with her nent not sting the greant off the censne im healing,
Es that us sosther of so yhe him lor