In [3]:
import logging

import hydra
import lightning as L
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from lightning.pytorch.callbacks import (
    LearningRateMonitor,
    ModelCheckpoint,
    ModelSummary,
)
from lightning.pytorch.profilers import AdvancedProfiler, SimpleProfiler
from omegaconf import DictConfig, OmegaConf
from temporaldata import Data

from torch_brain.registry import MODALITY_REGISTRY, ModalitySpec
from torch_brain.optim import SparseLamb
from torch_brain.models.poyo import POYO
from torch_brain.models.poyo_adap_win import POYOAdapContextPad
from torch_brain.utils import callbacks as tbrain_callbacks
from torch_brain.utils import seed_everything
from torch_brain.utils.stitcher import (
    DecodingStitchEvaluator,
    DataForDecodingStitchEvaluator,
)
from torch_brain.data import Dataset, collate
from torch_brain.data.sampler import (
    DistributedStitchingFixedWindowSampler,
    RandomFixedWindowSampler,
    TrialSampler,
    RandomBoundedWindowSampler,
)
from torch_brain.transforms import Compose

# higher speed on machines with tensor cores
torch.set_float32_matmul_precision("medium")


logger = logging.getLogger(__name__)


class TrainWrapper(L.LightningModule):
    def __init__(
        self,
        cfg: DictConfig,
        model: nn.Module,
        modality_spec: ModalitySpec,
    ):
        super().__init__()

        self.cfg = cfg
        self.model = model
        self.modality_spec = modality_spec
        self.save_hyperparameters(OmegaConf.to_container(cfg))

    def configure_optimizers(self):
        max_lr = self.cfg.optim.base_lr * self.cfg.batch_size  # linear scaling rule

        special_emb_params = list(self.model.unit_emb.parameters()) + list(
            self.model.session_emb.parameters()
        )

        remaining_params = [
            p
            for n, p in self.model.named_parameters()
            if "unit_emb" not in n and "session_emb" not in n
        ]

        optimizer = SparseLamb(
            [
                {"params": special_emb_params, "sparse": True},
                {"params": remaining_params},
            ],
            lr=max_lr,
            weight_decay=self.cfg.optim.weight_decay,
        )

        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=max_lr,
            total_steps=self.trainer.estimated_stepping_batches,
            pct_start=self.cfg.optim.lr_decay_start,
            anneal_strategy="cos",
            div_factor=1,
        )

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
            },
        }

    def training_step(self, batch, batch_idx):

        # forward pass
        output_values = self.model(**batch["model_inputs"])

        # compute loss
        mask = batch["model_inputs"]["output_mask"]
        output_values = output_values[mask]
        target_values = batch["target_values"][mask]
        target_weights = batch["target_weights"][mask]

        loss = self.modality_spec.loss_fn(output_values, target_values, target_weights)

        self.log("train_loss", loss, prog_bar=True)

        # Log batch statistics
        # for name in target_values.keys():
        #     preds = torch.cat([pred[name] for pred in output if name in pred])
        #     self.log(f"predictions/mean_{name}", preds.mean())
        #     self.log(f"predictions/std_{name}", preds.std())

        #     targets = target_values[name].float()
        #     self.log(f"targets/mean_{name}", targets.mean())
        #     self.log(f"targets/std_{name}", targets.std())

        unit_index = batch["model_inputs"]["input_unit_index"].float()
        self.log("inputs/mean_unit_index", unit_index.mean())
        self.log("inputs/std_unit_index", unit_index.std())

        return loss

    def validation_step(self, batch, batch_idx):

        # forward pass
        output_values = self.model(**batch["model_inputs"])

        # prepare data for evaluator
        # (goes to DecodingStitchEvaluator.on_validation_batch_end)
        data_for_eval = DataForDecodingStitchEvaluator(
            timestamps=batch["model_inputs"]["output_timestamps"],
            preds=output_values,
            targets=batch["target_values"],
            eval_masks=batch["eval_mask"],
            session_ids=batch["session_id"],
            absolute_starts=batch["absolute_start"],
        )

        return data_for_eval

    def test_step(self, batch, batch_idx):
        return self.validation_step(batch, batch_idx)


class DataModule(L.LightningDataModule):
    def __init__(self, cfg: DictConfig):
        super().__init__()
        self.cfg = cfg
        self.log = logging.getLogger(__name__)

    def setup_dataset_and_link_model(self, model: POYOAdapContextPad):
        r"""Setup Dataset objects, and update a given model's embedding vocabs (session
        and unit_emb)
        """
        self.sequence_length = model.sequence_length

        train_transforms = hydra.utils.instantiate(self.cfg.train_transforms)
        self.train_dataset = Dataset(
            root=self.cfg.data_root,
            config=self.cfg.dataset,
            split="train",
            transform=Compose([*train_transforms, model.tokenize_varlen]),
        )
        self.train_dataset.disable_data_leakage_check()

        self._init_model_vocab(model)

        eval_transforms = hydra.utils.instantiate(self.cfg.eval_transforms)

        self.val_dataset = Dataset(
            root=self.cfg.data_root,
            config=self.cfg.dataset,
            split="valid",
            transform=Compose([*eval_transforms, model.tokenize]),
        )
        self.val_dataset.disable_data_leakage_check()

        self.test_dataset = Dataset(
            root=self.cfg.data_root,
            config=self.cfg.dataset,
            split="test",
            transform=Compose([*eval_transforms, model.tokenize]),
        )
        self.test_dataset.disable_data_leakage_check()

    def _init_model_vocab(self, model: POYOAdapContextPad):
        # TODO: Add code for finetuning situation (when model already has a vocab)
        model.unit_emb.initialize_vocab(self.get_unit_ids())
        model.session_emb.initialize_vocab(self.get_session_ids())

    def get_session_ids(self):
        return self.train_dataset.get_session_ids()

    def get_unit_ids(self):
        return self.train_dataset.get_unit_ids()

    def get_recording_config_dict(self):
        return self.train_dataset.get_recording_config_dict()

    def train_dataloader(self):
        # train_sampler = TrialSampler(
        #     sampling_intervals=self.train_dataset.get_sampling_intervals(),
        #     generator=torch.Generator().manual_seed(self.cfg.seed + 1),
        # )
        # train_sampler = RandomFixedWindowSampler(
        #     sampling_intervals=self.train_dataset.get_sampling_intervals(),
        #     window_length=self.sequence_length,
        #     generator=torch.Generator().manual_seed(self.cfg.seed + 1),
        # )
        train_sampler = RandomBoundedWindowSampler(
            sampling_intervals=self.train_dataset.get_sampling_intervals(),
            min_window_length=self.sequence_length / 2,
            max_window_length=self.sequence_length,
            generator=torch.Generator().manual_seed(self.cfg.seed + 1),
        )

        train_loader = DataLoader(
            self.train_dataset,
            sampler=train_sampler,
            collate_fn=collate,
            batch_size=self.cfg.batch_size,
            num_workers=self.cfg.num_workers,
            drop_last=True,
            pin_memory=True,
            persistent_workers=True if self.cfg.num_workers > 0 else False,
            prefetch_factor=2 if self.cfg.num_workers > 0 else None,
        )

        self.log.info(f"Training on {len(train_sampler)} samples")
        self.log.info(f"Training on {len(self.train_dataset.get_unit_ids())} units")
        self.log.info(f"Training on {len(self.get_session_ids())} sessions")

        return train_loader

    def val_dataloader(self):
        batch_size = self.cfg.eval_batch_size or self.cfg.batch_size

        val_sampler = DistributedStitchingFixedWindowSampler(
            sampling_intervals=self.val_dataset.get_sampling_intervals(),
            window_length=self.sequence_length,
            step=self.sequence_length / 2,
            batch_size=batch_size,
            num_replicas=self.trainer.world_size,
            rank=self.trainer.global_rank,
        )

        val_loader = DataLoader(
            self.val_dataset,
            sampler=val_sampler,
            shuffle=False,
            batch_size=batch_size,
            collate_fn=collate,
            num_workers=self.cfg.num_workers,
            drop_last=False,
        )

        self.log.info(f"Expecting {len(val_sampler)} validation steps")

        return val_loader

    def test_dataloader(self):
        batch_size = self.cfg.eval_batch_size or self.cfg.batch_size

        test_sampler = DistributedStitchingFixedWindowSampler(
            sampling_intervals=self.test_dataset.get_sampling_intervals(),
            window_length=self.sequence_length,
            step=self.sequence_length / 2,
            batch_size=batch_size,
            num_replicas=self.trainer.world_size,
            rank=self.trainer.global_rank,
        )

        test_loader = DataLoader(
            self.test_dataset,
            sampler=test_sampler,
            shuffle=False,
            batch_size=batch_size,
            collate_fn=collate,
            num_workers=self.cfg.num_workers,
        )

        self.log.info(f"Testing on {len(test_sampler)} samples")

        return test_loader

In [4]:
from pathlib import Path
CKPT_DIR = Path("/home/mila/p/pingsheng.li/scratch/poyo_adp_win_ckpt/")

def find_resume_ckpt(ckpt_dir: Path) -> str | None:
    """Return a checkpoint path if one exists; else None."""
    ckpt_dir.mkdir(parents=True, exist_ok=True)
    last = ckpt_dir / "last.ckpt"        # created if save_last=True (see below)
    if last.exists():
        return str(last)
    # otherwise pick most recent .ckpt by mtime
    cands = sorted(ckpt_dir.glob("*.ckpt"), key=lambda p: p.stat().st_mtime, reverse=True)
    return str(cands[0]) if cands else None

In [None]:
# read config file
from hydra import compose, initialize
from omegaconf import OmegaConf

# For use outside of @hydra.main()
with initialize(version_base=None, config_path="./configs"):
    cfg = compose(config_name="train_poyo_adp_win_mp.yaml")

logger.info("POYO ADP WIN!")
# fix random seed, skipped if cfg.seed is None
seed_everything(cfg.seed)

print(OmegaConf.to_yaml(cfg))

MissingConfigException: Primary config directory not found.
Check that the config directory '/home/mila/p/pingsheng.li/projects/neuro-foundation-model/buildathon/torch_brain/examples/poyo_adp_win/examples/poyo_adp_win/configs' exists and readable

In [None]:
log = logging.getLogger(__name__)
wandb_logger = None
if cfg.wandb.enable:
    wandb_logger = L.pytorch.loggers.WandbLogger(
        save_dir=cfg.log_dir,
        entity=cfg.wandb.entity,
        name=cfg.wandb.run_name,
        project=cfg.wandb.project,
        log_model=cfg.wandb.log_model,
    )

In [None]:
readout_id = cfg.dataset[0].config.readout.readout_id
readout_spec = MODALITY_REGISTRY[readout_id]


In [None]:
model = hydra.utils.instantiate(cfg.model, readout_spec=readout_spec)
data_module = DataModule(cfg)
data_module.setup_dataset_and_link_model(model)



In [None]:
# Lightning train wrapper
wrapper = TrainWrapper(
    cfg=cfg,
    model=model,
    modality_spec=readout_spec,
)

stitch_evaluator = DecodingStitchEvaluator(
    session_ids=data_module.get_session_ids(),
    modality_spec=readout_spec,
)

callbacks = [
    stitch_evaluator,
    ModelSummary(max_depth=2),  # Displays the number of parameters in the model.
    ModelCheckpoint(
        dirpath=str(CKPT_DIR),
        filename="{epoch}-{step}-{average_val_metric:.3f}",
        save_last=True,
        monitor="average_val_metric",
        mode="max",
        save_on_train_epoch_end=True,
        every_n_epochs=cfg.eval_epochs,
    ),
    LearningRateMonitor(logging_interval="step"),
    tbrain_callbacks.MemInfo(),
    tbrain_callbacks.EpochTimeLogger(),
    tbrain_callbacks.ModelWeightStatsLogger(),
]

# adv_profiler = AdvancedProfiler(
#     dirpath=cfg.log_dir,          # where to write the file
#     filename="pl_advanced.txt"    # will contain the per-function timings
# )
simple_profiler = SimpleProfiler(
    dirpath=cfg.log_dir,         # where to write the file
    filename="pl_simple.txt"     # will contain the per-block timings
)

trainer = L.Trainer(
    logger=wandb_logger,
    default_root_dir=cfg.log_dir,
    check_val_every_n_epoch=cfg.eval_epochs,
    max_epochs=cfg.epochs,
    log_every_n_steps=1,
    callbacks=callbacks,
    precision=cfg.precision,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    devices=cfg.gpus,
    num_nodes=cfg.nodes,
    limit_val_batches=None,  # Ensure no limit on validation batches
    num_sanity_val_steps=-1 if cfg.sanity_check_validation else 0,
    profiler=simple_profiler,
)

resume_ckpt = find_resume_ckpt(CKPT_DIR)

# Train
trainer.fit(wrapper, data_module, ckpt_path=resume_ckpt)
# trainer.fit(wrapper, data_module, ckpt_path=cfg.ckpt_path)

# Test
trainer.test(wrapper, data_module, ckpt_path="best")

/home/mila/p/pingsheng.li/projects/neuro-foundation-model/torch_brain/.venv/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python3 /home/mila/p/pingsheng.li/projects/neuro-foundation ...
Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
[34m[1mwandb[0m: Currently logged in as: [33mpingsheng[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.

   | Name                 | Type                   | Params | Mode 
-------------------------------------------------------------------------
0  | model                | POYOAdapContextPad     | 1.9 M  | train
1  | model.unit_emb       | InfiniteVocabEmbedding | 622 K  | train
2  | model.session_emb    | InfiniteVocabEmbedding | 6.4 K  | train
3  | model.token_type_emb | Embedding              | 256    | train
4  | model.latent_emb     | Embedding              | 1.1 K  | train
5  | model.rotary_emb     | RotaryTimeEmbedding    | 0      | train
6  | model.dropout        | Dropout                | 0      | train
7  | model.enc_atn        | RotaryCrossAttention   | 33.1 K | train
8  | model.enc_ffn        | Sequential             | 49.9 K | train
9  | model.proc_layers    | ModuleList             | 1.1 M  | train
10 | model.dec_atn        | RotaryCrossAttention   | 33.1 K | train


Training: |                                                                                      | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

`Trainer.fit` stopped: `max_epochs=1` reached.
Restoring states from the checkpoint path at /network/scratch/p/pingsheng.li/poyo_adp_win_ckpt/epoch=0-step=4105-average_val_metric=-0.014.ckpt
/home/mila/p/pingsheng.li/projects/neuro-foundation-model/torch_brain/.venv/lib/python3.10/site-packages/lightning/fabric/utilities/cloud_io.py:57: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommen

Testing: |                                                                                       | 0/? [00:00<…

[{'perich_miller_population_2018/c_20131003_center_out_reaching': -0.00854647159576416,
  'perich_miller_population_2018/c_20131009_random_target_reaching': -0.003414660692214966,
  'perich_miller_population_2018/c_20131010_random_target_reaching': -0.0016041398048400879,
  'perich_miller_population_2018/c_20131011_random_target_reaching': -0.006069481372833252,
  'perich_miller_population_2018/c_20131022_center_out_reaching': -0.010137677192687988,
  'perich_miller_population_2018/c_20131023_center_out_reaching': -0.014978379011154175,
  'perich_miller_population_2018/c_20131028_random_target_reaching': 2.384185791015625e-07,
  'perich_miller_population_2018/c_20131029_random_target_reaching': -0.0025427937507629395,
  'perich_miller_population_2018/c_20131031_center_out_reaching': -0.00036022067070007324,
  'perich_miller_population_2018/c_20131101_center_out_reaching': -0.003862828016281128,
  'perich_miller_population_2018/c_20131203_center_out_reaching': -0.040819764137268066,
  '

[1;34mwandb[0m: 
[1;34mwandb[0m: 🚀 View run [33mpoyo_adp_win_mp[0m at: [34m[0m
