<h1><center>Adaptive Computation Time for Recurrent Neural Networks</center></h1>


<h13><center>Alex Graves 2017 </center></h13>
<h13><center>DeepMind</center></h13>

<h13><center> Presentation to Enlitic by Carson Lam  </center></h13>

[Arxiv link](https://arxiv.org/pdf/1603.08983.pdf) 

[GitHub](https://github.com/zphang/adaptive-computation-time-pytorch)

## Background: Vanilla Recurrent Neural Network (RNN)

<img src="https://miro.medium.com/max/1168/1*SBeZvxdxhnL5zqvKfyLGRQ.png" width =500>

<img src="https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcQLBJa0HXfrvlgQIZgNkzHWRquWdWAf-gGws6UwUnHFqzk8DSLH&s">

vertical dotted lines represent the separation between time steps 

 s_t is the RNN sate state at timepoint t. s_t = S(s_t - 1, x_t) is a state transition operation aka one layer of a recurrent neural network. s_t is passed to the next timestep t. 

s_t is often given the variable h_t for hidden state in other writings, here h is reserved for another value

<img src="http://karpathy.github.io/assets/rnn/charseq.jpeg" width =500>

## One example of a state transition operation is the Gated Recurrent Unit

$$ z_{t} = \sigma(W^{z} x_{t} + U^{z} s_{t-1}) $$

z (update gate): when z ~ 1, the previous state is copied over to the current state,
when z ~ 0 the updated state replaces the previous state 

$$ r_{t} = \sigma(W^{r} x_{t} + U^{r} s_{t-1}) $$

r (reset gate): the closer r is to 1, the more the prevous state is used to inform the updated state

$$  \tilde{s_{t}} = tanh(W^{s} x_{t} + r_{t} \odot U^{s} s_{t-1} ) $$

$$ s_{t} = z_{t} \odot s_{t-1} + (1 - z_{t}) \odot \tilde{s_{t}} = S(s_{t - 1}, x_{t}) $$




## Multi-Layer RNN

<img src="https://blog.exxactcorp.com/wp-content/uploads/2019/01/0_iLv3Tisjx68QrYU7-257x300.png" width=300 >

In multi-layer RNNS each of the green states is produced by another set of weights. 

$$ s^{n}_{t} =  S(s^{n}_{t-1}, s^{n-1}_{t}) $$

Where the first layer takes as input x, and the deeper layers (up direction) takes the output of the layer below it as input.

$$ s^{0}_{t} = x_{t}$$

The depth is fixed so the input is transformed exactly 4 times before it is transformed into the output at each time step t, 3 RNN layers and 1 output layer. 

## Adaptive Computation Time (ACT)

In Adaptive Computation Time (ACT), n is for the number of times the same weights are applied repeatedly to update the state. 

The first of these updates takes as input the state from the previous timestep and the input with binary flag = 0. `[-0.1, 0.5, 1.2, 0]` Subsequent updates take as input the state from the previous update of the same timestep and the input with binary flag = 1, `[-0.1, 0.5, 1.2, 1]`, so that the state transition operation can distinguish between repeat inputs at a subsequent timestep from repeat inputs at a subsequent update.

<img src="https://miro.medium.com/max/729/1*4pOFTSf6clGBToAriB4i5w.png" width=400>

n is not for the number of layers in a multi-layer RNN where each transformation by a layers weights are performed exactly once for timestep and the weights are not shared between layers. 

For each timestp t, the input is updated for a variable number of updates n. The number of updates, n, is a learned function of the state at time t, and in turn the input at time t. 

That learned function is the halting unit, h, not to be confused with hidden state, h is a scalar.

<img src="https://miro.medium.com/max/489/1*f4PRPGPSnzgEmhdgWd86ow.png" width=300>

The halting unit between 0 and 1, thus the sigmoid, and is accumulated at each state update until the total sum of halthing units exceeds 1 - epsilon.

Suppose epsilon = 0.1, then 1 - epsilon = 0.9, then the total number of updates at time t will be N(t) where

$$ N(t) = min\{ n' : \sum_{n=1}^{n'} h^{n}_{t} >= 1 - \epsilon \} $$

In the table representation of the halthing units produced at each update 1 thru 4, N(t) = 3, since 0.1 + 0.3 + 0.5 = 0.9, so this timestep would undergo 3 updates. The 4th update wouldve produced a halthing unit of 0.5, but this update is never performed since the updates are stopped once the sum of halting units equals or exceeds 1 - epsilon.

<table width=200>
    
<tr>
<th>s</th>
<th>h</th>
</tr>

<tr>
<td>4</td>
<td>0.5</td>
</tr>

<tr>
<td>3</td>
<td>0.5</td>
</tr>

<tr>
<td>2</td>
<td>0.3</td>
</tr>

<tr>
<td>1</td>
<td>0.1</td>
</tr>

</table>

The diagram below is the ACT version of the RNN diagram we have already seem. It leaves out one informative concept though. That is how the final state s and ouputs y are determined for timestep t. They are mean-field or weighted sum of the intermediate states and ouputs, weighted by p, the halting probabilities. The halting probabilities p  are mostly the same as the halting units  h at each update, except the last update where n = N(t).

<img src="https://miro.medium.com/max/888/1*SivgPX_-tcrlTuOTs2toEQ.png" width=400>

$$
    p^{n}_{t}=\left\{
                \begin{array}{ll}
                  R(t),\, if \; n = N(t)\\
                  h^{n}_{t} \, otherwise
                \end{array}
              \right.
$$

$$ R(t) = 1 -  \sum_{n=1}^{N(t)-1} h^{n}_{t} $$

So in our example, the halting probabilities would be 0.1, 0.3, 0.6, since the last halting unit was 0.5 it pushed the accumulation of h above the 1 - epsilon threshold, so it's halting probability gets set to the remainder R(t). This way, the halting probabilities p are a probability distribution over the all the updates for that timestep t. The authors call this process of computing updates, pondering.

<img src="https://miro.medium.com/max/1972/1*5dULqBM2KKGlQTHrKCeR3Q.png" width=600>


##  Limiting Computation Time

If the loss function,  L(x,y), ie binary cross entropy for the parity task, only rewards correct predictions and punishes wrong predictions, the ACT-RNN is encouraged to update the state as long as possible, to ponder infinitely long.

We want the network to ponder longer for timesteps when this is beneficial, but also not to waste computation when a small number of updates can predict correctly. 

To encourage efficient use of computation, the *Ponder Cost* is added to the loss function and scaled by the *time penalty* (*Tau*)

$$\hat{L}(x,y) = L(x,y) + \tau P(x) $$

$$ P(x) = \sum^{T}_{t=1} N(t) + R(t) $$

The example implementation simplifies the Ponder Cost to:

$$ P(x) = \sum^{T}_{t=1} \sum^{N(t)}_{n=1} - h^{n}_{t}  $$

Which encourages the network to produce larger halting units to bring down the ponder cost.

In early training, before this loss has been allowed to train the network to be parsimonious when pondering/allowing more updates, the halting bias b_n can be initialized to a positive scalar. 

$$ h^{n}_{t} = \sigma(W_{h}s^{n}_{t} + b_{n}) $$

Also, a hard limit on the number of updates, M, is used during implementation.


## Implementation

This is not an exhaustive implementation of ACT, it is a distilled version of this [GitHub](https://github.com/zphang/adaptive-computation-time-pytorch) that is complete enough only to concretely demonstrate how ACT processes the input, infers and is trained

The task is the parity task

<img src="https://media.arxiv-vanity.com/render-output/1687679/fig/parity_example.png" width=200>

Parity training Example. Each sequence consists of a single input and target vector. Similar to the example diagram we use a vector size 8, but in the paper the vector size was 64. To find the parity, just count to number of 1.'s, ignore the -1.'s and 0.'s, if there is an odd number of 1.'s the parity is 1, if there is an even number of 1.'s, the parity is 0.

In [3]:
import pathlib
import numpy as np

import torch
from torch import nn
from torch.autograd import Variable
import torch.optim as optim
import torch.utils.data

from utils import ParityDataManager, test_epoch,  maybe_cuda_var, bool_to_idx

%load_ext autoreload
%autoreload 2

In [4]:
class Config:
    def __init__(self):
        pass
    
config = Config()
config.input_length = 8 
config.test_percentage = 0.1 
config.batch_size = 2 
config.num_epochs = 10
config.model_save_interval = 10
config.model_save_path = "models/parity/act_cpu" #"models/parity/act"
config.cuda =  False #True #
config.act_ponder_penalty = 0.0001
config.train_log = True
config.train_log_interval = 10
config.learning_rate = 0.001

In [13]:
class ACT(nn.Module):
    
    def __init__(self, rnn_size, input_size, 
                 max_ponder=100, epsilon=0.01):
        
        super(ACT, self).__init__()

        self.max_ponder = max_ponder
        self.epsilon = epsilon
        self.input_size = input_size
        self.rnn_size = rnn_size
        
        self.rnn = nn.RNNCell(
            # input_size +1 for binary flag
            input_size=self.input_size + 1,  
            hidden_size=self.rnn_size,
        )
        
        self.ponder_linear = nn.Linear(self.rnn.hidden_size, 1)
        self.fc1 = nn.Linear(self.rnn_size, 1)
        self.reset_parameters()
        
    def forward(self, input_, compute_ponder_cost=True):
        """
        This forward pass takes an input of shape (batch_size, sequence length, input_dim)
        A boolean that is True during training and False during testing. This is used
        to pass the ponder cost to the overall loss function and incentivize efficient use
        of compute. 
        
        It puts the mean-field state through an affine layer to produce a logit for each sample
        called all_s which is shape (batch_size), it also outputs a list with keys
        "ponder_cost", "ponder_times" which are the scalar ponder costs, shape (batch_size),
         and a list of the number of updates performed for each sample respectively
        """
        #(batch_size, time, input_dim)->(time, batch_size, input_dim)
        input_ = input_.transpose(0, 1) 
        
        time_size, batch_size, input_dim_size = input_.size()
    
        # The initial hidden state s_0 
        s = Variable(input_.data.new(batch_size, # (batch_size, hidden_size)
                     self.rnn.hidden_size).zero_())
            
        selector = input_.data.new(batch_size).byte() # indices, ie [0, 1] uint8 shape [2]
        
        s_list = []
        ponder_cost = 0
        ponder_times = []

        # For each timestep t: 1 thru T 
        for input_row in input_:
            
            accum_h = Variable(input_.data.new(batch_size).zero_())
            accum_s = Variable(input_.data.new( # vector of zeros (batch size,hidden_dim)
                       batch_size, self.rnn.hidden_size).zero_())

            selector = selector.fill_(1) #ones shape (batch size)

            step_count = Variable(input_.data.new(batch_size).zero_())

            input_row_with_flag = torch.cat([
                input_row,
                Variable(input_row.data.new(batch_size, 1).zero_())
            ], dim=1) # adds a 0 to the end of each input vector
            # for x_1 this last element is 0, but for all others it is =1

            if compute_ponder_cost:
                step_ponder_cost = Variable(input_.data.new(batch_size).zero_())

            # START LOOP for State Updates (1 thru N(t), per time step t)
            for act_step in range(self.max_ponder): # 100
                
                # [1,1,1]->[0,1,2], [1,0]->[0], [0,1,1]->[1,2]
                idx = bool_to_idx(selector)
                
                if compute_ponder_cost:
                    # incentivize large halting units 
                    step_ponder_cost[idx] = -accum_h[idx]
                    #print("step_ponder_cost", step_ponder_cost)

                # only update those hidden states for which the selector=1
                s[idx] = self.rnn(input_row_with_flag[idx], s[idx])
                
                # halting units from the state 
                h =  torch.sigmoid(self.ponder_linear(s[idx]).squeeze(1))
                accum_h[idx] += h # accumulate halting units 
                
                # halting probability, if accum_h is < 1, p = h, if the most recent h has
                # pushed accum_h over 1, set p to the remainder rather than h 
                p = h - (accum_h[idx] - 1).clamp(min=0) 

                # accumulate the mean-field of states weighted by the halting probability
                accum_s[idx] += p.unsqueeze(1) * s[idx]  

                step_count[idx] += 1 #keep track of total number of updates for this sample
                
                # prune the batch to include only samples that have not exceeded computational budget 
                selector = (accum_h < 1 - self.epsilon).data 
                if not selector.any(): # if all selectors == False, done processing s
                    break
                    
                #change last element of input to 1 if after first pondering update
                input_row_with_flag[:, input_dim_size] = 1 
                
                # END OF LOOP for State Updates (1 thru N(t), per time step t)
            
            # at each timestep, how many times was each sample in the batch processed 
            ponder_times.append(step_count.data.cpu().numpy())
            
            if compute_ponder_cost:
                ponder_cost += step_ponder_cost
        
            # append once for each time step (only once in parity)
            s_list.append(accum_s)
            
        # END OF LOOP for sequence length T 
        
        all_s = torch.stack(s_list)
        s = s.unsqueeze(0)
        all_s = all_s.transpose(0, 1)
        
        ponder_cost = {"ponder_cost": ponder_cost, "ponder_times": ponder_times}
        all_s = self.fc1(all_s).squeeze(1).squeeze(1)

        return all_s, ponder_cost 

    def reset_parameters(self):
        self.rnn.reset_parameters()
        self.ponder_linear.reset_parameters()
        self.ponder_linear.bias.data.fill_(1)

In [22]:
model = ACT(rnn_size = 16, input_size = config.input_length)

if config.cuda:
    model = model.cuda()

data_manager = ParityDataManager
test_data_loader = data_manager.create_dataloader(config, mode="test")

for batch_idx, (x, y) in enumerate(test_data_loader):
    
    x_var = maybe_cuda_var(x, cuda=config.cuda)
    y_var = Variable(y, requires_grad=False)
    if config.cuda:
        y_var = y_var.cuda()
    
    print("x", x_var, x_var.shape)
    print("y", y_var, y_var.shape)
    print("____________________________________________")
    break
    
y_hat, ponder_dict = model(x_var)
print("____________________________________________")
print("y_hat", y_hat, "ponder_dict", ponder_dict["ponder_times"])

x tensor([[[-1.,  1.,  1., -1., -1.,  0.,  0.,  0.]],

        [[-1., -1.,  0.,  0.,  0.,  0.,  0.,  0.]]]) torch.Size([2, 1, 8])
y tensor([ 0.,  0.]) torch.Size([2])
____________________________________________
____________________________________________
y_hat tensor([ 0.1588,  0.0086]) ponder_dict [array([2., 2.], dtype=float32)]


In [17]:
loss_func = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)

for epoch in range(1, config.num_epochs + 1):
    train_data_loader = data_manager.create_dataloader(config)
    test_data_loader = data_manager.create_dataloader(config, mode="test")

    for batch_idx, (x, y) in enumerate(train_data_loader):
        
        x_var = maybe_cuda_var(x, cuda=config.cuda)
        y_var = Variable(y, requires_grad=False)
        
        if config.cuda:
            y_var = y_var.cuda()

        y_hat, ponder_dict = model(x_var) # Forward pass 
        
        loss = loss_func(y_hat, y_var) #BCE Loss
        if ponder_dict: # add Ponder Cost to BCE Loss 
            loss += (
                config.act_ponder_penalty * ponder_dict["ponder_cost"].mean()
            )

        optimizer.zero_grad()
        loss.backward() # backprop
        optimizer.step() # update weights

        if config.train_log and batch_idx % config.train_log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(x), len(train_data_loader.dataset),
                100. * batch_idx / len(train_data_loader), loss.item())
            )
    
    test_result = test_epoch(
        config=config, model=model,
        data_loader=test_data_loader,
        epoch=epoch,
    )

Epoch: 1, Average loss: 0.2680, Accuracy: 13/16 (81%), PT: 18.2
Epoch: 2, Average loss: 0.0090, Accuracy: 16/16 (100%), PT: 9.8
Epoch: 3, Average loss: 0.1034, Accuracy: 14/16 (88%), PT: 27.2
Epoch: 4, Average loss: 0.1550, Accuracy: 13/16 (81%), PT: 30.8
Epoch: 5, Average loss: 0.0862, Accuracy: 14/16 (88%), PT: 12.4
Epoch: 6, Average loss: 0.0450, Accuracy: 15/16 (94%), PT: 12.0
Epoch: 7, Average loss: 0.0332, Accuracy: 15/16 (94%), PT: 20.2
Epoch: 8, Average loss: 0.1200, Accuracy: 14/16 (88%), PT: 9.9
Epoch: 9, Average loss: 0.0677, Accuracy: 15/16 (94%), PT: 20.7
Epoch: 10, Average loss: 0.0306, Accuracy: 15/16 (94%), PT: 19.6
Saving checkpoint to models/parity/act_cpu/epoch_10.pt


In [20]:
model = ACT(rnn_size = 16, input_size = config.input_length)

if config.cuda:
    model = model.cuda()

epoch = 10
model_save_path = pathlib.Path(config.model_save_path)
model_save_file_path = (
    model_save_path / f"epoch_{epoch}.pt"
)
model.load_state_dict(torch.load(model_save_file_path))

data_manager = ParityDataManager
test_data_loader = data_manager.create_dataloader(config, mode="test")

for batch_idx, (x, y) in enumerate(test_data_loader):
    
    x_var = maybe_cuda_var(x, cuda=config.cuda)
    y_var = Variable(y, requires_grad=False)
    if config.cuda:
        y_var = y_var.cuda()
    
    print("x", x_var, x_var.shape)
    print("y", y_var, y_var.shape)
    print("____________________________________________")
    break
    
y_hat, ponder_dict = model(x_var)
print("____________________________________________")
print("y_hat", y_hat, "ponder_dict", ponder_dict["ponder_times"])

x tensor([[[-1.,  1.,  1.,  0.,  0.,  0.,  0.,  0.]],

        [[-1.,  1.,  1.,  1.,  1.,  1.,  0.,  0.]]]) torch.Size([2, 1, 8])
y tensor([ 0.,  1.]) torch.Size([2])
____________________________________________
____________________________________________
y_hat tensor([-5.2169,  2.4202]) ponder_dict [array([11.,  4.], dtype=float32)]


## Parity Task 

The RNN used was a single layer vanilla RNN with hidden size 128

<img src="https://media.arxiv-vanity.com/render-output/1687679/fig/parity_bar.png">

Parity Error Rates. Bar heights show the mean error rates for different time penalties at the end of training. The error bars show the standard error in the mean.

<img src="https://media.arxiv-vanity.com/render-output/1687679/fig/parity_plot.png">

Parity Learning Curves and Error Rates Versus Ponder Time. Left: Errors for each 
τ value. ‘Iterations’ is the number of gradient updates per asynchronous worker. Right: Small circles represent individual runs after training is complete, large circles represent the mean over 20 runs for each  τ value. ‘Ponder’ is the mean number of computation steps per input timestep (minimum 1). The black dotted line shows the mean error for the networks without ACT. The height of the ellipses surrounding the mean values represents the standard error over error rates for that value of τ, while the width shows the standard error over ponder times.

<img src="https://media.arxiv-vanity.com/render-output/1687679/fig/parity_difficulty.png">

Parity Ponder Time and Error Rate Versus Input Difficulty. Faint lines are individual runs, bold lines are means over 20 networks. ‘Difficulty’ is the number of bits in the parity vectors, with a mean over 1,000 random vectors used for each data-point.

## Logic Task 

The RNN used was a single layer LSTM with hidden size 128

<img src="https://media.arxiv-vanity.com/render-output/1687679/fig/logic_example.png">

The first input sequence has input bits b_0 = F, b_1 = T. The first gate these two bits are operated on is the NOR gate, onehot encoded by 100. The output of this gate is b_2 = F, b_1 = F and b_2 = F are operated on by the next gate, 010, which is the Xq gate. The Binary Truth Table shows that for Xq, F,F->F therefore b_3 = F and thus the target is 0 for this first input vector. 

<img src="https://media.arxiv-vanity.com/render-output/1687679/fig/logic_bar.png">


<img src="https://media.arxiv-vanity.com/render-output/1687679/fig/logic_plot.png">

network reaches a minimum sequence error rate of around 0.2 without ACT (compared to 0.5 for random guessing)

<img src="https://media.arxiv-vanity.com/render-output/1687679/fig/logic_difficulty.png">

Logic Ponder Time and Error Rate Versus Input Difficulty. The example pretends there are 2 gates in sequence and 3 choices for each gate. The experiment in the paper has 10 gates and 10 choices of gates for each. ‘Difficulty’ is the number of logic gates in each input vector

## Addition Task

<img src="https://media.arxiv-vanity.com/render-output/1687679/fig/addition_example.png" width=300>

The example performs the addition operations: 1038 + 392 = 1430, 68450 + 1430 = 69880

Each of those digits is actually onehot encoded and concatenated into the input vector. Since there are 5 digits with and 10 choices for each digit, the input vector is size 50.

<img src="https://media.arxiv-vanity.com/render-output/1687679/fig/addition_bar.png">

<img src="https://media.arxiv-vanity.com/render-output/1687679/fig/addition_plot.png">

<img src="https://media.arxiv-vanity.com/render-output/1687679/fig/addition_difficulty.png">
The relationship between the ponder time and the number of digits (Difficulty) was approximately linear for most of the ACT networks
<img src="https://media.arxiv-vanity.com/render-output/1687679/fig/addition_sequence_wide_lines.png">

The grey lines show the total number of digits in the two numbers being summed at each step; this appears to give a rough lower bound on the ponder time, suggesting an internal algorithm that is approximately linear in the number of digits.

## Sort Task

Here the size 2 vectors, ie `[0, -0.03]`, are fed one at a time. The first element becomes 1 when it is the last value to be sorted `[1, 0.55]`. Up to 15 of these sequences are fed before the outputs start to be considered. Thereafter, the outputs are size 15 softmax classifiers to predict the indices of the 15 inputs in ascending order. 

<img src="https://media.arxiv-vanity.com/render-output/1687679/fig/sort_example.png" width=500>

<img src="https://media.arxiv-vanity.com/render-output/1687679/fig/sort_bar.png">

<img src="https://media.arxiv-vanity.com/render-output/1687679/fig/sort_plot.png">

<img src="https://media.arxiv-vanity.com/render-output/1687679/fig/sort_difficulty.png">

<img src="https://media.arxiv-vanity.com/render-output/1687679/fig/sort_sequence_3_black.png">

There is a large spike in ponder time near (though not precisely at) the end of the input sequence, presumably when the majority of the sort comparisons take place. The spike is much higher for the longer sequences.

## Wikipedia Character prediction

Language modeling was performed on raw unicode text. LSTM networks were used with a single layer of 1500 cells and a size 256 softmax classification layer to predict the next n+1 byte in the sequence

<img src="https://media.arxiv-vanity.com/render-output/1687679/fig/enwik_bar.png">
Error rates are fairly similar with and without ACT
<img src="https://media.arxiv-vanity.com/render-output/1687679/fig/enwik_plot.png">
Learning curves suggest that the ACT networks are somewhat more data efficient. The amount of ponder per input is much lower than for the other problems, suggesting that the advantages of extra computation were slight for this task.
<img src="https://media.arxiv-vanity.com/render-output/1687679/fig/enwik_sequence.png">
Ponder Time, Prediction loss and Prediction Entropy During a Wikipedia Text Sequence.
Character prediction networks trained with ACT consistently pause at spaces between words, and pause for longer at ‘boundary’ characters such as commas and full stops. We speculate that the extra computation is used to make predictions about the next ‘chunk’ in the data (word, sentence, clause)
<img src="https://media.arxiv-vanity.com/render-output/1687679/fig/enwik_sequence_2.png">
Ponder Time, Prediction loss and Prediction Entropy During a Wikipedia Sequence Containing XML Tags. Again ACT is an effective detector of non-text transition markers such as the opening brackets of XML tags, ACT does not increase computation time during random or fundamentally unpredictable sequences like the two ID numbers

## Conclusion 

ACT allows RNNS to dynamically adapt the amount of computation it uses to the demands of the data. An experiment on real data suggests that the allocation of computation steps learned by ACT can yield insight into both the structure of the data and the computational demands of predicting it