# CogPonder: Stroop task with fixed hyper-parameters

This notebook implements the CogPonder framework using PyTorch Lightning and evaluate it on the Stroop task.


It assumes fixed hyper-parameters and fits the model to a single-subject dataset. It wraps a simple linear network with a pondering layer and trains it on a random subject from the *Self-Regulation Ontology* dataset.


## Data

Here, we fit the Stroop data from the *Self-Regulation Ontology*. The data is loaded from the `data/Self_Regulation_ontology/` directory; see the corresponding [README](../data/Self_Regulation_Ontology/README.md) for more information on the data structure.

### Input and Output


The input is a list of trials, including 1) the color and 2) the letter of the current stimulus. The output is the human response to the current trial (i.e., red, green, or blue), and response time.


## Hyper-parameters

- `n_embeddings`: number of hidden units in the operator model. Defaults to 8.
- `resp_loss_beta`: the beta parameter of the loss function. Defaults to 1.
- `time_loss_beta`: the beta parameter of the loss function. Defaults to 10.
- `learning_rate`: the learning rate of the optimizer. Defaults to 0.001.
- `max_response_step`: maximum response step. Defaults to the maximum response time in the dataset divided by the response_step_interval.
- `response_step_interval`: the interval between response steps. Defaults to 10ms.

## Criterion

The loss function is a weighted sum of the reconstruction loss ($L_{\text{response}}$) which measures the corss-entropy loss between human response and predicted response, and is regularized by the response time loss ($L_{\text{time}}$), which measures the KL-div between human response times and predicted response times.

$L_{\text{total}} = \sum pL_{\text{response}} + \beta L_{\text{time}}$


In [1]:
%reload_ext autoreload
%autoreload 3

import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display

import torch
import pytorch_lightning as pl
import pandas as pd
from pytorch_lightning.callbacks import EarlyStopping
from src.cogponder import CogPonderModel
from src.cogponder.data import StroopSRODataset, CogPonderDataModule
from pathlib import Path

In [2]:
# Parameters

# number of maximum epochs to train (early stopping will be applied)
# early stopping patience is 10% of max_epochs (min 10 epochs)
max_epochs = 1000
batch_size = 256
step_duration = 20

In [3]:
# Data

print(f'Loading Stroop dataset... ', end='')

dataset = StroopSRODataset(step_duration=20, non_decision_time='auto')
datamodule = CogPonderDataModule(dataset, batch_size=batch_size, num_workers=8)

# determine some parameters from data
n_features = dataset[:][3].size(-1)
n_subjects = torch.unique(dataset[:][0]).size(0)
n_contexts = torch.unique(dataset[:][2]).size(0)
n_outputs = torch.unique(dataset[:][4]).size(0)
max_response_step = dataset[:][5].max().int().item() + 1

configs = {
    'inputs_dim': n_features,
    'outputs_dim': n_outputs,
    'embeddings_dim': 8,
    'response_loss_beta': 1.,
    'time_loss_beta': 10.,
    'learning_rate': 1e-2,
    'max_response_step': max_response_step,
    'n_contexts': n_contexts,
    'n_subjects': n_subjects,
    'subject_embeddings_dim': 2,
    'task': 'stroop',
    'operator_type': 'simple'
}

print('Done!')

Loading Stroop dataset... Done!


In [4]:
# Experiment

model = CogPonderModel(**configs)

# TODO: check if torch>=2.0 is installed
# model = torch.compile(model)

# Trainer
trainer = pl.Trainer(
    max_epochs=max_epochs,
    min_epochs=100,
    accelerator='auto',
    log_every_n_steps=1,
    # overfit_batches=True,
    # accumulate_grad_batches=2,
    callbacks=[
        EarlyStopping(monitor='val/total_loss',
                      patience=100,
                      mode='min', min_delta=0.001),
    ])

# Fit and evaluate the model
trainer.fit(model, datamodule=datamodule)

# save checkpoint
checkpoint_path = f'models/checkpoints/stroop/cogponder_epochs-{trainer.current_epoch}.ckpt'
trainer.save_checkpoint(checkpoint_path)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name            | Type                 | Params
---------------------------------------------------------
0 | halt_node       | HaltingModule        | 81    
1 | recurrence_node | RecurrenceModule     | 1.4 K 
2 | operator_node   | SimpleOperatorModule | 108   
3 | embeddings      | Embedding            | 16    
4 | resp_loss_fn    | ResponseLoss         | 0     
5 | time_loss_fn    | ResponseTimeLoss     | 0     
---------------------------------------------------------
1.6 K     Trainable params
0         Non-trainable params
1.6 K     Total params
0.006     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [None]:
# DEBUG

model_ckpt = CogPonderModel.load_from_checkpoint(checkpoint_path)