# PonderICOM: Joint Modeling of Accuracy and Speed in Cognitive Tasks
## Intro

In the context of behavioral data, we are interested in simultaneously modeling speed and accuracy. Yet, most advanced techniques in machine learning cannot capture such a duality of decision making data.


Building on [PonderNet](https://arxiv.org/abs/2107.05407) and [Variable Rate Coding](https://doi.org/10.32470/CCN.2019.1397-0), this notebook implements a neural model that captures speed and accuracy of human-like responses.

Given stimulus symbols as inputs, the model produces two outputs:

- Response symbol, which, in comparison with the input stimuli, can be used to measure accuracy).
- Halting probability ($\lambda_n$).

Under the hood, the model iterates over a ICOM-like component to reach a halting point in time. Unlike DDM and ICOM models, all the parameters and outcomes of the current model *seem* cognitively interpretable.

### Additional resources

- [ICOM network model](https://drive.google.com/file/d/16eiUUwKGWfh9pu9VUxzlx046hQNHV0Qe/view?usp=sharinghttps://drive.google.com/file/d/16eiUUwKGWfh9pu9VUxzlx046hQNHV0Qe/view?usp=sharing)


## Problem setting

### Model
Given input and output data, we want to learn a supervised model of the function $X \to y$ as follows:

$
f: X,h_n \mapsto \tilde{y},h_{n+1}, \lambda_n
$

where $X$ and $y$ denote stimulus and response symbols, $\lambda_n$ denotes halting probability at time $n$, and $h_{n}$ is the latent state of the model. The learninig continious up to the time point $N$.

For the brevity and compatibility, both data are one-hot encoded.


### Input

One-hot encoded symbols.

### Output

One-hot encoded symbols.

### Criterion

L = L_cross_entropy + L_halting

In [1]:
%reload_ext autoreload
%autoreload 3

# Setup and imports
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, random_split

from tqdm import tqdm

from sklearn.metrics import accuracy_score

import matplotlib.pyplot as plt
import seaborn as sns; sns.set()

from torch.utils.tensorboard import SummaryWriter
# import tensorboard as tb
# import tensorflow as tf
# tf.io.gfile = tb.compat.tensorflow_stub.io.gfile #FIX storing embeddings using tensorboard


from cogponder.nback_dataset import NBackDataset

In [244]:
# produce a train of spikes and store timestamps of each spike in `spike_timestamps`.

n_total_timesteps = 20

# method 1: shuffle timesteps

# method 2: exponential isi -> timestamps
# isi = np.random.exponential(1 / rate, n_spikes)
# spike_timestamps = np.cumsum(isi)

# method 3: homogenous spikes -> timestamps
# spike_timestamps = stats.uniform.rvs(loc=0, scale=max_duration_in_sec, size=n_spikes)

## Mock data

In [4]:
# generate mock n-back data

n_subjects = 2
n_trials = 20
n_stimuli = 6

dataset = NBackDataset(n_subjects, n_trials, n_stimuli)

X, y, accuracies, response_times = dataset[0]

# DEBUG
# X.shape, y.shape, accuracies.shape, response_times.shape

In [82]:
class ICOM(nn.Module):
    def __init__(self, n_inputs, n_channels, n_outputs):
      super(ICOM, self).__init__()

      self.n_inputs = n_inputs
      # encode: x -> sent_msg
      self.encode = nn.Sequential(
        nn.Linear(n_inputs, n_channels, bias=False)
      )
      
      # transmit: sent_msg -> rcvd_msg
      self.transmit = nn.RNN(n_channels, n_channels, batch_first=False)

      # decode: rcvd_msg -> response
      self.decode = nn.Sequential(
        nn.Linear(n_channels, n_outputs, bias=False),
        nn.Softmax(dim=2)
      )

    def forward(self, x, h=None):

      # shapes:
      #   X: seq_size, batch_size, input_size
      #   H: 1, batch_size, hidden_size
      #   Y: seq_size, batch_size, output_size

      batch_size = x.size(0)
      
      if h == None:
        h = self.init_hidden(batch_size)

      msg = F.one_hot(x, num_classes=self.n_inputs).type(torch.float)
      msg = self.encode(msg)
      msg = msg.transpose(0, 1)  # reshape for RNN
      msg, h = self.transmit(msg, h)
      msg = msg.transpose(0, 1)  # reshape for linear layer
      y = self.decode(msg)
      y = y[:,-1,:]  # last time step
      y = y.argmax(dim=1)
      return y, h

    def init_hidden(self, batch_size):
      h0 = torch.zeros(1, batch_size, n_stimuli)
      return h0

# DEBUG
# model = ICOM(n_inputs=n_stimuli+1, n_channels=n_stimuli, n_outputs=2)
# y_pred, h = model(X)

In [118]:
class PonderNet(nn.Module):
  def __init__(self, n_inputs, n_embeddings, n_outputs, max_steps):
    super(PonderNet, self).__init__()

    self.n_embeddings = n_embeddings
    self.n_outputs = n_outputs
    self.max_steps = max_steps

    self.output_node = ICOM(n_inputs, n_embeddings, n_outputs)

    # the halting node predicts the probability of halting conditional on not having halted before. It exactly computes the overall probability of halting at each step as a geometric distribution.
    self.halt_node = nn.Sequential(
      nn.Linear(n_embeddings, 1),
      nn.Sigmoid()
    )

    # loss:  we don’t regularize PonderNet to explicitly minimize the number of computing steps, but incentivize exploration instead. The pressure of using computation efficiently happens naturally as a form of Occam’s razor.

  def step(self, x, h, n):
    """A single pondering step.

    Args:
    -------
    x: current input of shape (batch_size, n_inputs)
    h: previous hidden state of shape (batch_size, n_embeddings, n_xxx)

    Returns
    -------
    lambda_n : float
        probability of the "continue->halt" transition
    """

    batch_size = x.shape[0]

    y_n, h = self.output_node(x, h)

    if n == self.max_steps:
      lambda_n = torch.ones((batch_size,))
    else:
      lambda_n = self.halt_node(h).squeeze()
    
    return y_n, h, lambda_n


  def forward(self, x):

    batch_size = x.shape[0]

    h = torch.zeros(1, batch_size, self.n_embeddings)
    p_halt = torch.zeros((batch_size, 1, 1,))
    p_continue = torch.ones((batch_size, 1, 1,))

    ys = []
    ps = []
    lambdas = []

    halt_step = torch.zeros((batch_size,)) # stopping step

    for n in range(1, self.max_steps + 1):

      y_n, h, lambda_n = self.step(x, h, n)

      if n == self.max_steps:
        halt_step = torch.empty((batch_size,)).fill_(n)
      else:
        _halt_step_dist = torch.distributions.Geometric(lambda_n / 5)
        halt_step = torch.maximum(_halt_step_dist.sample(), halt_step)

      p_halt = p_continue * lambda_n # p_halt = (1-p)p
      p_continue = p_continue * (1 - lambda_n) # update

      ys.append(y_n)
      lambdas.append(lambda_n)
      ps.append(p_halt)

      if (halt_step <= n).all():
        break

    # prepare outputs of the forward pass
    halt_step_idx = halt_step.reshape(-1).to(torch.int64) - 1
    ys = torch.stack(ys).transpose(0, 1)
    lambdas = torch.stack(lambdas).transpose(0, 1)

    # FIXME p_halt is not correct
    p_halts = torch.stack(ps).transpose(0, 1).squeeze()

    y = ys[0, halt_step_idx].squeeze()

    return ys, lambdas, y, p_halts, halt_step_idx

# DEBUG
model = PonderNet(n_stimuli+1, n_stimuli, 2, 100)
_, _, y_pred, _, halt_steps = model(X)

In [None]:

# split params
train_size = int(n_trials * .8)
test_size = n_trials - train_size

# training parrms
n_epoches = 100

logs = SummaryWriter()

model = PonderNet(n_stimuli+1, n_stimuli, n_stimuli+1, 100)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()

dataset = TensorDataset(torch.tensor(X), torch.tensor(y)-1)
train_subset, test_subset = random_split(dataset, lengths=(train_size,test_size))

X_train, y_train = dataset[train_subset.indices]
X_test, y_test = dataset[test_subset.indices]

for epoch in tqdm(range(n_epoches), desc='Epochs'):

  for X_batch, y_batch in DataLoader(train_subset, batch_size=3):
    model.train()
    optimizer.zero_grad()
    print(X_batch)
    # ys, lambdas, y_pred, p_halts, halt_steps = model(X_batch)

  model_accuracy = accuracy_score(y_batch, y_pred)
  logs.add_scalar('accuracy/train', model_accuracy, epoch)  

  loss = criterion(y_pred.unsqueeze(0), y_batch)
  
  logs.add_scalar('loss/train', loss, epoch)

  loss.backward()
  optimizer.step()

  # model.eval()
  # with torch.no_grad():
  #   _, _, y_pred, _ = model(X_test)
  #   loss = criterion(y_test, y_pred)
  #   logs.add_scalar('loss/test', loss.detach(), epoch)

# tensorboard --logdir=runs

In [None]:
model.eval()
_, _, y_pred, _, _ = model(dataset.tensors[0].unsqueeze(0))
# print(x.shape, y_m.shape, y_n.shape, is_halted.shape, p_m.shape, p_n.shape)
y_pred, y

In [None]:
# example code to decode a stimulus into multiple sequence (one per channel)

import torch
from torch import nn

n_inputs = 7
max_timestep = 10
n_channels = 5

X = torch.nn.functional.one_hot(torch.tensor(4), num_classes=n_inputs).type(torch.float)

decode = nn.Linear(n_inputs, n_channels * max_timestep)
out = decode(X).reshape((n_channels, max_timestep))

print(out.shape)

In [None]:
d = torch.distributions.Geometric(torch.tensor([0.3, .01]))

d.sample()