# CogPonder: Stroop Task with Fixed Hyper-parameters

This notebook implements the CogPonder algorithm using PyTorch Lightning to perfrom the Stroop task. It assumes fixed hyper-parameters and fits the model to a single-subject dataset. It wraps a simple linear network with a pondering layer and trains it on the *Self-Regulation Ontology* dataset.


## Data

Here, we fit the Stroop data from the *Self-Regulation Ontology*. The data is loaded from the `data/Self_Regulation_ontology/` directory. See the `data/Self_Regulation_ontology/README.md` file for more information.

### Input and Output

#### Stroop

The input is the 1) color and 2) letter of the current stimuli. The output is the human response to the current trial (red, green, or blue).


#### N-Back

Previous N+1 presented symbols are used as input, the last input is the current symbol. The output is the human response to the N+1th trial.


## Hyper-parameters

- `n_embeddings`: number of hidden units in the operator model. Defaults to $N_{\text{symbols} + 1}$
- `rec_loss_beta`: the beta parameter of the loss function. Defaults to 0.5.
- `cog_loss_beta`: the beta parameter of the loss function. Defaults to 0.5.
- `learning_rate`: the learning rate of the optimizer. Defaults to 0.0001.
- `max_response_step`: maximum response step in the dataset. Defaults to $\max(\text{response\_step}) + 10$.

## Criterion

$L = L_{\text{reconstruction}} + L_{\text{cognitive}}$


In [2]:
%reload_ext autoreload
%autoreload 3

import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display

import torch
import torchmetrics
import pytorch_lightning as pl
import pandas as pd
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBar
from src.cogponder import CogPonderModel
from src.cogponder.datasets import NBackSRODataset, CogPonderDataModule
from pathlib import Path

In [3]:
# this notebook only fits one SRO subject, which its SRO-SubjectID can be defined here
SRO_SUBJECT_ID = 202

# number of maximum epochs to train
MAX_EPOCHS = 10000

# upon successful training, the model will be saved to this path
CHECKPOINT_PATH = Path('models/nback/') / f'cogponder_subject-{SRO_SUBJECT_ID}_epochs-{MAX_EPOCHS}.ckpt'

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [17]:
# Load the dataset and configs

print('Loading N-Back dataset... ', end='')

dataset = NBackSRODataset(n_subjects=-1, response_step_interval=10, non_decision_time='auto')
data = dataset[0]

# determine the number of loaded subjects
n_subjects = data[0].size(1)


n_symbols = torch.unique(data[0]).shape[0]

# parameter space
CONFIG = {
    'task': 'nback',
    'resp_loss_beta': 1.,
    'time_loss_beta': 10.,
    # 'non_decision_time': 10,  # in milliseconds
    'loss_by_trial_type': False,
    'learning_rate': 1e-2,
    'max_response_step': data[4].max().int().item() + 10,
    'inputs_dim': data[0].size(1) - 1,  # minus subject_id (first column)
    'embeddings_dim': n_symbols,
    'outputs_dim': torch.unique(data[3]).size(0),  # number of unique responses
    'auto_lr_find': False,
    'batch_size': 72,
    'n_subjects': 1
}

datamodule = CogPonderDataModule(data,
                                 batch_size=CONFIG['batch_size'],
                                 num_workers=8)

print('Done!')

Loading N-Back dataset... Done!


3

In [52]:
# Define the pondering model and run the trainer

model = CogPonderModel(CONFIG)#, example_input_array=data[0][:1].to(device))

# Trainer
trainer = pl.Trainer(
    max_epochs=MAX_EPOCHS,
    # min_epochs=100,
    accelerator='auto',
    auto_lr_find=CONFIG['auto_lr_find'],
    log_every_n_steps=1,
    # overfit_batches=True,
    # accumulate_grad_batches=2,
    callbacks=[
        RichProgressBar(),
        EarlyStopping(monitor='val/total_loss',
                      patience=np.max([10, MAX_EPOCHS // 10]).item(),
                      mode='min', min_delta=0.001),
    ])

# Auto-detect learning-rate if the flag is set
if CONFIG['auto_lr_find']:
    trainer.tune(model, datamodule=datamodule)

# Fit and evaluate the model
trainer.fit(model, datamodule=datamodule)

# Save the latest checkpoint
trainer.save_checkpoint(CHECKPOINT_PATH)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(overfit_batches=1)` was configured so 1 batch will be used.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

In [59]:
# DEBUG

# DEBUG - Load the checkpoint

model_ckpt = CogPonderModel.load_from_checkpoint(CHECKPOINT_PATH)
model_ckpt.eval()

if not 'datamodule' in locals() or not hasattr(datamodule, 'train_dataset'):
    print('loading N-back dataset...', end='')
    data = NBackSRODataset(n_subjects=-1, response_step_interval=10)[SRO_SUBJECT_ID]
    datamodule = CogPonderDataModule(data, batch_size=CONFIG['batch_size'], num_workers=8)
    datamodule.prepare_data()
    print('Done!')

X_train, trial_types_train, is_corrects_train, y_train, rt_train = datamodule.train_dataset[:]
X_test, trial_types_test, is_corrects_test, y_test, rt_test = datamodule.test_dataset[:]

with torch.no_grad():
    model_ckpt.eval()

    y_train_steps,p_train,rt_train_pred = model_ckpt(X_train)
    y_test_steps,p_test,rt_test_pred = model_ckpt(X_test)

    y_train_steps = torch.argmax(y_train_steps, dim=-1)
    y_test_steps = torch.argmax(y_test_steps, dim=-1)

    y_train_pred = y_train_steps.gather(dim=0, index=rt_train_pred[None, :] - 1,)[0]  # (batch_size,)
    y_test_pred = y_test_steps.gather(dim=0, index=rt_test_pred[None, :] - 1,)[0]  # (batch_size,)

    train_res = pd.DataFrame({'true_rt_train': rt_train.detach().tolist(),
                              'pred_rt_train': rt_train_pred.tolist()})
    test_res = pd.DataFrame({'true_rt_test': rt_test.detach().tolist(),
                             'pred_rt_test': rt_test_pred.tolist()})

    display(train_res.T, test_res.T)

# DEBUG report mean-RT
print(f'RT train mean (pred/true): '
      f'{rt_train_pred.float().mean().item():.2f}, '
      f'{rt_train.float().mean().item():.2f}')

print(f'RT test  mean (pred/true): '
      f'{rt_test_pred.float().mean().item():.2f}, '
      f'{rt_test.float().mean().item():.2f}')

# DEBUG - report sd-RT
print(f'RT train std (pred/true): '
      f'{rt_train_pred.float().std().item():.2f}, '
      f'{rt_train.float().std().item():.2f}')

print(f'RT test  std (pred/true): '
      f'{rt_test_pred.float().std().item():.2f}, '
      f'{rt_test.float().std().item():.2f}')


cuda:0
RT TRUE TRAIN: [54, 30, 32, 44, 35, 45, 33, 38, 33, 63, 45, 38, 38, 30, 38, 26, 25, 31, 47, 37, 44, 49, 44, 39, 50, 37, 46, 51, 37, 44, 54, 34, 42, 54, 36, 45, 44, 43, 41, 51, 30, 69, 51, 43, 39, 31, 47, 32, 32, 25, 60, 45, 46, 45, 52, 53, 59, 57, 47, 52, 42, 50, 32, 26, 23, 24, 44, 28, 49, 48, 39] 
RT PRED TRAIN: [54, 18, 92, 84, 100, 5, 49, 61, 44, 100, 45, 46, 93, 18, 16, 19, 4, 100, 100, 45, 100, 84, 46, 23, 6, 100, 72, 26, 39, 46, 45, 15, 46, 43, 32, 46, 5, 91, 46, 16, 100, 11, 45, 48, 29, 69, 46, 46, 13, 17, 47, 55, 44, 8, 47, 9, 45, 66, 10, 62, 57, 45, 46, 46, 49, 45, 28, 78, 100, 17, 20]
RT TRUE TEST: [34, 35, 57, 31, 29, 39, 66, 49, 31, 73, 67, 41, 38, 44, 33, 72, 47, 40, 56, 36, 66, 37, 46, 29] 
RT PRED TEST: [55, 47, 51, 21, 100, 46, 92, 48, 48, 48, 18, 46, 45, 74, 100, 46, 58, 49, 49, 13, 100, 78, 13, 48]
RT mean (pred/true):  47.43661880493164 41.78873062133789
RT mean (pred/true):  53.875 45.66666793823242
