# PonderICOM: Joint Modeling of Accuracy and Speed in Cognitive Tasks
## Intro

In the context of behavioral data, we are interested in simultaneously modeling speed and accuracy. Yet, most advanced techniques in machine learning cannot capture such a duality of decision making data.


Building on [PonderNet](https://arxiv.org/abs/2107.05407) and [Variable Rate Coding](https://doi.org/10.32470/CCN.2019.1397-0), this notebook implements a neural model that captures speed and accuracy of human-like responses.

Given stimulus symbols as inputs, the model produces two outputs:

- Response symbol, which, in comparison with the input stimuli, can be used to measure accuracy).
- Halting probability ($\lambda_n$).

Under the hood, the model iterates over a ICOM-like component to reach a halting point in time. Unlike DDM and ICOM models, all the parameters and outcomes of the current model *seem* cognitively interpretable.

### Additional resources

- [ICOM network model](https://drive.google.com/file/d/16eiUUwKGWfh9pu9VUxzlx046hQNHV0Qe/view?usp=sharinghttps://drive.google.com/file/d/16eiUUwKGWfh9pu9VUxzlx046hQNHV0Qe/view?usp=sharing)


## Problem setting

### Model
Given input and output data, we want to learn a supervised model of the function $X \to y$ as follows:

$
f: X,h_n \mapsto \tilde{y},h_{n+1}, \lambda_n
$

where $X$ and $y$ denote stimulus and response symbols, $\lambda_n$ denotes halting probability at time $n$, and $h_{n}$ is the latent state of the model. The learninig continious up to the time point $N$.

For the brevity and compatibility, both data are one-hot encoded.


### Input

One-hot encoded symbols.

### Output

One-hot encoded symbols.

### Criterion

L = L_cross_entropy + L_halting

In [3]:

%reload_ext autoreload
%autoreload 3

# Setup and imports
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, random_split

from tqdm import tqdm

from sklearn.metrics import accuracy_score

import matplotlib.pyplot as plt
import seaborn as sns; sns.set()

from torch.utils.tensorboard import SummaryWriter
# import tensorboard as tb
# import tensorflow as tf
# tf.io.gfile = tb.compat.tensorflow_stub.io.gfile #FIX storing embeddings using tensorboard


from cogponder import NBackDataset
from cogponder import PonderNet, ICOM, ReconstructionLoss, RegularizationLoss

## Data

In [20]:


# generate mock n-back data

n_subjects = 2
n_trials = 100
n_stimuli = 6

dataset = NBackDataset(n_subjects, n_trials, n_stimuli)

X, responses, matches, response_times = dataset[0]


In [22]:
# training params
n_epoches = 1000

logs = SummaryWriter()

model = PonderNet(ICOM, n_stimuli+1, n_stimuli, 2, 20)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

dataset = TensorDataset(X, matches.float(), response_times)


# split params
train_size = int(len(dataset) * .8)
test_size = len(dataset) - train_size

train_subset, test_subset = random_split(dataset, lengths=(train_size,test_size))

X_train, y_train, rt_train = dataset[train_subset.indices]
X_test, y_test, rt_test = dataset[test_subset.indices]

loss_rec_fn = ReconstructionLoss(nn.BCELoss(reduction='mean'))
loss_reg_fn = RegularizationLoss(lambda_p=.5, max_steps=20)
loss_beta = .2

loss_fn = nn.BCELoss(reduction='mean')

for epoch in tqdm(range(n_epoches), desc='Epochs'):

  model.train()
  optimizer.zero_grad()
  y_steps, p_halt, halt_step = model(X_train)

  loss_rec = loss_rec_fn(p_halt, y_steps, y_train)
  loss_reg = loss_reg_fn(p_halt, rt_train)
  loss = loss_rec + loss_beta * loss_reg

  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

  logs.add_scalar('loss/rec_train', loss_rec, epoch)
  logs.add_scalar('loss/reg_train', loss_reg, epoch)
  logs.add_scalar('loss/train', loss, epoch)

  y_pred = y_steps.detach()[0, halt_step].argmax(dim=1)
  train_accuracy = accuracy_score(y_pred, y_train)
  
  logs.add_scalar('accuracy/train', train_accuracy, epoch)  

  model.eval()
  with torch.no_grad():
    y_steps, p_halt, halt_step = model(X_test)
    loss_rec = loss_rec_fn(p_halt, y_steps, y_test)
    loss_reg = loss_reg_fn(p_halt, rt_test)
    loss = loss_rec + loss_beta * loss_reg
    logs.add_scalar('loss/test', loss.detach(), epoch)

    y_pred = y_steps.detach()[0, halt_step].argmax(dim=1)
    test_accuracy = accuracy_score(y_pred, y_test)
  
    logs.add_scalar('accuracy/test', test_accuracy, epoch)  

# tensorboard --logdir=runs

Epochs: 100%|██████████| 1000/1000 [00:36<00:00, 27.44it/s]
