# Recurrent Neural Network for Next-Bit Prediction

There are many tutorials on Recurrent Neural Networks. This notebook shows that computations on RNNs can be conveniently factored according to the <b>refresh-extract</b> paradigm, shared by other primivites such as pseudorandom number generators. 
The basic (vanila) implementation is then evaluated on the task of predicting next bits for a Markov chain (a model used for studying outputs of random generators) and compared with best theoretical accuracy.


The code is essentially equivalent to a network built out of $\texttt{tf.contrib.rnn.BasicRNNCell}$ iterated through the $\texttt{tf.nn.static_rnn}$ method.

## Refresh-Extract Pattern

RNNs computation can be factored in two steps:  
<ul>
    <li>$\texttt{Refresh}$: refreshing the internal state with an input</li>
    <li>$\texttt{Extract}$: extracting the output from the current internal state</li>
</ul>

In this they are similar to cryptographic random number generators.

$$
\begin{align}
S_{t-1},X_{t} & \overset{refresh}{\longrightarrow} S_t \\
S_{t} & \overset{extract}{\longrightarrow} Y_t
\end{align}
$$

As for implementation details $\texttt{Refresh}$ merges the input with the current state and tranforms it through a hidden layer, while $\texttt{Extract}$ transforms the state into a distribution used to sample the output $Y_i$ (it is convenient to work with the distribution when training).

A forward pass is a sequence of $T$ steps each executing "refresh" and then "extract" operation. It initializes the state in the beginning and provides $T$ inputs.

In [1]:
import tensorflow as tf

xavier = tf.contrib.layers.xavier_initializer()

## Refreshing ##

def refresh(state,new_input):
    with tf.variable_scope('refresh',reuse=True):
        W = tf.get_variable('W')
        b = tf.get_variable('b')
        return tf.nn.tanh(tf.matmul(tf.concat([state,new_input],axis=-1),W)+b)
    
## Extraction ##
    
def extract(state):
    with tf.variable_scope('extract',reuse=True):
        W = tf.get_variable('W')
        b = tf.get_variable('b',shape=(dim_target_out,))
        return tf.nn.tanh(tf.matmul(state,W)+b)
    
## Forward Pass ##
    
def eval_forward(state0,X_list):
    states = []
    outputs = []
    state = state0
    for x in X_list:
        state = refresh(state,tf.to_float(x))
        states.append(state)
        output = extract(state)
        outputs.append(output)
    return outputs,states

## Likelihood

To train a RNN we want to maximize the likelihood of the output sequence $\{Y_t\}_t$.
Denoting by $\{\hat{y}_t\}_t$ the predicted distribution we can write the log-likelihood as 

$$\sum_{t} \log p(\left. Y_t\ \right|\ X_{ t}) = \sum_{i}\mathsf{OneHot}(Y_t)\cdot \log \hat{y}_t$$


The RNN equations contain information about all previous data points. When fitting the model we optimize a truncated $T$-step pass.

$$
S_{t-1},X_{t}\quad \underbrace{\overset{Refresh-Extract}{\longrightarrow\ldots\longrightarrow}}_{T-fold}\quad S_{t+T-1},X_{t+T}
$$

which avoids issues with optimizing long chained expressions (vanishing/exploding gradients). 

More precisely, during the $i$-th step we optimize the loss (negative log-likelihood) of the data generated by the forward pass
from time $i\cdot T+1$ to time $(i+1)\cdot T$. The very first state is initialized with zeros. The graph for $T=5$ is shown below (note it reflects the refresh-extract logic).

In [2]:
from IPython.display import Image
from IPython.core.display import HTML 
Image(url= "./graph.png")

In [3]:
tf.reset_default_graph()

## Global params ##

dim_target_out = 2 # target out-dimension 
num_step = 5 # backpropagation length 
dim_state = 10
batch_size = 250

## Placeholders ##

X = tf.placeholder(dtype=tf.int32,shape=(batch_size,num_step),name='X')
Y = tf.placeholder(dtype=tf.int32,shape=(batch_size,num_step),name='Y')

## Rearrange inputs ##
X_enc = tf.one_hot(X,dim_target_out,axis=-1,name='X-onehot') # shape = (batch_size,window,target_out)
X_list = tf.unstack(X_enc,axis=1,name='X-list') # list of (batch_size,target_out) of length=window

## Refreshing params ##

with tf.variable_scope('refresh'):
    W = tf.get_variable('W',shape=(dim_state+dim_target_out,dim_state),initializer=xavier)
    b = tf.get_variable('b',shape=(dim_state,),initializer=tf.constant_initializer(0))

## Extraction params ##
with tf.variable_scope('extract'):
    W = tf.get_variable('W',shape=(dim_state,dim_target_out),initializer=xavier)
    b = tf.get_variable('b',shape=(dim_target_out,),initializer=tf.constant_initializer(0))
    
## State initialization ##

state0 = tf.zeros(shape=(batch_size,dim_state),name='state0')
    
## Forward pass ##
outputs,states = eval_forward(state0,X_list)
state = states[-1]


## Likelihood / accuracy ##

with tf.variable_scope('loss'):
    outputs_t = tf.stack(outputs,axis=1,name='outputs')
    loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=Y,logits=outputs_t))
    acc = tf.reduce_mean(tf.to_float(tf.equal(tf.to_int64(Y),tf.argmax(outputs_t,-1))))
    
with tf.variable_scope('optimize'):
    optimizer = tf.train.AdamOptimizer(1e-2).minimize(loss)

## Test on a Markov Chain

To illustrate RNN in action, let's evaluate it on a specific Markov chain on bits, which introduces bias depending on the partity of last 3 bits. In the example below the best theoretical accuracy for predicting next bit is $0.55$. A simple RNN model gets easily to about $0.53$, getting even closer to the limit would likely need more complicated architecture - the network needs essentially to learn parity and that is known to be little tricky.

In [4]:
## data generation ##

from collections import deque
from itertools import islice
import numpy as np

def sequence(n,w=2):

    tmp = deque()
    tmp.extend(0 for _ in range(w))

    for i in range(n):
        if sum(tmp)%2==1:
            threshold = 0.55
        else:
            threshold = 0.45
        out = int( np.random.rand() < threshold )
        tmp.append(out)
        tmp.popleft()
        yield out

print('Bias %s'% (sum(i for i in sequence(10000))/10000 - 0.5))

n_batch = 20000
data = sequence(n_batch*batch_size*num_step,3)

Bias -0.0006999999999999784


In [5]:
with tf.Session() as sess:
    #writer = tf.summary.FileWriter('./', sess.graph)
    sess.run(tf.global_variables_initializer())
    state_ = np.zeros(shape=(batch_size,dim_state))
    total_acc = 0
    total_loss = 0
    for i in range(n_batch):
        X_ = np.array(list(t for t in islice(data,batch_size*num_step)))
        Y_ = np.concatenate([X_[1:],[np.random.randint(0,2)]],0)
        X_ = X_.reshape(batch_size,num_step)
        Y_ = Y_.reshape(batch_size,num_step)
        loss_,state_,acc_,_=sess.run([loss,state,acc,optimizer],
                           feed_dict={X:X_,Y:Y_,state0:state_})
        total_acc += acc_
        total_loss += loss_
        if i%1000==0:
            print('Loss=%s, Accuracy=%s'%(total_loss/(i+1),total_acc/(i+1)))

Loss=0.7056767344474792, Accuracy=0.5040000081062317
Loss=0.6926051188301254, Accuracy=0.5149626385617804
Loss=0.6922789198169108, Accuracy=0.5192911554163304
Loss=0.6921216523357329, Accuracy=0.5202937694896423
Loss=0.6920359474782555, Accuracy=0.5210699334692223
Loss=0.6919847836830072, Accuracy=0.5218194770565082
Loss=0.6919466833376205, Accuracy=0.5225002509894877
Loss=0.6919065237964771, Accuracy=0.5232907309223628
Loss=0.6918831282802678, Accuracy=0.5240205985846303
Loss=0.6918596496662555, Accuracy=0.5247305866354453
Loss=0.6918480030990413, Accuracy=0.5252302781145235
Loss=0.6918263386058955, Accuracy=0.525697519574184
Loss=0.6918035694663003, Accuracy=0.5260344316622444
Loss=0.6917789162750309, Accuracy=0.5263206226790321
Loss=0.6917551364798894, Accuracy=0.5266128430646808
Loss=0.6917335394620657, Accuracy=0.5268571973144989
Loss=0.6917238229811545, Accuracy=0.5269983137676801
Loss=0.6917095350262754, Accuracy=0.5271617916445684
Loss=0.6916987781531015, Accuracy=0.52729341819

## Using Tensorflow API 

To use the tensorflow API one needs to replace the code pice in the paragraph "## Forward pass ##" by the following

In [41]:
with tf.variable_scope('rnn',initializer=xavier): # initialization important !

    cell = tf.contrib.rnn.BasicRNNCell(dim_state) 
    outputs, state = tf.nn.static_rnn(cell, X_list, initial_state=state0)
outputs = [extract(o) for o in outputs]

In [32]:
# one-pass to visualize the computational graph
with tf.Session() as sess:
    writer = tf.summary.FileWriter('./', sess.graph)
    sess.run(tf.global_variables_initializer())
    state_ = np.zeros((batch_size,dim_state))
    train_loss,state_,acc_,_=sess.run([loss,state,acc,optimizer],
                           feed_dict={X:np.ones(shape=(batch_size,num_step)),
                                      Y:np.ones(shape=(batch_size,num_step)),
                                      state0:state_})
    state0 = state_
    print(train_loss,acc_)
    writer.close()

0.4845379 1.0
