# CogPonder: An Interoperable Model of Response Times in Cognitive Tasks

We are interested in a model of response time that is interoperable in human data and computational agents.

Inspired by [PonderNet](https://arxiv.org/abs/2107.05407), this notebook demonstrates CogPonder, a differentiable model that produces human-like speed-accuracy in decision making tasks.

The model iterates over a recurrent decision network, terminating only when it reaches a halting step. The network is trained to 1) maximize decision accuracy and 2) maximize the similarity between the human response time distribution and the halting steps of the network.


## Tasks

### N-back Task

N-back is a cognitive test commonly used to tap into working memory. In this task, human subjects are presented with a sequence of symbols, and are asked to respond with the "target" button whe the current symbol matches the one from N steps earlier in the sequence. The load factor N can be adjusted the difficulty of the task.

Here, we use a 2-back dataset from [Self-Regulation Ontology study]() to evaluate the architecture. The mock dataset includes, for each subject, trial-level $X$ (symbol), $is_targets$ (or $y$; whether it was a target), $responses$, and $response\_times$.


The dataset interface provides the following features:

- $X_i$: previous 3 symbols for the trial i; $X_i$ is a 3-dimensional vector of integers.
- $\text{trial\_type}_i$: see the "Trial types" section below.
- $\text{is\_target}_i$: whether the trial i is a match; $is\_targets_i$ is a boolean.
- $\text{response}_i$: the response of the subject for the trial i; $responses_i$ is a boolean.
- $\text{response\_step}_i$: the response step of the subject for the trial i; $response\_times_i$ is a float. Response steps represent RT in 50ms steps.

#### Trial types

The  dataset includes 3 trial types depending on the response and whether the symbols was a target:

1. correct match
2. correct non-match
3. incorrect match
4. incorrect non-match


### Decision Model

We want to learn a supervised approximation of the $X \to y$ as follows:
$
f: X,h_n \mapsto \tilde{y},h_{n+1}, \lambda_n
$

where $X$ and $y$ denote recent stimulus and responses, $\lambda_n$ is the halting probability at step $n$, and $h_{n}$ is the latent state of the decision model. The learning process continues for a maximum of $N_{max}$ steps. The duration of each step is in fact a hyperparameter of the model, but for brevity is considered 100ms. ($\lambda_n$) is the halting probability in step $n$ (given a Bernoulli distribution), and $p_n$ is the probability of halting at step $n$, which imply that in all the previous steps it did not halt. The halting probability is a function of the latent state $h_n$, $lambda_n$, and the current symbol $X$.

For the N-back task, we define X as a moving window of recent N+1 symbols, e.g., [A, B, C] and [B, C, D], ...; and $y$ is either NON_MATCH (False or 0) or MATCH (True or 1).

### Output

The *CogPonder* network outputs $y\_steps$, $p\_halts$, $halt\_steps$ for each item in the batch. Batch items represent the trials in the N-back task.

### Criterion

$L = L_{\text{reconstruction}} + L_{\text{cognitive}}$


In [1]:
%reload_ext autoreload
%autoreload 3

import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from ray import tune, air
from ray.tune import JupyterNotebookReporter
from functools import partial
from pytorch_lightning.callbacks import RichProgressBar, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from ray.tune.integration.pytorch_lightning import TuneReportCallback
from cogponder import CogPonderModel
from cogponder.datasets import NBackMockDataset, CogPonderDataModule, NBackSRODataset
from pathlib import Path

# CONSTANTS
import os
os.environ['RAY_PICKLE_VERBOSE_DEBUG'] = '1'


In [2]:

# load the data
# mock_data = NBackMockDataset(n_subjects=1, n_trials=198, n_stimuli=5, n_back=2)

# create data module for handling data in PyTorch Lightning


In [10]:
def train_tune(config, epochs=10, data_file=None):

    data = NBackSRODataset(n_subjects=1, n_back=2, data_file=data_file) # shape (n_subjects, (...))
    datamodule = NBackDataModule(data, batch_size=8)
    n_symbols = torch.unique(data[0][0]).shape[0]
    max_response_step = data[0][4].max() + 10
    lambda_p = 1. / data[0][4].median()

    max_response_step = max_response_step.item()
    lambda_p = lambda_p.item()

    config['max_response_step'] = max_response_step
    config['lambda_p'] = lambda_p
    config['n_symbols'] = n_symbols
    config['embeddings_dim'] = n_symbols

    # decision_model = ICOM(n_inputs=n_symbols+1, n_embeddings=n_symbols, n_outputs=2)

    # pondering model
    model = CogPonderModel(
        config,
        example_input_array=data[0][0][:2])

    trainer = pl.Trainer(
        logger=TensorBoardLogger(save_dir=os.getcwd(), name='', version='.'),
        max_epochs=epochs,
        log_every_n_steps=8,
        enable_progress_bar=False,
        # auto_scale_batch_size=True,
        callbacks=[
            TuneReportCallback(['val_loss'], on='validation_end'),
            # RichProgressBar(),
            EarlyStopping(monitor='val_loss', patience=25, mode='min', min_delta=0.01),
        ])

    trainer.fit(model, datamodule=datamodule)

# parameters to tune
config = {
    'loss_beta': .5,
    'loss_by_trial_type': False,
    'learning_rate': 1e-4,
}

# WORKAROUND for PTL 1.7 and RayTune
# See https://github.com/ray-project/ray/issues/28197
import ray
ray.shutdown()
ray.init(runtime_env={"env_vars": {"PL_DISABLE_FORK": "1"}})

# hyperparameter tuner
tuner = tune.Tuner(
    trainable=partial(
        train_tune,
        epochs=100,
        data_file='~/workspace/CogPonder/data/Self_Regulation_Ontology/adaptive_n_back.csv.gz'),
    tune_config=tune.TuneConfig(
        metric='val_loss',
        mode='min',
        num_samples=1,
    ),
    run_config=air.RunConfig(
            log_to_file=True,
            verbose=0,
            progress_reporter = JupyterNotebookReporter(
                parameter_columns=['loss_beta'],
                metric_columns=['val_loss']),
        ),
    param_space=config,
)

# run the tuner
results = tuner.fit()

Trial name,status,loc,loss_beta,val_loss
train_tune_48d14_00000,TERMINATED,127.0.0.1:64765,0.5,0.6956


[2m[33m(raylet)[0m   aiogrpc.init_grpc_aio()
[2m[33m(raylet)[0m   loop = asyncio.get_event_loop()
[2m[36m(func pid=64765)[0m GPU available: False, used: False
[2m[36m(func pid=64765)[0m TPU available: False, using: 0 TPU cores
[2m[36m(func pid=64765)[0m IPU available: False, using: 0 IPUs
[2m[36m(func pid=64765)[0m HPU available: False, using: 0 HPUs
[2m[36m(func pid=64765)[0m   rank_zero_deprecation(
[2m[36m(func pid=64765)[0m   rank_zero_deprecation("The `on_init_end` callback hook was deprecated in v1.6 and will be removed in v1.8.")
[2m[36m(func pid=64765)[0m   rank_zero_deprecation(
[2m[36m(func pid=64765)[0m   rank_zero_deprecation(
[2m[36m(func pid=64765)[0m   rank_zero_deprecation(
[2m[36m(func pid=64765)[0m   rank_zero_deprecation(
[2m[36m(func pid=64765)[0m 
[2m[36m(func pid=64765)[0m   | Name           | Type       | Params | In sizes            | Out sizes          
[2m[36m(func pid=64765)[0m -------------------------------------

Result for train_tune_48d14_00000:
  date: 2022-10-04_16-25-00
  done: false
  experiment_id: c7051382dfc34529a0c452786854ff14
  hostname: MP0159
  iterations_since_restore: 1
  node_ip: 127.0.0.1
  pid: 64765
  time_since_restore: 12.50996994972229
  time_this_iter_s: 12.50996994972229
  time_total_s: 12.50996994972229
  timestamp: 1664893500
  timesteps_since_restore: 0
  training_iteration: 1
  trial_id: 48d14_00000
  val_loss: 0.7188934087753296
  warmup_time: 0.004273891448974609
  
Result for train_tune_48d14_00000:
  date: 2022-10-04_16-25-07
  done: false
  experiment_id: c7051382dfc34529a0c452786854ff14
  hostname: MP0159
  iterations_since_restore: 2
  node_ip: 127.0.0.1
  pid: 64765
  time_since_restore: 19.746084690093994
  time_this_iter_s: 7.236114740371704
  time_total_s: 19.746084690093994
  timestamp: 1664893507
  timesteps_since_restore: 0
  training_iteration: 2
  trial_id: 48d14_00000
  val_loss: 0.7180671691894531
  warmup_time: 0.004273891448974609
  
Result for t

2022-10-04 16:32:07,433	INFO tune.py:758 -- Total run time: 446.32 seconds (446.20 seconds for the tuning loop).


# WARNING: the remaining codes won't work with Ray-Tune.

Please use previous version of the notebook from the git.

In [None]:
# DEBUG

import torch

X_train, _, _, _, rt_train = datamodule.dataset[datamodule.train_dataset.indices]
X_test, _, _, _, rt_test = datamodule.dataset[datamodule.test_dataset.indices]

with torch.no_grad():
    model.eval()
    rt_train_pred = model(X_train)[2].detach()
    rt_test_pred = model(X_test)[2].detach()

    # DEBUG report the ground truth and predicted response times
    print('TRUE TRAIN:', rt_train.detach().tolist(), '\nPRED TRAIN:',  rt_train_pred.tolist())
    print('TRUE TEST:', rt_test.detach().tolist(), '\nPRED TEST:',  rt_test_pred.tolist())

# DEBUG report medians
# rt_train_pred.median(), rt_train.float().median()
# rt_test_pred.median(), rt_test.float().median()


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

sns.ecdfplot(rt_train.detach(), label='True (train)')
sns.ecdfplot(rt_train_pred.detach(), label='Predicted (train)')

plt.title('Evaluation of PonderNet on simulated train split')
plt.xlabel('response time (steps)')

plt.legend()
plt.show()

sns.ecdfplot(rt_test.detach(), label='True (test)')
sns.ecdfplot(rt_test_pred.detach(), label='Predicted (test)')

plt.title('Evaluation of PonderNet on simulated test split')
plt.xlabel('response time (steps)')
plt.legend()
plt.show()

In [None]:

sns.kdeplot(rt_train.detach(), label='Train (TRUE)', cut=0)
sns.kdeplot(rt_train_pred.detach(), label='Train (PRED)', cut=0)

sns.kdeplot(rt_test.detach(), label='Test (TRUE)', cut=0)
sns.kdeplot(rt_test_pred.detach(), label='Test (PRED)', cut=0)


plt.title('Evaluation of PonderNet on SRO single-subject 2-back')
plt.xlabel('response time (steps)')
plt.legend()
plt.show()