In [None]:
from brainnet.config import get_cfg_defaults

cfg = get_cfg_defaults()

# manually download and unzip subj01.zip from algonauts2023 challenge
# https://docs.google.com/forms/d/e/1FAIpQLSehZkqZOUNk18uTjRTuLj7UYmRGz-OkdsU25AyO3Wm6iAb0VA/viewform
cfg.DATASET.DATA_DIR = "/home/admin/Algonaut/data/algonauts2023/subj02"
cfg.DATASET.BATCH_SIZE = 10

In [None]:
# for google colab
#!pip3 install natten -f https://shi-labs.com/natten/wheels/cu117/torch2.0.0/index.html --quiet 
#!pip install openmim
#!mim install mmdet

In [None]:
import torch
import torch.nn as nn
import gc

#del trainer
#del plmodel
#del backbone

torch.cuda.empty_cache()
gc.collect()

In [None]:
#extract backbone state from checkpoints

def extract_backbone(ckpt):
    state = {}
    for k in ckpt.keys():
        if 'backbone' in k:
            new_k = k.split('backbone.')[1]
            state[new_k] = ckpt[k]
    return state

ckpt_f = '/home/admin/Algonaut/dat_backbones/cmrcn_dat_b_3x.pth'
ckpt = torch.load(ckpt_f)['state_dict']
state = extract_backbone(ckpt)
torch.save(state, '/home/admin/Algonaut/dat_backbones/bkbn_cmrcn_dat_b_3x.pth')

In [None]:
from dat_backbone.dat import DAT
backbone = DAT()
backbone.load_state_dict(torch.load('/home/admin/Algonaut/dat_backbones/bkbn_upn_dat_b_160k.pth'))
cfg.MODEL.LAYERS = list(range(4))
cfg.MODEL.LAYER_WIDTHS = [128, 256, 512, 1024]
cfg.MODEL.BOTTLENECK_DIM = 128  # can be reduced to speed up

In [None]:
from brainnet.plmodel import PLModel
import pytorch_lightning as pl


plmodel = PLModel(
    cfg, 
    backbone, 
    draw=False,  # draw on each epoch end
    cached=False,  # cache the features into cpu memory in first epoch
)
# plmodel.validation_epoch_end() is called on validation epoch to draw
plmodel = plmodel.cuda()

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
        # dirpath=checkpoints_path, # <--- specify this on the trainer itself for version control
        filename="model_{epoch:02d}",
        every_n_epochs=1,
        save_top_k=-1,  # <--- this is important!
    )

In [None]:
trainer = pl.Trainer(
    max_epochs=10,
    accelerator="gpu",
    devices=[0],
    gradient_clip_val=0.5,
    precision=32,  # auto_fp16 already in dat code
    limit_train_batches=1.0,
    limit_val_batches=1.0,
    enable_checkpointing=True,
    callbacks=[checkpoint_callback],
)
trainer.fit(plmodel)
# 40 min on default colab