# CogPonder: N-Back Task with Fixed Hyper-parameters

This notebook trains a single-task CogPonder agent on N-back task. The operator is a spatiotemporal encoder (CNN+LSTM) with fixed hyper-parameters and a binary classification head.

## Data

Here, we only use the 2-back data from the *Self-Regulation Ontology* dataset ([see `data/Self_Regulation_Ontology/`](../data/Self_Regulation_Ontology/README.md)).

#### 2-back

The input includes recent N+1 symbols, including the one presented in the current trial; the last input is the current symbol. The output is the response to the current trial (either match or non-match).


In [1]:
%reload_ext autoreload
%autoreload 3

import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display, clear_output

import torch
import pytorch_lightning as pl
import pandas as pd
from pytorch_lightning.callbacks import EarlyStopping
from src.cogponder import CogPonderModel
from src.cogponder.data import NBackSRODataset, CogPonderDataModule
from pathlib import Path

In [2]:
# Parameters

# number of maximum epochs to train (early stopping will be applied)
# early stopping patience is 10% of max_epochs (min 10 epochs)
max_epochs = 3000
batch_size = 1024
step_duration = 10
n_subjects = 5

In [3]:
# Data

print(f'Loading N-back dataset... ', end='')

dataset = NBackSRODataset(n_subjects=n_subjects,
                          n_back=2,
                          step_duration=step_duration)

datamodule = CogPonderDataModule(dataset, batch_size=batch_size, num_workers=8)
datamodule.prepare_data()
# determine some parameters from data
n_subjects = torch.unique(dataset[:][0]).size(0)
n_contexts = torch.unique(dataset[:][2]).size(0)
n_features = dataset[:][3].size(-1)
n_outputs = torch.unique(dataset[:][4]).size(0)
max_response_step = dataset[:][5].max().int().item() + 1

configs = {
    'inputs_dim': n_features,
    'outputs_dim': n_outputs,
    'embeddings_dim': 8,
    'time_loss_beta': 10.,
    'learning_rate': 1e-2,
    'max_response_step': max_response_step,
    'n_contexts': n_contexts,
    'n_subjects': n_subjects,
    'subject_embeddings_dim': 2,
    'task': 'nback',
    'operator_type': 'spatiotemporal'
}

print('Done!')

Loading N-back dataset... Done!


In [4]:
# Experiment

model = CogPonderModel(**configs)

# model = torch.compile(model)

# Trainer
trainer = pl.Trainer(
    max_epochs=max_epochs,
    min_epochs=200,
    accelerator='cpu',
    log_every_n_steps=1,
    # overfit_batches=True,
    # accumulate_grad_batches=2,
    callbacks=[
        # EarlyStopping(monitor='val/total_loss',
        #               patience=100,
        #               mode='min', min_delta=0.001),
    ])

# Fit and evaluate the model
trainer.fit(model,
            ckpt_path='models/checkpoints/nback/cogponder_5subjects_2000epochs.ckpt',
            datamodule=datamodule)

# save checkpoint
ckpt_path = f'models/checkpoints/nback/cogponder_{n_subjects}subjects_{trainer.current_epoch}epochs.ckpt'
trainer.save_checkpoint(ckpt_path)

clear_output()

print('Saved checkpoint to:', ckpt_path)

Saved checkpoint to: models/checkpoints/nback/cogponder_5subjects_3000epochs.ckpt
