# CogPonder: N-Back Task with Fixed Hyper-parameters

This notebook implements the CogPonder algorithm using PyTorch Lightning to perform the 2-back task. It assumes fixed hyper-parameters and fits the model to a single-subject dataset. It wraps a simple linear network with a pondering layer and trains it on the *Self-Regulation Ontology* dataset.


## Data

Here, we use the 2-back data from the *Self-Regulation Ontology*. The data is loaded from the `data/Self_Regulation_ontology/` directory. See the `data/Self_Regulation_ontology/README.md` file for more information.

### Input and Output

#### 2-back

The input is the 1) previous 3 symbols (N+1 for the N-back), including the one in current trial. The last input is the current symbol. The output is the human response to the current trial (match or non-match).

## Hyper-parameters

- `n_embeddings`: number of hidden units in the operator model. Defaults to $N_{\text{symbols} + 1}$
- `rec_loss_beta`: the beta parameter of the loss function. Defaults to 0.5.
- `cog_loss_beta`: the beta parameter of the loss function. Defaults to 0.5.
- `learning_rate`: the learning rate of the optimizer. Defaults to 0.0001.
- `max_response_step`: maximum response step in the dataset. Defaults to $\max(\text{response\_step}) + 10$.

## Criterion

$L = L_{\text{reconstruction}} + L_{\text{cognitive}}$


In [None]:
%reload_ext autoreload
%autoreload 3

import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display

import torch
import torchmetrics
import pytorch_lightning as pl
import pandas as pd
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBar
from src.cogponder import CogPonderModel
from src.cogponder.datasets import NBackSRODataset, CogPonderDataModule
from pathlib import Path

In [None]:
# this notebook only fits one SRO subject, whose SRO-SubjectID can be defined here
SRO_SUBJECT_ID = 202

# number of maximum epochs to train
MAX_EPOCHS = 10000
BATCH_SIZE = 72

# upon successful training, the model will be saved to this path
CHECKPOINT_PATH = Path('models/nback/') / f'cogponder_subject-{SRO_SUBJECT_ID}_epochs-{MAX_EPOCHS}.ckpt'

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# Load the dataset and configs

print('Loading N-Back dataset... ', end='')

dataset = NBackSRODataset(n_subjects=-1, response_step_interval=20, non_decision_time=0)  #'auto')
datamodule = CogPonderDataModule(dataset, batch_size=BATCH_SIZE, num_workers=8)

# determine some parameters from data
n_subjects = torch.unique(dataset[:][0]).size(0)
n_symbols = torch.unique(dataset[:][1]).size(0)
n_inputs = 1
n_outputs = torch.unique(dataset[:][3]).size(0)
max_response_step = dataset[:][4].max().item() + 10

# parameter space
CONFIG = {
    'task': 'nback',
    'resp_loss_beta': 1.,
    'time_loss_beta': 10.,
    # 'non_decision_time': 10,  # in milliseconds
    'loss_by_trial_type': False,
    'learning_rate': 1e-2,
    'max_response_step': max_response_step,
    'inputs_dim': 1,  # number of input features
    'embeddings_dim': n_symbols,
    'outputs_dim': n_outputs,
    'auto_lr_find': False,
    'batch_size': BATCH_SIZE,
    'n_contexts': n_subjects
}

print('Done!')


In [None]:
# Define the pondering model and run the trainer

model = CogPonderModel(CONFIG)#, example_input_array=data[0][:1].to(device))

# Trainer
trainer = pl.Trainer(
    max_epochs=MAX_EPOCHS,
    # min_epochs=100,
    accelerator='auto',
    auto_lr_find=CONFIG['auto_lr_find'],
    log_every_n_steps=1,
    # overfit_batches=True,
    # accumulate_grad_batches=2,
    callbacks=[
        RichProgressBar(),
        EarlyStopping(monitor='val/total_loss',
                      patience=np.max([10, MAX_EPOCHS // 10]).item(),
                      mode='min', min_delta=0.001),
    ])

# Auto-detect learning-rate if the flag is set
if CONFIG['auto_lr_find']:
    trainer.tune(model, datamodule=datamodule)

# Fit and evaluate the model
trainer.fit(model, datamodule=datamodule)

# Save the latest checkpoint
trainer.save_checkpoint(CHECKPOINT_PATH)

In [None]:
# DEBUG

# DEBUG - Load the checkpoint

model_ckpt = CogPonderModel.load_from_checkpoint(CHECKPOINT_PATH)
model_ckpt.eval()

if not 'datamodule' in locals() or not hasattr(datamodule, 'train_dataset'):
    print('loading N-back dataset...', end='')
    data = NBackSRODataset(n_subjects=-1, response_step_interval=10)[SRO_SUBJECT_ID]
    datamodule = CogPonderDataModule(data, batch_size=CONFIG['batch_size'], num_workers=8)
    datamodule.prepare_data()
    print('Done!')

X_train, trial_types_train, is_corrects_train, y_train, rt_train = datamodule.train_dataset[:]
X_test, trial_types_test, is_corrects_test, y_test, rt_test = datamodule.test_dataset[:]

with torch.no_grad():
    model_ckpt.eval()

    y_train_steps,p_train,rt_train_pred = model_ckpt(X_train)
    y_test_steps,p_test,rt_test_pred = model_ckpt(X_test)

    y_train_steps = torch.argmax(y_train_steps, dim=-1)
    y_test_steps = torch.argmax(y_test_steps, dim=-1)

    y_train_pred = y_train_steps.gather(dim=0, index=rt_train_pred[None, :] - 1,)[0]  # (batch_size,)
    y_test_pred = y_test_steps.gather(dim=0, index=rt_test_pred[None, :] - 1,)[0]  # (batch_size,)

    train_res = pd.DataFrame({'true_rt_train': rt_train.detach().tolist(),
                              'pred_rt_train': rt_train_pred.tolist()})
    test_res = pd.DataFrame({'true_rt_test': rt_test.detach().tolist(),
                             'pred_rt_test': rt_test_pred.tolist()})

    display(train_res.T, test_res.T)

# DEBUG report mean-RT
print(f'RT train mean (pred/true): '
      f'{rt_train_pred.float().mean().item():.2f}, '
      f'{rt_train.float().mean().item():.2f}')

print(f'RT test  mean (pred/true): '
      f'{rt_test_pred.float().mean().item():.2f}, '
      f'{rt_test.float().mean().item():.2f}')

# DEBUG - report sd-RT
print(f'RT train std (pred/true): '
      f'{rt_train_pred.float().std().item():.2f}, '
      f'{rt_train.float().std().item():.2f}')

print(f'RT test  std (pred/true): '
      f'{rt_test_pred.float().std().item():.2f}, '
      f'{rt_test.float().std().item():.2f}')
