In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

## Data Loader

In [2]:
import numpy as np
from collections import namedtuple

In [3]:
class DataLoader(object):
    def __init__(self, data, batch_size=8):
        assert isinstance(data, str) and len(data) > 1
        assert 0 < batch_size < len(data)
        self._data = self._preprocess_data(data)
        self._batch_size = batch_size
        self.curr_pos = 0
        
        all_chars = sorted(set(self._data))
        self._char_to_index = {ch:idx for idx, ch in enumerate(all_chars)}
        self._index_to_char = {idx:ch for ch, idx in self._char_to_index.items()}
    
    @property
    def num_examples(self):
        return len(self._data)-1
    
    @property
    def vocab_size(self):
        return len(self._char_to_index)
    
    @property
    def vocab(self):
        return set(self._char_to_index.keys())
    
    def next_batch(self):
        data = self._data[self.curr_pos:self.curr_pos+self._batch_size+1]
        self.curr_pos += self._batch_size
        if self.curr_pos >= self.num_examples:
            self.curr_pos = 0
        data = self._one_hot(data)
        return data[:-1, :], data[1:, :]
    
    def _one_hot(self, s):
        one_hot = np.zeros((len(s), self.vocab_size))
        idx = [self._char_to_index[ch] for ch in s]
        one_hot[np.arange(len(s)), idx] = 1
        return one_hot
    
    def decode(self, one_hot):
        if len(one_hot.shape) == 1:
            one_hot = one_hot.reshape(1, -1)
        return ''.join(map(self.decode_char, one_hot))
    
    def decode_char(self, one_hot):
        idx = np.argmax(one_hot)
        return self._index_to_char[idx]
    
    def __iter__(self):
        while True:
            yield next(self)
            if self.curr_pos == 0:
                break
    
    def __next__(self):
        return self.next_batch()
    
    def _preprocess_data(self, data):
        return data.lower()

In [4]:
def read_data():
    with open('data/input.txt', 'r') as fp:
        return fp.read()

## RNN model

### Helper modules

In [5]:
class Layer:
    def forward(self, *args, **kwrgs):
        raise NotImplementedError
    
    def backward(self, *args, **kwrgs):
        raise NotImplementedError
    
    def __call__(self, *args, **kwrgs):
        return self.forward(*args, **kwrgs)

    
class Affine(Layer):
    def forward(self, x, w, b):
        self.cache = (x, w, b)
        return x.dot(w) + b
    
    def backward(self, dout):
        x, w, b = self.cache
        db = dout.sum(axis=0)
        dw = x.T.dot(dout)
        dx = dout.dot(w.T)
        return dx, dw, db

    
class Tanh(Layer):
    def forward(self, x):
        out = np.tanh(x)
        self.cache = out
        return out
    
    def backward(self, dout):
        out = self.cache
        return (1-out*out)*dout

    
class CrossEntropy(Layer):
    def forward(self, logits, target):
        if len(logits.shape) == 1:
            logits = np.expand_dims(logits, 0)
        if len(target.shape) == 1:
            target = np.expand_dims(target, 0)
        target = np.argmax(target, axis=1)
        logits = logits.copy()
        logits -= np.max(logits, axis=1)
        unnormalized_probs = np.exp(logits)
        probs = unnormalized_probs / np.sum(unnormalized_probs, axis=1, keepdims=True)
        correct_class_probs = probs[np.arange(len(logits)), target]
        self.cache = (probs, target)
        return np.mean(-np.log(correct_class_probs))
    
    def backward(self):
        probs, target = self.cache
        dlogits = probs.copy()
        dlogits[np.arange(len(dlogits)), target] -= 1
        return dlogits / dlogits.shape[0]

In [6]:
class RNNCell(Layer):
    def forward(self, x, hidden_state_prev, params):
        assert len(x.shape) == 2
        affine_hidden, affine_input, affine_output = Affine(), Affine(), Affine()
        tanh = Tanh()
        hidden_state_raw = affine_hidden(hidden_state_prev, params['h2h'], params['h2h_b'])
        hidden_state_raw += affine_input(x, params['i2h'], params['i2h_b'])
        hidden_state = tanh(hidden_state_raw)
        logits = affine_output(hidden_state, params['h2o'], params['h2o_b'])
        self.cache = (affine_hidden, affine_input, affine_output, tanh, params)
        return hidden_state, logits
    
    def backward(self, dnext_hidden_state, dlogits):
        affine_hidden, affine_input, affine_output, tanh, params = self.cache
        dparams = {}
        dhidden_state, dparams['h2o'], dparams['h2o_b'] = affine_output.backward(dlogits)
        dhidden_state = dhidden_state + dnext_hidden_state
        dhidden_state_raw = tanh.backward(dhidden_state)
        dhidden_state_prev, dparams['h2h'], dparams['h2h_b'] = affine_hidden.backward(dhidden_state_raw)
        dx, dparams['i2h'], dparams['i2h_b'] = affine_input.backward(dhidden_state_raw)
        return dx, dhidden_state_prev, dparams


class RNN(Layer):
    def forward(self, hidden_state, x, params):
        num_inputs = len(x)
        logits = []
        self.cache = []
        for i in range(num_inputs):
            rnn_cell = RNNCell()
            hidden_state, _logits = rnn_cell(np.expand_dims(x[i], 0), hidden_state, params)
            logits.append(_logits)
            self.cache.append(rnn_cell)
        self.cache = (self.cache, params)
        return hidden_state, logits
    
    def backward(self, dlogits):
        rnn_cells, params = self.cache
        dparams = {k: np.zeros_like(v) for k, v in params.items()}
        dnext_hidden_state = 0
        while len(rnn_cells) > 0:
            rnn_cell = rnn_cells.pop()
            _, dnext_hidden_state, _dparams = rnn_cell.backward(dnext_hidden_state, dlogits.pop())
            for param_name, grad_value in _dparams.items():
                dparams[param_name] += grad_value
        return dparams

In [22]:
def rnn_training_step(rnn, hidden_state, x, y, params):
    hidden_state, logits = rnn(hidden_state, x, params)
    dlogits = []
    loss = 0
    for i, l in enumerate(logits):
        criterion = CrossEntropy()
        loss += criterion(l, y[i])
        dlogits.append(criterion.backward())
    loss /= len(x)
    dparams = rnn.backward(dlogits)
    return loss, hidden_state, dparams


def sample(rnn, hidden_state, input, params, n=100):
    one_hot = []
    while n > 0:
        if len(input.shape) == 1:
            input = np.expand_dims(input, 0)
        hidden_state, logits = rnn(hidden_state, input, params)
        logits = logits[0].squeeze()
        probs = logits_to_probs(logits)
        idx = np.random.choice(len(logits), p=probs)
        one_hot_char = np.zeros_like(logits)
        one_hot_char[idx] = 1
        one_hot.append(one_hot_char)
        input = one_hot_char
        n -= 1
    return np.asarray(one_hot)


def logits_to_probs(logits):
    logits = logits.copy()
    logits -= np.max(logits)
    unnormalized_probs = np.exp(logits)
    return unnormalized_probs / np.sum(unnormalized_probs)


def init_params(vocab_size, hidden_size, std=0.01):
    init_weights = lambda size: np.random.randn(*size)*std
    init_bias = lambda num_outputs: np.zeros((num_outputs,))
    params = {
        'i2h': init_weights((vocab_size, hidden_size)),
        'h2h': init_weights((hidden_size, hidden_size)),
        'h2o': init_weights((hidden_size, vocab_size)),
        'i2h_b': init_bias(hidden_size),
        'h2h_b': init_bias(hidden_size),
        'h2o_b': init_bias(vocab_size)
    }
    return params

sample(rnn, hidden_state, x[0], params, n=10)

array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 

### Optimizer

In [23]:
class Optimizer:
    def step(self, *args, **kwrgs):
        raise NotImplementedError


class Adam(Optimizer):
    def __init__(self, params, lr=1e-3, beta1=.9, beta2=.99, eps=1e-8):
        self.params = params
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.eps = eps
        self.dx = {k: np.zeros_like(v) for k, v in params.items()}
        self.dx2 = {k: np.zeros_like(v) for k, v in params.items()}
    
    def step(self, grads):
        for param_name, grad_value in grads.items():
            param_value = self.params[param_name]
            m = self.beta1*self.dx[param_name] + (1-self.beta1)*grad_value
            v = self.beta2*self.dx2[param_name] + (1-self.beta2)*(grad_value*grad_value)
            self.dx[param_name] = m
            self.dx2[param_name] = v
            self.params[param_name] = param_value - self.lr*m / np.sqrt(v + self.eps)
        return self.params

### Training

In [38]:
data = read_data()
data = 'i will work at deep mind how are you doing bold fuck!'
data_loader = DataLoader(data[:10000], batch_size=5)

hidden_size, vocab_size = 100, data_loader.vocab_size
rnn = RNN()
params = init_params(vocab_size=vocab_size, hidden_size=hidden_size)
optimizer = Adam(params, lr=1e-3)

num_epochs = 100000

it = 0
for epoch in range(num_epochs):
    hidden_state = np.zeros((1, hidden_size))
    for x, y in data_loader:
        if it % 10000 == 0:
            one_hot = sample(rnn, hidden_state, x[0], params, n=60)
            print(data_loader.decode(one_hot))
        loss, hidden_state, dparams = rnn_training_step(rnn, hidden_state, x, y, params)
        if it % 1000 == 0:
            print('it: {}, loss: {}'.format(it, loss))
        optimizer.step(dparams)
        it += 1

wkdcwaapkdmmbtmwrpighdgopgrka!gh wrowhnfcgffiwgre!tkmwwhaiff
it: 0, loss: 3.0910644653429156
it: 1000, loss: 0.08377164816841934
it: 2000, loss: 0.0006755729377005408
it: 3000, loss: 2.1757301470474312e-05
it: 4000, loss: 6.479699640840998e-06
it: 5000, loss: 3.9689606252081366e-06
it: 6000, loss: 3.123754695730528e-06
it: 7000, loss: 3.0324899583466215e-06
it: 8000, loss: 2.648846852522479e-06
it: 9000, loss: 1.3871291031618162e-06
 work at deep mind how are you doing bold fuck!ri deep mind 
it: 10000, loss: 1.242838972457879e-06
it: 11000, loss: 1.6113929055848675e-06
it: 12000, loss: 2.06054816064e-06
it: 13000, loss: 1.5422487566677646e-06
it: 14000, loss: 1.1692492915650097e-06
it: 15000, loss: 7.916838854739989e-07
it: 16000, loss: 7.58351926236691e-07
it: 17000, loss: 7.635637944135057e-07
it: 18000, loss: 8.804640845422465e-07
it: 19000, loss: 8.918848269927499e-07
 at deep mind how are you doing bold fuck!rt deep mind how a
it: 20000, loss: 5.186986966812195e-07
it: 21000, los

 doing bold fuck!re doep mind how are you doing bold fuck!rt
it: 180000, loss: 5.25792464199203e-08
it: 181000, loss: 5.5879501080795897e-08
it: 182000, loss: 5.962151143274982e-08
it: 183000, loss: 7.280594959897637e-08
it: 184000, loss: 7.95203360868413e-08
it: 185000, loss: 4.8763496802735365e-08
it: 186000, loss: 4.700783763739764e-08
it: 187000, loss: 6.821057755682711e-08
it: 188000, loss: 1.0242259016770494e-07
it: 189000, loss: 8.395278599715527e-08
g bold fuck!at deep mind how are you doing bold fuck!rt deep
it: 190000, loss: 6.937992731888226e-08
it: 191000, loss: 4.945182687528491e-08
it: 192000, loss: 5.259596085928617e-08
it: 193000, loss: 5.612368723918216e-08
it: 194000, loss: 6.855426728851195e-08
it: 195000, loss: 7.492426971359805e-08
it: 196000, loss: 4.596056252182153e-08
it: 197000, loss: 4.427857679378064e-08
it: 198000, loss: 6.42483323994625e-08
it: 199000, loss: 9.661143548714996e-08
d fuck!rt deep mind how are you doing bold fuck!rn deep mind
it: 200000, loss:

KeyboardInterrupt: 

In [30]:
print(data[:100])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You
