In [2]:
%load_ext autoreload
%autoreload 2

import sys 
import logging
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "4"

import logging

import hydra
import lightning as L
import torch
from model import (
    BhvrDecoder,
    ContextManager,
    Decoder,
    Encoder,
    MaeMaskManager,
    SpikesPatchifier,
    SslDecoder,
    NDT2Model,
    # MTMMaskManager
)
from transforms import FilterUnit, Ndt2Tokenizer
from train import DataModule
from lightning.pytorch.utilities import CombinedLoader
from omegaconf import OmegaConf, open_dict
from torch import optim
from torchmetrics import R2Score

from train import TrainWrapper, set_callbacks

log = logging.getLogger(__name__)

# probe_superv
# def load_cfg():
#     sys.argv = [sys.argv[0]]
#     cfg = OmegaConf.load("./configs/probe_superv.yaml")
#     dflt_cfg = OmegaConf.load("./configs/_default.yaml")
#     cfg = OmegaConf.merge(cfg, dflt_cfg)
#     cfg.dataset = OmegaConf.load("./configs/dataset/test.yaml")
#     del cfg.defaults
#     return cfg
# cfg.dataset[0].selection[0].sessions = [cfg.dataset[0].selection[0].sessions[0]]

def load_cfg():
    sys.argv = [sys.argv[0]]
    cfg = OmegaConf.load("./configs/train_ssl.yaml")
    dflt_cfg = OmegaConf.load("./configs/_default.yaml")
    cfg = OmegaConf.merge(cfg, dflt_cfg)
    cfg.dataset = OmegaConf.load("./configs/dataset/train.yaml")
    del cfg.defaults
    return cfg

cfg = load_cfg()
cfg.wandb.enable = False

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [10]:
#remove all entries except for brainset inside cfg['dataset'][0].selection[0]
cfg['dataset'][0].selection[0] = {'brainset': 'perich_miller_population_2018'}

In [16]:
cfg

{'batch_size': 512, 'mask_ratio': 0.5, 'is_ssl': True, 'unsorted': True, 'keep_M1_units': True, 'train_ratio': 0.8, 'optimizer': {'scheduler': True, 'lr': 0.0005, 'weight_decay': 0.01, 'start_factor': 0.1, 'warmup_steps': 100, 'decay_steps': 2500, 'lr_min': 1e-06}, 'load_from_checkpoint': False, 'callbacks': {'checkpoint': True, 'early_stop': True, 'patience': 250}, 'wandb': {'enable': False, 'entity': 'aandre8-gatech', 'project': 'reproduce_NDT2', 'run_name': None}, 'seed': 0, 'split_seed': 0, 'superv_batch_size': 512, 'epochs': 800, 'eval_epochs': 1, 'precision': 'bf16-mixed', 'num_workers': 16, 'log_dir': './logs', 'log_every_n_steps': 1, 'fast_dev_run': False, 'num_sanity_val_steps': 0, 'model': {'dim': 256, 'max_time_patches': 256, 'max_space_patches': 256, 'patchifier': {}, 'encoder': {'depth': 6, 'heads': 4, 'dropout': 0.1, 'ffn_mult': 1}, 'predictor': {'depth': 2, 'heads': 4, 'dropout': 0.1, 'ffn_mult': 1}, 'bhv_decoder': {'depth': 2, 'heads': 4, 'dropout': 0.1, 'ffn_mult': 1, 

In [12]:
L.seed_everything(cfg.seed)

if cfg.fast_dev_run:
    cfg.wandb.enable = False
    cfg.num_workers = 0


with open_dict(cfg):
    # Adjust batch size for multi-gpu
    num_gpus = torch.cuda.device_count() + 1
    cfg.batch_size_per_gpu = cfg.batch_size // num_gpus
    cfg.superv_batch_size = cfg.superv_batch_size or cfg.batch_size
    cfg.superv_batch_size_per_gpu = cfg.superv_batch_size // num_gpus
    log.info(f"Number of GPUs: {num_gpus}")
    log.info(f"Batch size per GPU: {cfg.batch_size_per_gpu}")
    log.info(f"Superv batch size per GPU: {cfg.superv_batch_size_per_gpu}")

dim = cfg.model.dim

# Mask manager (for MAE SSL)
mae_mask_manager = None
if cfg.is_ssl:
    mae_mask_manager = MaeMaskManager(cfg.mask_ratio)
    # mae_mask_manager = MTMMaskManager(cfg.mask_ratio)

# context manager
ctx_manager = ContextManager(dim)

# Spikes patchifier
spikes_patchifier = SpikesPatchifier(dim, cfg.patch_size)

# Model = Encoder + Decoder
encoder = Encoder(
    dim=dim,
    max_time_patches=cfg.model.max_time_patches,
    max_space_patches=cfg.model.max_space_patches,
    **cfg.model.encoder,
)

if cfg.is_ssl:
    decoder = SslDecoder(
        dim=dim,
        max_time_patches=cfg.model.max_time_patches,
        max_space_patches=cfg.model.max_space_patches,
        patch_size=cfg.patch_size,
        **cfg.model.predictor,
    )
else:
    decoder = BhvrDecoder(
        dim=dim,
        max_time_patches=cfg.model.max_time_patches,
        max_space_patches=cfg.model.max_space_patches,
        bin_time=cfg.bin_time,
        **cfg.model.bhv_decoder,
    )

model = NDT2Model(
        mae_mask_manager, ctx_manager, spikes_patchifier, encoder, decoder
    )

# Train wrapper
train_wrapper = TrainWrapper(
    cfg, model
)

# Tokenizer
ctx_tokenizer = ctx_manager.get_ctx_tokenizer()
tokenizer = Ndt2Tokenizer(
    ctx_time=cfg.ctx_time,
    bin_time=cfg.bin_time,
    patch_size=cfg.patch_size,
    pad_val=cfg.pad_val,
    ctx_tokenizer=ctx_tokenizer
)


Seed set to 0


In [13]:
# set up data module
data_module = DataModule(cfg, tokenizer, cfg.is_ssl)
data_module.setup()


Seed set to 0


In [14]:

# register context
ctx_manager.init_vocab(data_module.get_ctx_vocab(ctx_manager.keys))

L.seed_everything(cfg.seed)

# Callbacks
callbacks = set_callbacks(cfg)

# Set up trainer
# trainer = L.Trainer(
#     logger=wandb_logger,
#     default_root_dir=cfg.log_dir,
#     check_val_every_n_epoch=cfg.eval_epochs,
#     max_epochs=cfg.epochs,
#     log_every_n_steps=cfg.log_every_n_steps,
#     callbacks=callbacks,
#     accelerator="gpu",
#     precision=cfg.precision,
#     fast_dev_run=cfg.fast_dev_run,
#     num_sanity_val_steps=cfg.num_sanity_val_steps,
#     strategy="ddp_find_unused_parameters_true",
# )

Seed set to 0


In [15]:
[
    "model.ctx_manager.session_emb.weight",
    "model.ctx_manager.subject_emb.weight",
    "model.decoder.cls_token",
    "model.decoder.decoder.transformer.layers.0.self_attn.in_proj_weight",
    "model.decoder.decoder.transformer.layers.0.self_attn.in_proj_bias",
    "model.decoder.decoder.transformer.layers.0.self_attn.out_proj.weight",
    "model.decoder.decoder.transformer.layers.0.self_attn.out_proj.bias",
    "model.decoder.decoder.transformer.layers.0.linear1.weight",
    "model.decoder.decoder.transformer.layers.0.linear1.bias",
    "model.decoder.decoder.transformer.layers.0.linear2.weight",
    "model.decoder.decoder.transformer.layers.0.linear2.bias",
    "model.decoder.decoder.transformer.layers.0.norm1.weight",
    "model.decoder.decoder.transformer.layers.0.norm1.bias",
    "model.decoder.decoder.transformer.layers.0.norm2.weight",
    "model.decoder.decoder.transformer.layers.0.norm2.bias",
    "model.decoder.decoder.transformer.layers.1.self_attn.in_proj_weight",
    "model.decoder.decoder.transformer.layers.1.self_attn.in_proj_bias",
    "model.decoder.decoder.transformer.layers.1.self_attn.out_proj.weight",
    "model.decoder.decoder.transformer.layers.1.self_attn.out_proj.bias",
    "model.decoder.decoder.transformer.layers.1.linear1.weight",
    "model.decoder.decoder.transformer.layers.1.linear1.bias",
    "model.decoder.decoder.transformer.layers.1.linear2.weight",
    "model.decoder.decoder.transformer.layers.1.linear2.bias",
    "model.decoder.decoder.transformer.layers.1.norm1.weight",
    "model.decoder.decoder.transformer.layers.1.norm1.bias",
    "model.decoder.decoder.transformer.layers.1.norm2.weight",
    "model.decoder.decoder.transformer.layers.1.norm2.bias",
    "model.decoder.decoder.positional_encoding.time_emb.weight",
    "model.decoder.decoder.positional_encoding.space_emb.weight",
    "model.decoder.out.weight",
    "model.decoder.out.bias",
]

['model.ctx_manager.session_emb.weight',
 'model.ctx_manager.subject_emb.weight',
 'model.decoder.cls_token',
 'model.decoder.decoder.transformer.layers.0.self_attn.in_proj_weight',
 'model.decoder.decoder.transformer.layers.0.self_attn.in_proj_bias',
 'model.decoder.decoder.transformer.layers.0.self_attn.out_proj.weight',
 'model.decoder.decoder.transformer.layers.0.self_attn.out_proj.bias',
 'model.decoder.decoder.transformer.layers.0.linear1.weight',
 'model.decoder.decoder.transformer.layers.0.linear1.bias',
 'model.decoder.decoder.transformer.layers.0.linear2.weight',
 'model.decoder.decoder.transformer.layers.0.linear2.bias',
 'model.decoder.decoder.transformer.layers.0.norm1.weight',
 'model.decoder.decoder.transformer.layers.0.norm1.bias',
 'model.decoder.decoder.transformer.layers.0.norm2.weight',
 'model.decoder.decoder.transformer.layers.0.norm2.bias',
 'model.decoder.decoder.transformer.layers.1.self_attn.in_proj_weight',
 'model.decoder.decoder.transformer.layers.1.self_at