# PonderICOM: Joint Modeling of Accuracy and Speed in Cognitive Tasks
## Intro

In the context of behavioral data, we are interested in simultaneously modeling speed and accuracy. Yet, most advanced techniques in machine learning cannot capture such a duality of decision making data.


Building on [PonderNet](https://arxiv.org/abs/2107.05407) and [Variable Rate Coding](https://doi.org/10.32470/CCN.2019.1397-0), this notebook implements a neural model that captures speed and accuracy of human-like responses.

Given stimulus symbols as inputs, the model produces two outputs:

- Response symbol, which, in comparison with the input stimuli, can be used to measure accuracy).
- Halting probability ($\lambda_n$).

Under the hood, the model iterates over a ICOM-like component to reach a halting point in time. Unlike DDM and ICOM models, all the parameters and outcomes of the current model *seem* cognitively interpretable.

### Additional resources

- [ICOM network model](https://drive.google.com/file/d/16eiUUwKGWfh9pu9VUxzlx046hQNHV0Qe/view?usp=sharinghttps://drive.google.com/file/d/16eiUUwKGWfh9pu9VUxzlx046hQNHV0Qe/view?usp=sharing)


## Problem setting

### Model
Given input and output data, we want to learn a supervised model of the function $X \to y$ as follows:

$
f: X,h_n \mapsto \tilde{y},h_{n+1}, \lambda_n
$

where $X$ and $y$ denote stimulus and response symbols, $\lambda_n$ denotes halting probability at time $n$, and $h_{n}$ is the latent state of the model. The learninig continious up to the time point $N$.

For the brevity and compatibility, both data are one-hot encoded.


### Input

One-hot encoded symbols.

### Output

One-hot encoded symbols.

### Criterion

L = L_cross_entropy + L_halting

In [1]:

%reload_ext autoreload
%autoreload 3

# Setup and imports
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, random_split

from tqdm import tqdm

from sklearn.metrics import accuracy_score

import matplotlib.pyplot as plt
import seaborn as sns; sns.set()

from torch.utils.tensorboard import SummaryWriter
# import tensorboard as tb
# import tensorflow as tf
# tf.io.gfile = tb.compat.tensorflow_stub.io.gfile #FIX storing embeddings using tensorboard


from cogponder import NBackDataset
from cogponder import PonderNet, ICOM, ReconstructionLoss, RegularizationLoss

## Data

In [2]:
# generate mock n-back data

max_steps = 20
n_subjects = 2
n_trials = 100
n_stimuli = 6

dataset = NBackDataset(n_subjects, n_trials, n_stimuli)

X, y, accuracies, response_times = dataset[0]
y = torch.where(y != 0.0, 1, 0)

# DEBUG
# X.shape, y.shape, accuracies.shape, response_times.shape

In [35]:
# training params
n_epoches = 100

logs = SummaryWriter()

model = PonderNet(ICOM, n_stimuli+1, n_stimuli, 2, 20)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

dataset = TensorDataset(torch.tensor(X), torch.tensor(y).to(torch.float))


# split params
train_size = int(len(dataset) * .8)
test_size = len(dataset) - train_size

train_subset, test_subset = random_split(dataset, lengths=(train_size,test_size))

X_train, y_train = dataset[train_subset.indices]
X_test, y_test = dataset[test_subset.indices]

loss_rec_fn = ReconstructionLoss(nn.CrossEntropyLoss(reduction='mean'))
loss_reg_fn = RegularizationLoss(lambda_p=.5, max_steps=20)
loss_beta = .01

loss_fn = nn.CrossEntropyLoss(reduction='mean')

for epoch in tqdm(range(n_epoches), desc='Epochs'):

  model.train()
  optimizer.zero_grad()
  y_steps, _, y_pred, p_halt, halt_step = model(X_train)

  # FIXME
  y_pred = y_pred.argmax(dim=1).to(torch.float)
  # y_pred.requires_grad = True

  # loss = loss_fn(y_pred, y_train)

  loss_rec = loss_rec_fn(p_halt, y_steps, y_train)
  loss_reg = loss_reg_fn(p_halt)
  loss = loss_rec + loss_beta * loss_reg
  
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

  logs.add_scalar('loss/train', loss, epoch)

  # accuracy_train = accuracy_score(y_train.detach().numpy(), y_pred.detach().numpy())
  # logs.add_scalar('accuracy/train', accuracy_train, epoch)  

  # model.eval()
  # with torch.no_grad():
  #   _, _, y_pred, _ = model(X_test)
  #   loss = criterion(y_test, y_pred)
  #   logs.add_scalar('loss/test', loss.detach(), epoch)

# tensorboard --logdir=runs

  dataset = TensorDataset(torch.tensor(X), torch.tensor(y).to(torch.float))


tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 0., 1.])


Epochs:   0%|          | 0/100 [00:00<?, ?it/s]


torch.Size([78, 2]) torch.Size([78])


RuntimeError: expected scalar type Long but found Float

In [None]:
# example code to decode a stimulus into multiple sequence (one per channel)

import torch
from torch import nn

n_inputs = 7
max_timestep = 10
n_channels = 5

X = torch.nn.functional.one_hot(torch.tensor(4), num_classes=n_inputs).type(torch.float)

decode = nn.Linear(n_inputs, n_channels * max_timestep)
out = decode(X).reshape((n_channels, max_timestep))

print(out.shape)

In [None]:
d = torch.distributions.Geometric(torch.tensor([0.3, .01]))

d.sample()