# CogPonder: N-Back PyTorch Lightning

This notebook implements the CogPonder algorithm using PyTorch Lightning. It assumes fixed hyperparameters and fits the model to a single-subject dataset. It wraps a simple RRN with a pondering lambda layer and trains it on the *Self-Regulation Ontology* dataset.

## Data
The SRO-2back dataset interface provides the following features from the *Self-Regulation Ontology* study:

- input `X`: previous 3 symbols for the subject $i$ and trial $j$; For each subject, $X_i$ is a 2-dimensional vector of integers of shape (3, $N_{\text{trials}}$).
- `trial_type`: Correct match, incorrect match, correct-non-match, incorrect-non-match for each trial $i$.
- `is_target`: whether the trial $i$ is a match; it is a boolean.
- output `response`: the response of the subject for the trial i; it is a boolean.
- `response_step`: the response step of the subject for the trial i; Response step is an integer and represents *response times* in 50ms steps. This step duration is a hyperparameter of the data module.

## Hyperparameters

- `n_embeddings`: number of hidden units in the recurrent ICOM model. Defaults to $N_{\text{symbols} + 1}$
- `rec_loss_beta`: the beta parameter of the loss function. Defaults to 0.5.
- `cog_loss_beta`: the beta parameter of the loss function. Defaults to 0.5.
- `learning_rate`: the learning rate of the optimizer. Defaults to 0.0001.
- `max_response_step`: maximum response step in the dataset. Defaults to $\max(\text{response\_step}) + 10$.

## Criterion

$L = L_{\text{reconstruction}} + L_{\text{cognitive}}$


In [1]:
%reload_ext autoreload
%autoreload 3

import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from ray import tune, air
from ray.tune import JupyterNotebookReporter
from functools import partial
from pytorch_lightning.callbacks import RichProgressBar, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from ray.tune.integration.pytorch_lightning import TuneReportCallback
from cogponder import CogPonderNet
from cogponder.datasets import NBackMockDataset, NBackDataModule, NBackSRODataset
from pathlib import Path

from cogponder.losses import ReconstructionLoss, CognitiveLoss



In [2]:
# data
data = NBackSRODataset(n_subjects=1, n_back=2) # shape (n_subjects, (...))
datamodule = NBackDataModule(data, batch_size=32, num_workers=1)
n_symbols = torch.unique(data[0][0]).shape[0]
max_response_step = 25 # data[0][4].max().item() + 1 # max number of steps = 2 * max observed RT

# parameter space
CONFIG = {
    'rec_loss_beta': 1.,
    'cog_loss_beta': .1,
    'loss_by_trial_type': False,
    'learning_rate': 1e-2,
    'max_response_step': max_response_step,
    'inputs_dim': data[0][0].size(1),
    'embeddings_dim': n_symbols,
    'auto_lr_find': False,
}
device = 'cuda' if torch.cuda.is_available() else 'cpu'


In [3]:
# pondering model

model = CogPonderNet(CONFIG, example_input_array=data[0][0][:1].to(device))

# # DEBUG
# X = data[0][0][:10]
# y_true = data[0][3][:10]
# rt_true = data[0][4][:10]
# y_steps, p_halts, rt_pred = model(X)
# loss_func = CognitiveLoss(CONFIG['max_response_step'])
# l = loss_func(rt_pred, rt_true)
# 'l', l, rt_pred, rt_true

trainer = pl.Trainer(
    max_epochs=5000,
    # min_epochs=1000,
    accelerator='auto',
    auto_lr_find=CONFIG['auto_lr_find'],
    log_every_n_steps=4,
    # overfit_batches=True,
    # accumulate_grad_batches=4,
    callbacks=[
        RichProgressBar(),
        EarlyStopping(monitor='val_loss', patience=1000, mode='min', min_delta=0.01),
    ])

if CONFIG['auto_lr_find']:
    trainer.tune(model, datamodule=datamodule)

trainer.fit(model, datamodule=datamodule)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

In [4]:
# DEBUG: Plot LR tuning results
# lr_finder = trainer.tuner.lr_find(model, max_lr=2, datamodule=datamodule)
# fig = lr_finder.plot(suggest=True)
# fig.show()
# model.hparams.learning_rate = lr_finder.suggestion()
# trainer.fit(model, datamodule=datamodule)

In [5]:
# DEBUG

import torch
import torchmetrics

X_train, _, y_train, _, rt_train = datamodule.dataset[datamodule.train_dataset.indices]
X_test, _, y_test, _, rt_test = datamodule.dataset[datamodule.test_dataset.indices]

with torch.no_grad():
    model.eval()
    y_train_steps,_,rt_train_pred = model(X_train.to(device))
    y_test_steps,_,rt_test_pred = model(X_test.to(device))

    y_train_pred = y_train_steps.gather(dim=0, index=rt_train_pred[None, :] - 1,)[0]  # (batch_size,)
    y_test_pred = y_test_steps.gather(dim=0, index=rt_test_pred[None, :] - 1,)[0]  # (batch_size,)

    accuracy_fn = torchmetrics.Accuracy().to(device)
    train_accuracy = accuracy_fn(y_train_pred, y_train.to(device))
    print('TRAIN ACCURACY', train_accuracy.item())


    accuracy_fn = torchmetrics.Accuracy().to(device)
    test_accuracy = accuracy_fn(y_test_pred, y_test.to(device))
    print('TEST ACCURACY', test_accuracy.item())

    # DEBUG report the ground truth and predicted response times
    print('TRUE TRAIN:', rt_train.detach().tolist(), '\nPRED TRAIN:',  rt_train_pred.tolist())
    print('TRUE TEST:', rt_test.detach().tolist(), '\nPRED TEST:',  rt_test_pred.tolist())

# DEBUG report medians
# rt_train_pred.median(), rt_train.float().median()
# rt_test_pred.median(), rt_test.float().median()


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat2 in method wrapper_mm)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

RT_CAP = max_response_step # data[0][4].max().item()

sns.ecdfplot(rt_train, label='True (train)')
sns.ecdfplot(rt_train_pred[rt_train_pred < RT_CAP], label='Predicted (train)')

plt.title('Evaluation of PonderNet on simulated train split')
plt.xlabel('response time (steps)')

plt.legend()
plt.show()

sns.ecdfplot(rt_test, label='True (test)')
sns.ecdfplot(rt_test_pred[rt_test_pred < RT_CAP], label='Predicted (test)')

plt.title('Evaluation of PonderNet on simulated test split')
plt.xlabel('response time (steps)')
plt.legend()
plt.show()

In [None]:

sns.kdeplot(rt_train, label='Train (TRUE)', cut=0)
sns.kdeplot(rt_train_pred[rt_train_pred < RT_CAP], label='Train (PRED)', cut=0)

sns.kdeplot(rt_test, label='Test (TRUE)', cut=0)
sns.kdeplot(rt_test_pred[rt_test_pred < RT_CAP], label='Test (PRED)', cut=0)


plt.title('Evaluation of PonderNet on SRO single-subject 2-back')
plt.xlabel('response time (steps)')
plt.legend()
plt.show()