<a href="https://colab.research.google.com/github/charlesdgburns/NM_TinyRNN/blob/main/notebooks/nm_tinyrnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction

Let's use this as a learning notebook about gated recurrent units.

We will fit these to sequential behavioural decision making later.

We are taking inspiration from the following blogpost:
https://medium.com/data-science/building-a-lstm-by-hand-on-pytorch-59c02a4ec091

**The information flow in a gated recurrent unit**

A gated recurrent unit has information flowing from the inputs $^{(1)}$ $x_{t-1}$  and the past hidden $h_{t-1}$, which is gated via 'reset' and 'update' gates $r_t$ and $z_t$ before giving the final output $h_t$. The update gate decides whether or not to overwrite a long-term memory with inputs, while the

The gated recurrent unit allows the recurrent unit to persist its state and ignore its inputs.


$$ r_t = \sigma(W_{ir}x_{t-1}+b_{ir}+W_{hr}h_{t-1}+b_{hr}) \qquad \text{(reset)}$$

$$ z_t = \sigma(W_{iz}x_{t-1}+b_{iz}+W_{hz}h_{t-1}+b_{hz}) \qquad\text{(update)}$$

$$ n_t = \tanh(W_{in}x_{t-1}+b_{in}+r_{t}\odot (W_{hn}h_{t-1}+b_{hn})) \qquad\text{(new)} $$

$$ h_t = (1-z_t)\odot n_t +z_t \odot h_{t-1} \qquad\text{(gated reccurent output)}$$

*footnotes*
(1): inputs at time $t$ is recent experience

Note that GRU's are markovian - their current state ($h_t$) can be determined entirely from its previous state ($h_{t-1}$) and inputs ($x_{t-1}$).

In [None]:
import torch
import torch.nn as nn
import math
import numpy as np

class ManualGRU(nn.Module):
  def __init__(self,input_size,hidden_size, batch_first = False):
    self.sigmoid = torch.nn.Sigmoid()
    self.tanh = torch.nn.Tanh()

    self.W_from_in = nn.Parameter(torch.Tensor(input_size, hidden_size*3))
    self.W_from_h = nn.Parameter(torch.Tensor(hidden_size, hidden_size*3))
    self.bias = nn.Parameter(torch.Tensor(hidden_size*6))
    self.init_weights()

  def init_weights(self):
    stdv = 1.0 / math.sqrt(self.hidden_size)
    for weight in self.parameters():
        weight.data.uniform_(-stdv, stdv)
  def forward(self, inputs, init_states = None):
    ''' inputs are a tensor of shape (batch_size, sequence_size, input_size)
        outputs are tensor of shape (batch_size, sequence_size, hidden_size)'''

    batch_size, sequence_size, _ = inputs.size
    hidden_sequence = []
    if init_states is None:
      h_past = torch.zeros(batch_size, self.hidden_size).to(inputs.device)
    else:
      h_past = init_states

    for t in range(sequence_size):
      x_past = inputs[:,t,:] #(n_batch,input_size)
      #for computational efficiency we do two matrix multiplications and then do indexing further down:
      from_input = x_past@self.W_from_in + self.bias[:3]  #(n_batch,n_hidden)
      from_hidden = h_past@self.W_from_h + self.bias[3:]  #(n_batch,n_hidden)

      r_t =self.sigmoid(from_input[0]+from_hidden[0]) #(n_batch,n_hidden), ranging from 0 to 1
      z_t = self.sigmoid(from_input[1]+from_hidden[1]) #(n_batch,n_hidden), ranging from 0 to 1
      n_t = self.tanh(from_input[2]+r_t*(from_hidden[2])) #(n_batch,n_hidden)
      h_past = (1-z_t)*n_t + z_t*h_past #(n_batch,hidden_size) #NOTE h_past is tehnically h_t now, but in the next for-loop it will be h_past. ;)
      hidden_sequence.append(h_past.unsqueeze(0)) #appending (1,n_batch,n_hidden) to a big list.
    hidden_sequence = torch.cat(hidden_sequence, dim=0) #(n_sequence, n_batch, n_hidden) gather all inputs along the first dimenstion
    hidden_sequence = hidden_sequence.transpose(0, 1).contiguous() #reshape to batch first (n_batch,n_seq,n_hidden)
    return hidden_sequence, h_past #this is standard in Pytorch, to output sequence of hidden states alongside most recent hidden state.

## for reference, look at a hand-crafted LSTM from the following tutorial #



class ManualSTM(nn.Module):
    def __init__(self, input_sz, hidden_sz, device):
        super().__init__()
        self.device=device
        self.input_sz = input_sz
        self.hidden_size = hidden_sz
        self.W = nn.Parameter(torch.Tensor(input_sz, hidden_sz * 4))
        self.U = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz * 4))
        self.bias = nn.Parameter(torch.Tensor(hidden_sz * 6))
        self.init_weights()


    def init_weights(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv).to(self.device)

    def forward(self, x,
                init_states=None):
        """Assumes x is of shape (batch, sequence, feature)"""
        bs, seq_sz, _ = x.size()
        hidden_seq = []
        if init_states is None:
            h_t, c_t = (torch.zeros(bs, self.hidden_size).to(x.device),
                        torch.zeros(bs, self.hidden_size).to(x.device))
        else:
            h_t, c_t = init_states

        HS = self.hidden_size
        for t in range(seq_sz):
            x_t = x[:, t, :]
            # batch the computations into a single matrix multiplication
            gates = x_t @ self.W + h_t @ self.U + self.bias
            i_t, f_t, g_t, o_t = (
                torch.sigmoid(gates[:, :HS]), # input
                torch.sigmoid(gates[:, HS:HS*2]), # forget
                torch.tanh(gates[:, HS*2:HS*3]),
                torch.sigmoid(gates[:, HS*3:]), # output
            )
            c_t = f_t * c_t + i_t * g_t
            h_t = o_t * torch.tanh(c_t)
            hidden_seq.append(h_t.unsqueeze(0))
        hidden_seq = torch.cat(hidden_seq, dim=0)
        # reshape from shape (sequence, batch, feature) to (batch, sequence, feature)
        hidden_seq = hidden_seq.transpose(0, 1).contiguous()
        return hidden_seq, (h_t, c_t)


# NM-RNN

[Costacura et al. (2024)](https://openreview.net/pdf?id=HbIBqn3grD) introduced NM-RNNs to bridge a gap between more standard RNNs today and biophysical models.

Here we rewrite the equations (1) to (4) in Costacura et al. (2024) more similarly to the standard RNN notation above.

We have inputs $x_{t-1}$ and our past hidden state $h_{t-1}$ which we want to integrate to get a new hidden state $h_t$. However, we want to selectively change (by a gain) the weights of our recurrent network depending on a neuromodulation signal $s(z(t))$.

We therefore have a coupled network system, starting from a subnetwork state $z(t)$. In discretised terms:
$$ \tau_{z}  z_t = W_{zz} \phi(z_t)+W_{iz} x_{t-1} \qquad (1) $$

$$ \tau_{x} h_{t-1} = W_x(z_t)\cdot\phi(h_{t-1})+W_{ih} \qquad (2)$$

$$s(z(t)) = \sigma(W_s z_t+ b_s) \qquad W_x(z_t)=\sum_{k=1}^K s_k(z_t)\mathcal{l}_k r_k^T \qquad (4)$$

Note that instead of a low-rank recurrent weight component, we want a tiny RNN, so we could modulating the weights associated with a given unit. Now the 'dynamic modes' are not the low ranks of a large network, but the activity of single units.

In [None]:

class ManualNMRNN(nn.Module):
  def __init__(self,input_size,hidden_size, batch_first = False):
    self.sigmoid = torch.nn.Sigmoid()
    self.tanh = torch.nn.Tanh()

    self.W_from_in = nn.Parameter(torch.Tensor(input_size, hidden_size*2))
    self.W_from_h = nn.Parameter(torch.Tensor(hidden_size, hidden_size*3))
    self.bias = nn.Parameter(torch.Tensor(hidden_size*6))
    self.init_weights()

  def init_weights(self):
    stdv = 1.0 / math.sqrt(self.hidden_size)
    for weight in self.parameters():
        weight.data.uniform_(-stdv, stdv)
  def forward(self, inputs, init_states = None):
    ''' inputs are a tensor of shape (batch_size, sequence_size, input_size)
        outputs are tensor of shape (batch_size, sequence_size, hidden_size)'''

    batch_size, sequence_size, _ = inputs.size
    hidden_sequence = []
    if init_states is None:
      h_past = torch.zeros(batch_size, self.hidden_size).to(inputs.device)
    else:
      h_past = init_states

    for t in range(sequence_size):
      x_past = inputs[:,t,:] #(n_batch,input_size)
      #for computational efficiency we do two matrix multiplications and then do indexing further down:
      from_input = x_past@self.W_from_in + self.bias[:3]  #(n_batch,n_hidden)
      from_hidden = h_past@self.W_from_h + self.bias[3:]  #(n_batch,n_hidden)

      r_t =self.sigmoid(from_input[0]+from_hidden[0]) #(n_batch,n_hidden), ranging from 0 to 1
      z_t = self.sigmoid(from_input[1]+from_hidden[1]) #(n_batch,n_hidden), ranging from 0 to 1
      n_t = self.tanh(from_input[2]+r_t*(from_hidden[2])) #(n_batch,n_hidden)
      h_past = (1-z_t)*n_t + z_t*h_past #(n_batch,hidden_size) #NOTE h_past is tehnically h_t now, but in the next for-loop it will be h_past. ;)
      hidden_sequence.append(h_past.unsqueeze(0)) #appending (1,n_batch,n_hidden) to a big list.
    hidden_sequence = torch.cat(hidden_sequence, dim=0) #(n_sequence, n_batch, n_hidden) gather all inputs along the first dimenstion
    hidden_sequence = hidden_sequence.transpose(0, 1).contiguous() #reshape to batch first (n_batch,n_seq,n_hidden)
    return hidden_sequence, h_past #this is standard in Pytorch, to output sequence of hidden states alongside most recent hidden state.
