# CogPonder: N-Back PyTorch Lightning

This notebook implements the CogPonder algorithm using PyTorch Lightning. It assumes fixed hyperparameters and fits the model to a single-subject dataset. It wraps a simple RRN with a pondering lambda layer and trains it on the *Self-Regulation Ontology* dataset.

## Data

Either N-back or Stroop is used as the dataset. The data is loaded from the `data/Self_Regulation_ontology/` directory.

### Input and Output

#### N-Back

Previous N+1 presented symbols are used as input, the last input is the current symbol. The output is the human response to the N+1th trial.

#### Stroop

The input is the color and letter of the current stimuli. The output is the human response to the current trial.


## Hyperparameters

- `n_embeddings`: number of hidden units in the recurrent ICOM model. Defaults to $N_{\text{symbols} + 1}$
- `rec_loss_beta`: the beta parameter of the loss function. Defaults to 0.5.
- `cog_loss_beta`: the beta parameter of the loss function. Defaults to 0.5.
- `learning_rate`: the learning rate of the optimizer. Defaults to 0.0001.
- `max_response_step`: maximum response step in the dataset. Defaults to $\max(\text{response\_step}) + 10$.

## Criterion

$L = L_{\text{reconstruction}} + L_{\text{cognitive}}$


In [4]:
%reload_ext autoreload
%autoreload 3

import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from ray import tune, air
from ray.tune import JupyterNotebookReporter
from functools import partial
from pytorch_lightning.callbacks import RichProgressBar, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from ray.tune.integration.pytorch_lightning import TuneReportCallback
from src.cogponder import CogPonderModel
from src.cogponder.datasets import StroopSRODataset, NBackSRODataset, CogPonderDataModule
from src.cogponder.losses import ReconstructionLoss, CognitiveLoss
from pathlib import Path




In [5]:
# Load the dataset and configs

TASK = 'stroop'  # or stroop

match TASK:
    case 'nback':
        print('Loading N-back dataset...')
        data = NBackSRODataset(n_subjects=1, n_back=2) # shape (n_subjects, (...))
        n_symbols = torch.unique(data[0][0]).shape[0]
        embeddings_dim = n_symbols
        max_response_step = 30 # OR something like "2 * max observed RT"
    case 'stroop':
        print('Loading Stroop dataset...')
        data = StroopSRODataset(n_subjects=1)
        embeddings_dim = 2
        max_response_step = 100

# parameter space
CONFIG = {
    'rec_loss_beta': 1.,
    'cog_loss_beta': .5,
    'loss_by_trial_type': False,
    'learning_rate': 1e-2,
    'max_response_step': max_response_step,
    'inputs_dim': data[0][0].size(1),
    'embeddings_dim': embeddings_dim,
    'auto_lr_find': False,
    'task': TASK,
    'batch_size': 196
}

device = 'cuda' if torch.cuda.is_available() else 'cpu'


Loading Stroop dataset...


In [6]:
# pondering model

datamodule = CogPonderDataModule(data, batch_size=CONFIG['batch_size'], num_workers=1)
model = CogPonderModel(CONFIG, example_input_array=data[0][0][:1].to(device))

# # DEBUG
# X = data[0][0][:10]
# y_true = data[0][3][:10]
# rt_true = data[0][4][:10]
# y_steps, p_halts, rt_pred = model(X)
# loss_func = CognitiveLoss(CONFIG['max_response_step'])
# l = loss_func(rt_pred, rt_true)
# 'l', l, rt_pred, rt_true

trainer = pl.Trainer(
    max_epochs=10000,
    # min_epochs=1000,
    accelerator='auto',
    auto_lr_find=CONFIG['auto_lr_find'],
    log_every_n_steps=1,
    overfit_batches=True,
    # accumulate_grad_batches=2,
    callbacks=[
        RichProgressBar(),
        EarlyStopping(monitor='val_loss', patience=500, mode='min', min_delta=0.001),
    ])

if CONFIG['auto_lr_find']:
    trainer.tune(model, datamodule=datamodule)

trainer.fit(model, datamodule=datamodule)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(overfit_batches=1)` was configured so 1 batch will be used.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

ValueError: If `preds` and `target` are of shape (N, ...) and `preds` are floats, `target` should be binary.

In [None]:
# DEBUG: Plot LR tuning results
# lr_finder = trainer.tuner.lr_find(model, max_lr=2, datamodule=datamodule)
# fig = lr_finder.plot(suggest=True)
# fig.show()
# model.hparams.learning_rate = lr_finder.suggestion()
# trainer.fit(model, datamodule=datamodule)

In [None]:
# DEBUG

import torch
import torchmetrics

X_train, _, _, y_train, rt_train = datamodule.dataset[datamodule.train_dataset.indices]
X_test, _, _, y_test, rt_test = datamodule.dataset[datamodule.test_dataset.indices]

model.eval()
with torch.no_grad():
    y_train_steps,_,rt_train_pred = model(X_train.to(device))
    y_test_steps,_,rt_test_pred = model(X_test.to(device))

    y_train_pred = y_train_steps.gather(dim=0, index=rt_train_pred[None, :] - 1,)[0]  # (batch_size,)
    y_test_pred = y_test_steps.gather(dim=0, index=rt_test_pred[None, :] - 1,)[0]  # (batch_size,)

    accuracy_fn = torchmetrics.Accuracy().to(device)
    train_accuracy = accuracy_fn(y_train_pred, y_train.int().to(device))
    print('TRAIN ACCURACY', train_accuracy.item())

    accuracy_fn = torchmetrics.Accuracy().to(device)
    test_accuracy = accuracy_fn(y_test_pred, y_test.int().to(device))
    print('TEST ACCURACY', test_accuracy.item())

    # DEBUG report the ground truth and predicted response times
    print('TRUE TRAIN:', rt_train.detach().tolist(), '\nPRED TRAIN:',  rt_train_pred.tolist())
    print('TRUE TEST:', rt_test.detach().tolist(), '\nPRED TEST:',  rt_test_pred.tolist())

# DEBUG report medians
# rt_train_pred.median(), rt_train.float().median()
# rt_test_pred.median(), rt_test.float().median()


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

RT_CAP = max_response_step # data[0][4].max().item()

sns.ecdfplot(rt_train, label='True (train)')
sns.ecdfplot(rt_train_pred[rt_train_pred < RT_CAP].cpu(), label='Predicted (train)')

plt.title('Evaluation of PonderNet on simulated train split')
plt.xlabel('response time (steps)')

plt.legend()
plt.show()

sns.ecdfplot(rt_test, label='True (test)')
sns.ecdfplot(rt_test_pred[rt_test_pred < RT_CAP].cpu(), label='Predicted (test)')

plt.title('Evaluation of PonderNet on simulated test split')
plt.xlabel('response time (steps)')
plt.legend()
plt.show()

In [None]:

sns.kdeplot(rt_train, label='Train (TRUE)', cut=0)
sns.kdeplot(rt_train_pred[rt_train_pred < RT_CAP].cpu(), label='Train (PRED)', cut=0)

sns.kdeplot(rt_test, label='Test (TRUE)', cut=0)
sns.kdeplot(rt_test_pred[rt_test_pred < RT_CAP].cpu(), label='Test (PRED)', cut=0)


plt.title('Evaluation of PonderNet on SRO single-subject 2-back')
plt.xlabel('response time (steps)')
plt.legend()
plt.show()