## Demo code to load the best V1T model and inference on the test set

Please follow the instruction in [README.md](README.md) to setup the conda environment.

In [1]:
import torch
import typing as t
from tqdm import tqdm
from torch.utils.data import DataLoader

from sensorium import data
from sensorium.utils import utils
from sensorium.models import Model
from sensorium.metrics import Metrics
from sensorium.utils.scheduler import Scheduler

dummy arg object to mimic argparse

In [2]:
class Args:
    def __init__(self):
        self.device = torch.device("cpu")
        self.batch_size = 16
        self.output_dir = "runs/best_v1t"
        self.dataset = "data/sensorium"

In [3]:
args = Args()
utils.load_args(args)  # load arguments from checkpoint

load test set

In [4]:
test_ds, _ = data.get_submission_ds(
    args,
    data_dir=args.dataset,
    batch_size=args.batch_size,
    device=args.device,
)

initialize model and restore checkpoint

In [5]:
model = Model(args, ds=test_ds)
model.to(args.device)

scheduler = Scheduler(args, model=model, save_optimizer=False)
_ = scheduler.restore(force=True)


Loaded checkpoint from epoch 120 (correlation: 0.4612).



In [6]:
@torch.no_grad()
def inference(
    ds: DataLoader,
    model: torch.nn.Module,
    batch_size: int,
    device: torch.device = "cpu",
) -> t.Dict[str, torch.Tensor]:
    """Inference data in DataLoader ds
    Returns:
        results: t.Dict[int, t.Dict[str, torch.Tensor]]
            - mouse_id
                - predictions: torch.Tensor, predictions given images
                - targets: torch.Tensor, the ground-truth responses
                - trial_ids: torch.Tensor, trial ID of the responses
                - image_ids: torch.Tensor, image ID of the responses
    """
    results = {
        "predictions": [],
        "targets": [],
        "trial_ids": [],
        "image_ids": [],
    }
    mouse_id = ds.dataset.mouse_id
    model.to(device)
    model.train(False)
    for batch in tqdm(ds, desc=f"Mouse {data.convert_id(mouse_id)}"):
        for micro_batch in data.micro_batching(batch, batch_size=batch_size):
            predictions, _, _ = model(
                inputs=micro_batch["image"].to(device),
                mouse_id=mouse_id,
                behaviors=micro_batch["behavior"].to(device),
                pupil_centers=micro_batch["pupil_center"].to(device),
            )
            results["predictions"].append(predictions.cpu())
            results["targets"].append(micro_batch["response"])
            results["image_ids"].append(micro_batch["image_id"])
            results["trial_ids"].append(micro_batch["trial_id"])
    results = {
        k: torch.cat(v, dim=0) if isinstance(v[0], torch.Tensor) else v
        for k, v in results.items()
    }
    return results

In [None]:
for mouse_id in ["2", "3", "4", "5", "6"]:
    outputs = inference(
        ds=test_ds[mouse_id],
        model=model,
        batch_size=args.batch_size,
        device=args.device,
    )
    metrics = Metrics(ds=test_ds[mouse_id], results=outputs)

    single_trial_correlation = metrics.single_trial_correlation(per_neuron=False)
    correlation_to_average = metrics.correlation_to_average(per_neuron=False)
    print(
        f"single trial correlation: {single_trial_correlation:.03f}\n"
        f"correlation to average: {correlation_to_average:.03f}\n"
    )

Mouse A:   3%|▎         | 2/63 [00:03<02:00,  1.98s/it]