### Demo and submission notebook

See [README.md#File sturecture](README.md#file-structure) on how to store the dataset and model checkpoints, and follow [README.md#Installation](README.md#installation) for envrionment setup instruction.

In `torch.float32`, a GPU with >=20GB of VRAM is needed to inference the bonus set, while 11GB is fine for the live and final main sets.

In [1]:
import argparse
import os
from typing import Any, Dict, List, Literal

import numpy as np
import torch
import pandas as pd
from einops import rearrange
from torch import nn
from tqdm import tqdm
from torch.utils.data import DataLoader

from viv1t import data
from viv1t.model import Model
from viv1t.utils import utils
from viv1t.metrics import single_trial_correlation

Run configuration
- `DEVICE`: `torch.device` to run the model. In this notebook, we use a single Nvidia RTXC 2080 Ti 11GB to inference the trained model.
- `DATA_DIR`: path to directory where data from the 5 mice are stored, with the following format:
    ```bash
    data/
        sensorium/
            dynamic29515-10-12-Video-9b4f6a1a067fe51e15306b9628efea20.zip
            dynamic29623-4-9-Video-9b4f6a1a067fe51e15306b9628efea20.zip
            dynamic29647-19-8-Video-9b4f6a1a067fe51e15306b9628efea20.zip
            dynamic29712-5-9-Video-9b4f6a1a067fe51e15306b9628efea20.zip
            dynamic29755-2-8-Video-9b4f6a1a067fe51e15306b9628efea20.zip
    ```
    The data reader automatically unzip the files if a folder of the same name does not exist. **Note that we recompute the statistics for normalization/standardization from the training set and store the result to `<mouse_dir>/statistics.pkl`.**
- `PRECISION`: core module computation precision, only `32` (`torch.float32`) and `bf16` (`torch.bfloat16`) are supported. Note that all of our models were trained and inferenced in `torch.bfloat16`.
- `MOUSE_IDS`: we assigned each mouse with a shorter unique ID for convenience, this dictionary provides the mapping to the original mouse ID provided.
- `LIVE_MAIN_FILENAME`: filename to store live main test set parquet submission
- `FINAL_MAIN_FILENAME`: filename to store final main test set parquet submission 

In [2]:
DEVICE = torch.device("cuda:0")
# DEVICE = torch.device("cpu")

DATA_DIR = "data/sensorium"

SKIP = 51

PRECISION = "32"
assert PRECISION in ("bf16", "32")
if PRECISION == "bf16" and not utils.support_bf16(DEVICE):
    raise TypeError(
        f"Device {DEVICE} does not support torch.bfloat16, please use torch.float32 instead."
    )

MOUSE_IDS = {
    "F": "dynamic29515-10-12-Video-9b4f6a1a067fe51e15306b9628efea20",
    "G": "dynamic29623-4-9-Video-9b4f6a1a067fe51e15306b9628efea20",
    "H": "dynamic29647-19-8-Video-9b4f6a1a067fe51e15306b9628efea20",
    "I": "dynamic29712-5-9-Video-9b4f6a1a067fe51e15306b9628efea20",
    "J": "dynamic29755-2-8-Video-9b4f6a1a067fe51e15306b9628efea20",
}

LIVE_MAIN_FILENAME = "predictions_live_main.parquet.brotli"
FINAL_MAIN_FILENAME = "predictions_final_main.parquet.brotli"

#### Dataset configration

Please check `train.py` (`python train.py --help`) for all available options and their function.

In [3]:
args = argparse.Namespace()
args.data = DATA_DIR
args.device = DEVICE
args.precision = PRECISION
args.mouse_ids = list(MOUSE_IDS.keys())
args.batch_size = 1
args.micro_batch_size = 1
args.ds_mode = 3
args.stat_mode = 1
args.transform_mode = 2
args.num_workers = 0
args.verbose = 1

#### Load validation and test sets

Compute data statistics of the training set if this is the first time running and store under `<mouse_dir>/statistics.pkl`.

In [4]:
val_ds, test_ds = data.get_submission_ds(
    args,
    data_dir=args.data,
    mouse_ids=args.mouse_ids,
    batch_size=args.batch_size,
    device=args.device,
)

Unzipping data/sensorium/dynamic29515-10-12-Video-9b4f6a1a067fe51e15306b9628efea20.zip...
Compute statistics in dynamic29515-10-12-Video-9b4f6a1a067fe51e15306b9628efea20...
Unzipping data/sensorium/dynamic29623-4-9-Video-9b4f6a1a067fe51e15306b9628efea20.zip...
Compute statistics in dynamic29623-4-9-Video-9b4f6a1a067fe51e15306b9628efea20...
Unzipping data/sensorium/dynamic29647-19-8-Video-9b4f6a1a067fe51e15306b9628efea20.zip...
Compute statistics in dynamic29647-19-8-Video-9b4f6a1a067fe51e15306b9628efea20...
Unzipping data/sensorium/dynamic29712-5-9-Video-9b4f6a1a067fe51e15306b9628efea20.zip...
Compute statistics in dynamic29712-5-9-Video-9b4f6a1a067fe51e15306b9628efea20...
Unzipping data/sensorium/dynamic29755-2-8-Video-9b4f6a1a067fe51e15306b9628efea20.zip...
Compute statistics in dynamic29755-2-8-Video-9b4f6a1a067fe51e15306b9628efea20...


Validation DataLoaders have the format of `{mouse_id: DataLoader}`

In [5]:
print(val_ds)

{'F': <torch.utils.data.dataloader.DataLoader object at 0x7f5ec7a27b90>, 'G': <torch.utils.data.dataloader.DataLoader object at 0x7f5ec8b07cd0>, 'H': <torch.utils.data.dataloader.DataLoader object at 0x7f5ec7a4c750>, 'I': <torch.utils.data.dataloader.DataLoader object at 0x7f5ec7759110>, 'J': <torch.utils.data.dataloader.DataLoader object at 0x7f5ec7779f10>}


Test DataLoaders have the format of `{tier: {mouse_id: DataLoader}}` where `tier` is `['live_main', 'live_bonus', 'final_main', 'final_bonus']`.

In [6]:
print(test_ds.keys())
print(test_ds["live_main"].keys())

dict_keys(['live_main', 'live_bonus', 'final_main', 'final_bonus'])
dict_keys(['F', 'G', 'H', 'I', 'J'])


#### Ensemble model that returns mean response over all models

In [7]:
class EnsembleModel(nn.Module):
    def __init__(
        self,
        args: Any,
        saved_models: Dict[str, str],
        neuron_coordinates: Dict[str, torch.Tensor],
    ):
        super(EnsembleModel, self).__init__()
        self.input_shapes = args.input_shapes
        self.output_shapes = args.output_shapes
        self.ensemble = nn.ModuleDict()
        for name, output_dir in saved_models.items():
            self.ensemble[name] = self.load_model(
                output_dir,
                neuron_coordinates=neuron_coordinates,
                device=args.device,
                precision=args.precision,
            )
        self.ensemble.requires_grad_(False)

    def load_model(
        self,
        output_dir: str,
        neuron_coordinates: Dict[str, torch.Tensor],
        device: torch.device,
        precision: Literal["bf16", "32"] = None,
    ):
        # load model configuration and initialize model
        model_args = argparse.Namespace()
        model_args.output_dir = output_dir
        model_args.device = device
        model_args.precision = precision
        utils.load_args(model_args)
        model = Model(model_args, neuron_coordinates=neuron_coordinates)
        # load checkpoint dictionary to CPU
        filename = os.path.join(output_dir, "ckpt", "model_state.pt")
        ckpt = torch.load(filename, map_location="cpu")
        # restore weights from checkpoint that exists in current model
        state_dict = model.state_dict()
        state_dict.update({k: v for k, v in ckpt["model"].items() if k in state_dict})
        model.load_state_dict(state_dict)
        print(
            f"Loaded checkpoint from {output_dir} (validation correlation: {ckpt['value']:.04f})."
        )
        del ckpt, model_args
        return model

    @torch.inference_mode()
    def forward(
        self,
        inputs: torch.Tensor,
        mouse_id: str,
        behaviors: torch.Tensor,
        pupil_centers: torch.Tensor,
    ):
        outputs = []
        for name in self.ensemble.keys():
            y_pred, _ = self.ensemble[name](
                inputs,
                mouse_id=mouse_id,
                behaviors=behaviors,
                pupil_centers=pupil_centers,
            )
            outputs.append(rearrange(y_pred, "b t n -> b t n 1"))
        outputs = torch.cat(outputs, dim=-1)
        outputs = torch.mean(outputs, dim=-1)
        return outputs

##### Initialize ensemble model of 5 ViViT models trained with different random seeds
- `saved_models` specifies the name of the model and the path to the model checkpoint.
- `neuron_coordinates` is a dictionary of `{mouse_id: neuron_coordinates}` from `cell_motor_coordinates.npy`

In [8]:
model = EnsembleModel(
    args,
    saved_models={
        "001": "runs/viv1t_001",
        "002": "runs/viv1t_002",
        "003": "runs/viv1t_003",
        "004": "runs/viv1t_004",
        "005": "runs/viv1t_005",
    },
    neuron_coordinates={
        mouse_id: ds.dataset.neuron_coordinates for mouse_id, ds in val_ds.items()
    },
)

Loaded checkpoint from runs/viv1t_001 (validation correlation: 0.2507).
Loaded checkpoint from runs/viv1t_002 (validation correlation: 0.2530).
Loaded checkpoint from runs/viv1t_003 (validation correlation: 0.2509).
Loaded checkpoint from runs/viv1t_004 (validation correlation: 0.2494).
Loaded checkpoint from runs/viv1t_005 (validation correlation: 0.2495).


In [9]:
@torch.inference_mode()
def inference(ds: DataLoader, model: nn.Module, device: torch.device) -> Dict[str, Any]:
    """
    Inference data in DataLoader and return dictionary with entries for submission.

    Given the test sets have variable frames, we therefore inference 1 sample
    at a time and return a list of (N, T) Tensor.

    Returns:
        result: Dict[str, torch.Tensor]
            - prediction: List[torch.Tensor], list predicted responses in (N, T)
            - response: List[torch.Tensor], list of recorded responses in (N, T)
            - mouse: List[str], list of original mouse IDs
            - trial_indices: List[int], list of trial indices
            - neuron_ids: List[List[int]], list of neuron IDs
    """
    model = model.to(device)
    model.train(False)
    result = {"prediction": [], "response": []}
    mouse_id = ds.dataset.mouse_id
    to_batch = lambda x: torch.unsqueeze(x, dim=0).to(device)
    for i in tqdm(range(len(ds.dataset.trial_ids))):
        sample = ds.dataset.__getitem__(i, to_tensor=True)
        t = sample["video"].shape[1] - SKIP
        predictions = model(
            inputs=to_batch(sample["video"]),
            mouse_id=mouse_id,
            behaviors=to_batch(sample["behavior"]),
            pupil_centers=to_batch(sample["pupil_center"]),
        )
        result["prediction"].append(predictions[0, :, -t:].cpu().numpy())
        result["response"].append(sample["response"][:, -t:].cpu().numpy())
    # metadata for submission
    num_trials = len(result["prediction"])
    result["mouse"] = [MOUSE_IDS[mouse_id]] * num_trials
    result["trial_indices"] = ds.dataset.trial_ids.tolist()
    result["neuron_ids"] = [ds.dataset.neuron_ids.tolist()] * num_trials
    return result

#### Inference validation set
We recorded an average validation correlation of 0.2608 when inferencing in `torch.bfloat16`.

In [10]:
val_corrs = {}
for mouse_id in MOUSE_IDS.keys():
    print(f"Mouse {mouse_id} ({MOUSE_IDS[mouse_id]})")
    result = inference(ds=val_ds[mouse_id], model=model, device=args.device)
    val_corrs[mouse_id] = single_trial_correlation(
        y_true=result["response"], y_pred=result["prediction"]
    ).item()
    print(f"Mouse {mouse_id} validation correlation: {val_corrs[mouse_id]:.04f}\n")

Mouse F (dynamic29515-10-12-Video-9b4f6a1a067fe51e15306b9628efea20)


100%|██████████| 58/58 [01:18<00:00,  1.36s/it]


Mouse F validation correlation: 0.2472

Mouse G (dynamic29623-4-9-Video-9b4f6a1a067fe51e15306b9628efea20)


100%|██████████| 56/56 [01:28<00:00,  1.58s/it]


Mouse G validation correlation: 0.2683

Mouse H (dynamic29647-19-8-Video-9b4f6a1a067fe51e15306b9628efea20)


100%|██████████| 60/60 [01:26<00:00,  1.45s/it]


Mouse H validation correlation: 0.2475

Mouse I (dynamic29712-5-9-Video-9b4f6a1a067fe51e15306b9628efea20)


100%|██████████| 60/60 [01:31<00:00,  1.52s/it]


Mouse I validation correlation: 0.2701

Mouse J (dynamic29755-2-8-Video-9b4f6a1a067fe51e15306b9628efea20)


100%|██████████| 59/59 [01:28<00:00,  1.50s/it]


Mouse J validation correlation: 0.2706



In [11]:
val_corr = np.mean(list(val_corrs.values()))
print(f"average validation single trial correlation: {val_corr:.04f}")

average validation single trial correlation: 0.2608


#### Function to inference and create parquet file for submission

Note that compressing the output responses to a parquet file can take more than 30mins.

In [12]:
def create_parquet(
    mouse_ids: List[str],
    ds: Dict[str, DataLoader],
    model: EnsembleModel,
    device: torch.device,
    filename: str,
) -> pd.DataFrame:
    """
    Inference dataset and create parquet file for submission

    Args:
        mouse_ids: List[str], list of mouse IDs to inference
        ds: Dict[str, DataLoader], dictionart of DataLoader for each mouse
        model: EnsembleModel, model to inference
        device: torch.device, device to inference on
        filename: str, filename to save parquet file
    Returns:
        df: pd.DataFrame, DataFrame containing submission data for mouse_ids
    """
    df = []
    for mouse_id in mouse_ids:
        print(f"Mouse {mouse_id} ({MOUSE_IDS[mouse_id]})")
        result = inference(ds=ds[mouse_id], model=model, device=device)
        del result["response"]
        # convert list of np.ndarray to list of list of float
        result["prediction"] = [v.tolist() for v in result["prediction"]]
        df.append(pd.DataFrame(result))
        print("")
    df = pd.concat(df, ignore_index=True)

    # create folder if not exists
    dirname = os.path.dirname(filename)
    if dirname and not os.path.isdir(dirname):
        os.makedirs(dirname)

    # create parquet file
    print(f"Creating parquet file...")
    df.to_parquet(filename, compression="brotli", engine="pyarrow", index=False)
    print(f"Saved parquet file to {filename}")
    return df

#### Inference live main test set

The `parquet` files for live main and final main submissions will be saved to `LIVE_MAIN_FILENAME` and `FINAL_MAIN_FILENAME`.

In [13]:
live_main_df = create_parquet(
    mouse_ids=MOUSE_IDS,
    ds=test_ds["live_main"],
    model=model,
    device=args.device,
    filename=LIVE_MAIN_FILENAME,
)

Mouse F (dynamic29515-10-12-Video-9b4f6a1a067fe51e15306b9628efea20)


100%|██████████| 56/56 [01:19<00:00,  1.41s/it]



Mouse G (dynamic29623-4-9-Video-9b4f6a1a067fe51e15306b9628efea20)


100%|██████████| 53/53 [01:22<00:00,  1.56s/it]



Mouse H (dynamic29647-19-8-Video-9b4f6a1a067fe51e15306b9628efea20)


100%|██████████| 60/60 [01:25<00:00,  1.42s/it]



Mouse I (dynamic29712-5-9-Video-9b4f6a1a067fe51e15306b9628efea20)


100%|██████████| 60/60 [01:31<00:00,  1.52s/it]



Mouse J (dynamic29755-2-8-Video-9b4f6a1a067fe51e15306b9628efea20)


100%|██████████| 60/60 [01:25<00:00,  1.43s/it]



Creating parquet file...
Saved parquet file to predictions_live_main.parquet.brotli


In [14]:
live_main_df.head()

Unnamed: 0,prediction,mouse,trial_indices,neuron_ids
0,"[[0.29032811522483826, 0.27776122093200684, 0....",dynamic29515-10-12-Video-9b4f6a1a067fe51e15306...,9,"[1, 3, 4, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, ..."
1,"[[0.9097850918769836, 0.84004145860672, 0.6410...",dynamic29515-10-12-Video-9b4f6a1a067fe51e15306...,13,"[1, 3, 4, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, ..."
2,"[[0.07010897248983383, 0.06662942469120026, 0....",dynamic29515-10-12-Video-9b4f6a1a067fe51e15306...,17,"[1, 3, 4, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, ..."
3,"[[0.7700170874595642, 0.7540055513381958, 0.70...",dynamic29515-10-12-Video-9b4f6a1a067fe51e15306...,57,"[1, 3, 4, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, ..."
4,"[[0.34478139877319336, 0.32067403197288513, 0....",dynamic29515-10-12-Video-9b4f6a1a067fe51e15306...,58,"[1, 3, 4, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, ..."


#### Inference final main test set

In [15]:
final_main_df = create_parquet(
    mouse_ids=MOUSE_IDS,
    ds=test_ds["final_main"],
    model=model,
    device=args.device,
    filename=FINAL_MAIN_FILENAME,
)

Mouse F (dynamic29515-10-12-Video-9b4f6a1a067fe51e15306b9628efea20)


100%|██████████| 57/57 [01:18<00:00,  1.37s/it]



Mouse G (dynamic29623-4-9-Video-9b4f6a1a067fe51e15306b9628efea20)


100%|██████████| 56/56 [01:21<00:00,  1.46s/it]



Mouse H (dynamic29647-19-8-Video-9b4f6a1a067fe51e15306b9628efea20)


100%|██████████| 59/59 [01:32<00:00,  1.57s/it]



Mouse I (dynamic29712-5-9-Video-9b4f6a1a067fe51e15306b9628efea20)


100%|██████████| 60/60 [01:26<00:00,  1.44s/it]



Mouse J (dynamic29755-2-8-Video-9b4f6a1a067fe51e15306b9628efea20)


100%|██████████| 60/60 [01:35<00:00,  1.60s/it]



Creating parquet file...
Saved parquet file to predictions_final_main.parquet.brotli


In [16]:
final_main_df.head()

Unnamed: 0,prediction,mouse,trial_indices,neuron_ids
0,"[[0.22323818504810333, 0.199410542845726, 0.21...",dynamic29515-10-12-Video-9b4f6a1a067fe51e15306...,4,"[1, 3, 4, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, ..."
1,"[[1.3814209699630737, 1.6150273084640503, 1.57...",dynamic29515-10-12-Video-9b4f6a1a067fe51e15306...,7,"[1, 3, 4, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, ..."
2,"[[0.028004322201013565, 0.028461778536438942, ...",dynamic29515-10-12-Video-9b4f6a1a067fe51e15306...,19,"[1, 3, 4, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, ..."
3,"[[0.22700917720794678, 0.2122136652469635, 0.2...",dynamic29515-10-12-Video-9b4f6a1a067fe51e15306...,20,"[1, 3, 4, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, ..."
4,"[[0.023461228236556053, 0.022433314472436905, ...",dynamic29515-10-12-Video-9b4f6a1a067fe51e15306...,22,"[1, 3, 4, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, ..."
