# CogPonder: N-Back Task with Fixed Hyper-parameters

This notebook implements the CogPonder algorithm using PyTorch Lightning to perform the 2-back task. It assumes fixed hyper-parameters and fits the model to a single-subject dataset. It wraps a simple linear network with a pondering layer and trains it on the *Self-Regulation Ontology* dataset.


## Data

Here, we use the 2-back data from the *Self-Regulation Ontology*. The data is loaded from the `data/Self_Regulation_ontology/` directory. See the `data/Self_Regulation_ontology/README.md` file for more information.

### Input and Output

#### 2-back

The input is the 1) previous 3 symbols (N+1 for the N-back), including the one in current trial. The last input is the current symbol. The output is the human response to the current trial (match or non-match).

## Hyper-parameters

- `n_embeddings`: number of hidden units in the operator model. Defaults to $N_{\text{symbols} + 1}$
- `rec_loss_beta`: the beta parameter of the loss function. Defaults to 0.5.
- `cog_loss_beta`: the beta parameter of the loss function. Defaults to 0.5.
- `learning_rate`: the learning rate of the optimizer. Defaults to 0.0001.
- `max_response_step`: maximum response step in the dataset. Defaults to $\max(\text{response\_step}) + 10$.

## Criterion

$L = L_{\text{reconstruction}} + L_{\text{cognitive}}$


In [1]:
%reload_ext autoreload
%autoreload 3

import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display

import torch
import pytorch_lightning as pl
import pandas as pd
from pytorch_lightning.callbacks import EarlyStopping
from src.cogponder import CogPonderModel
from src.cogponder.data import NBackSRODataset, CogPonderDataModule
from pathlib import Path

In [2]:
# Parameters

# number of maximum epochs to train (early stopping will be applied)
# early stopping patience is 10% of max_epochs (min 10 epochs)
MAX_EPOCHS = 1000
BATCH_SIZE = 1024

In [3]:
# Data

print(f'Loading N-back dataset... ', end='')

dataset = NBackSRODataset(n_back=2, response_step_interval=20)
datamodule = CogPonderDataModule(dataset, batch_size=BATCH_SIZE, num_workers=8)

# determine some parameters from data
n_features = 1
n_contexts = torch.unique(dataset[:][2]).size(0)
n_subjects = torch.unique(dataset[:][1]).size(0)
n_outputs = torch.unique(dataset[:][4]).size(0)
max_response_step = 100

configs = {
    'inputs_dim': n_features,
    'outputs_dim': n_outputs,
    'embeddings_dim': 8,
    'response_loss_beta': 1.,
    'time_loss_beta': 10.,
    'learning_rate': 1e-2,
    'max_response_step': max_response_step,
    'n_contexts': n_contexts,
    'n_subjects': n_subjects,
    'subject_embeddings_dim': 2,
    'task': 'nback'
}

print('Done!')

Loading N-back dataset... Done!


In [4]:
# Experiment

model = CogPonderModel(**configs)

# TODO: check if torch>=2.0 is installed
# model = torch.compile(model)

# Trainer
trainer = pl.Trainer(
    max_epochs=MAX_EPOCHS,
    min_epochs=200,
    accelerator='auto',
    log_every_n_steps=1,
    # overfit_batches=True,
    # accumulate_grad_batches=2,
    callbacks=[
        EarlyStopping(monitor='val/total_loss',
                      patience=100,
                      mode='min', min_delta=0.001),
    ])

# Fit and evaluate the model
trainer.fit(model, datamodule=datamodule)

trainer.current_epoch

# save checkpoint
CHECKPOINT_PATH = f'models/checkpoints/nback/cogponder_epochs-{trainer.current_epoch}.ckpt'
trainer.save_checkpoint(CHECKPOINT_PATH)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name            | Type                 | Params
---------------------------------------------------------
0 | operator_node   | SimpleOperatorModule | 90    
1 | halt_node       | HaltingModule        | 81    
2 | recurrence_node | RecurrenceModule     | 1.4 K 
3 | embeddings      | Embedding            | 16    
4 | resp_loss_fn    | ResponseLoss         | 0     
5 | time_loss_fn    | ResponseTimeLoss     | 0     
---------------------------------------------------------
1.5 K     Trainable params
0         Non-trainable params
1.5 K     Total params
0.006     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [None]:
# DEBUG

model_ckpt = CogPonderModel.load_from_checkpoint(CHECKPOINT_PATH)