In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

## Data Loader

In [11]:
import numpy as np
from collections import namedtuple

In [88]:
Batch = namedtuple('Batch', 'features labels')

class DataLoader(object):
    def __init__(self, data, batch_size=8):
        assert isinstance(data, str) and len(data) > 1
        assert 0 < batch_size < len(data)
        self._data = self._preprocess_data(data)
        self._batch_size = batch_size
        self.curr_pos = 0
        
        all_chars = sorted(set(self._data))
        self._char_to_index = {ch:idx for idx, ch in enumerate(all_chars)}
        self._index_to_char = {idx:ch for ch, idx in self._char_to_index.items()}
    
    @property
    def num_examples(self):
        return len(self._data)-1
    
    @property
    def vocab_size(self):
        return len(self._char_to_index)
    
    @property
    def vocab(self):
        return set(self._char_to_index.keys())
    
    def next_batch(self):
        data = self._data[self.curr_pos:self.curr_pos+self._batch_size+1]
        self.curr_pos += self._batch_size
        if self.curr_pos >= self.num_examples:
            self.curr_pos = 0
        data = self._one_hot(data)
        return Batch(data[:-1, :], data[1:, :])
    
    def _one_hot(self, s):
        one_hot = np.zeros((len(s), self.vocab_size))
        idx = [self._char_to_index[ch] for ch in s]
        one_hot[np.arange(len(s)), idx] = 1
        return one_hot
    
    def decode(self, one_hot):
        if len(one_hot.shape) == 1:
            one_hot = one_hot.reshape(1, -1)
        return ''.join(map(self.decode_char, one_hot))
    
    def decode_char(self, one_hot):
        idx = np.argmax(one_hot)
        return self._index_to_char[idx]
    
    def __iter__(self):
        while True:
            yield next(self)
            if self.curr_pos == 0:
                break
    
    def __next__(self):
        return self.next_batch()
    
    def _preprocess_data(self, data):
        return data.lower()

In [89]:
def read_data():
    with open('data/input.txt', 'r') as fp:
        return fp.read()

## RNN model

### Helper modules

In [146]:
class Layer:
    def forward(self, *args, **kwrgs):
        raise NotImplementedError
    
    def backward(self, *args, **kwrgs):
        raise NotImplementedError
    
    def __call__(self, *args, **kwrgs):
        return self.forward(*args, **kwrgs)

class Affine(Layer):
    def forward(self, x, w, b):
        self.cache = (x, w, b)
        return x.dot(w) + b
    
    def backward(self, dout):
        x, w, b = self.cache
        db = dout.sum(axis=0)
        dw = x.T.dot(dout)
        dx = dout.dot(w.T)
        return dx, dw, db

class Tanh(Layer):
    def forward(self, x):
        out = np.tanh(x)
        self.cache = out
        return out
    
    def backward(self, dout):
        out = self.cache
        return (1-out*out)*dout

class CrossEntropy(Layer):
    def forward(self, logits, target):
        target = np.argmax(target, axis=1)
        logits = logits.copy()
        logits -= np.max(logits, axis=1)
        unnormalized_probs = np.exp(logits)
        probs = unnormalized_probs / np.sum(unnormalized_probs, axis=1, keepdims=True)
        correct_class_probs = probs[np.arange(len(logits)), target]
        self.cache = (probs, target)
        return np.mean(-np.log(correct_class_probs))
    
    def backward(self):
        probs, target = self.cache
        dlogits = probs.copy()
        dlogits[np.arange(len(dlogits)), target] -= 1
        return dlogits / dlogits.shape[0]

### RNN

In [229]:
class RNN(Layer):
    def forward(self, hidden_state, features, labels, params):
        self.params = params
        i2h, i2h_b = params['i2h'], params['i2h_b']
        h2h, h2h_b = params['h2h'], params['h2h_b']
        h2o, h2o_b = params['h2o'], params['h2o_b']
        num_chars = features.shape[0]
        self.affine_hidden = []
        self.affine_input = []
        self.tanh = []
        self.affine_output = []
        self.loss = []
        loss = 0
        for i in range(num_chars):
            x = np.expand_dims(features[i], 0)
            y = np.expand_dims(labels[i], 0)
            affine_hidden = Affine()
            affine_input = Affine()
            affine_output = Affine()
            cross_entropy = CrossEntropy()
            hidden_raw = affine_input(x, i2h, i2h_b) + affine_hidden(hidden_state, h2h, h2h_b)
            tanh = Tanh()
            hidden_state = tanh(hidden_raw)
            logits = affine_output(hidden_state, h2o, h2o_b)
            loss += cross_entropy(logits, y)

            self.affine_hidden.append(affine_hidden)
            self.affine_input.append(affine_input)
            self.tanh.append(tanh)
            self.affine_output.append(affine_output)
            self.loss.append(cross_entropy)
        return loss / len(features), hidden_state
    
    def backward(self):
        next_dhidden = 0
        grads = {k: np.zeros_like(v) for k, v in self.params.items()}
        while len(self.affine_hidden) > 0:
            loss = self.loss.pop()
            dlogits = loss.backward()
            
            affine_output = self.affine_output.pop()
            dhidden, dh2o, dh2o_b = affine_output.backward(dlogits)
            dhidden = dhidden + next_dhidden
            grads['h2o'] += dh2o
            grads['h2o_b'] += dh2o_b
            
            tanh = self.tanh.pop()
            dhidden_raw = tanh.backward(dhidden)
            
            affine_input = self.affine_input.pop()
            dx, di2h, di2h_b = affine_input.backward(dhidden_raw)
            grads['i2h'] += di2h
            grads['i2h_b'] += di2h_b
            
            affine_hidden = self.affine_hidden.pop()
            dhidden_prev, dh2h, dh2h_b = affine_hidden.backward(dhidden_raw)
            grads['h2h'] += dh2h
            grads['h2h_b'] += dh2h_b
            
            next_dhidden = dhidden_prev
        
        return grads
            
def init_params(vocab_size, hidden_size, std=0.01):
    init_weights = lambda size: np.random.randn(*size)*std
    init_bias = lambda num_outputs: np.zeros((num_outputs,))
    params = {
        'i2h': init_weights((vocab_size, hidden_size)),
        'h2h': init_weights((hidden_size, hidden_size)),
        'h2o': init_weights((hidden_size, vocab_size)),
        'i2h_b': init_bias(hidden_size),
        'h2h_b': init_bias(hidden_size),
        'h2o_b': init_bias(vocab_size)
    }
    return params


def sample(rnn, hidden_state, start_char):

In [248]:
class Optimizer:
    def step(self, *args, **kwrgs):
        raise NotImplementedError


class Adam(Optimizer):
    def __init__(self, params, lr=1e-3, beta1=.9, beta2=.99, eps=1e-3):
        self.params = params
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.eps = eps
        self.dx = {k: np.zeros_like(v) for k, v in params.items()}
        self.dx2 = {k: np.zeros_like(v) for k, v in params.items()}
    
    def step(self, grads):
        for param_name, grad_value in grads.items():
            param_value = self.params[param_name]
            m = self.beta1*self.dx[param_name] + (1-self.beta1)*grad_value
            v = self.beta2*self.dx2[param_name] + (1-self.beta2)*(grad_value*grad_value)
            self.dx[param_name] = m
            self.dx2[param_name] = v
            self.params[param_name] = param_value - self.lr*m / np.sqrt(v + self.eps)

In [266]:
data = read_data()
data_loader = DataLoader(data[:100], batch_size=5)

hidden_size, vocab_size = 100, data_loader.vocab_size
rnn = RNN()
params = init_params(vocab_size=vocab_size, hidden_size=hidden_size)
hidden_state = np.zeros((1, hidden_size))
rnn.forward(hidden_state, *next(data_loader), params)
optimizer = Adam(params, lr=1e-3)

num_epochs = 10000
it = 0
for epoch in range(num_epochs):
    hidden_state = np.zeros((1, hidden_size))
    losses = []
    for x, y in data_loader:
        loss, hidden_state = rnn(hidden_state, x, y, params)
        losses.append(loss)
        if it % 100 == 0:
            print('it: {}, loss: {}'.format(it, loss))
        grads = rnn.backward()
        optimizer.step(grads)
        it += 1
    #print('epoch: {}, loss: {}'.format(epoch, np.mean(losses)))

it: 0, loss: 3.2576426012210336
it: 100, loss: 3.2298292198352443
it: 200, loss: 3.251130540862038
it: 300, loss: 3.196794440759586
it: 400, loss: 3.1269258917975713
it: 500, loss: 2.94407508107704
it: 600, loss: 2.1365516068773025
it: 700, loss: 1.3182259383038677
it: 800, loss: 0.9968493556733904
it: 900, loss: 0.7923963624947317
it: 1000, loss: 0.6251716142606709
it: 1100, loss: 0.4958127794885815
it: 1200, loss: 0.38073855354151565
it: 1300, loss: 0.287894490211315
it: 1400, loss: 0.22909915270624484
it: 1500, loss: 0.1823449287386685
it: 1600, loss: 0.12241460449523407
it: 1700, loss: 0.07907456837143093
it: 1800, loss: 0.06964367933669009
it: 1900, loss: 0.05054010390215904
it: 2000, loss: 0.037459008791289626
it: 2100, loss: 0.03384241858416666
it: 2200, loss: 0.020679670625307595
it: 2300, loss: 0.02126603626774248
it: 2400, loss: 0.014916389786293569
it: 2500, loss: 0.013637400660654627
it: 2600, loss: 0.01273527926959669
it: 2700, loss: 0.012141979462648896
it: 2800, loss: 0.

KeyboardInterrupt: 