In [1]:
import matplotlib.pyplot as plt
import numpy as np
import os
import seaborn as sns

### Load model

In [2]:
from in_silico.model.mlflow_loader import ModelPaths, DataPaths, load_free_viewing_model_from_mlflow

model_paths = ModelPaths(
    checkpoint_uri="mlflow-artifacts:/621818231566971674/2f85fd6f5dda46e280456d3186618e1c/artifacts/6806be20120f307fa684cd4c637ad949_final.pth.tar",
    config_uri="mlflow-artifacts:/621818231566971674/2f85fd6f5dda46e280456d3186618e1c/artifacts/6806be20120f307fa684cd4c637ad949_final_cfg.pth.tar",
)
data_paths = DataPaths(session_dirs=["/mnt/data1/enigma/goliath_10_20_sandbox/37_3843837605846_0_V3A_V4/"])

out = load_free_viewing_model_from_mlflow(
    model_paths,
    data_paths,
    cuda_visible_devices="9",
    mlflow_tracking_uri="https://mlflow.enigmatic.stanford.edu/",
    mlflow_username="mlflow-runner",
    mlflow_password="x3i#U9*73N75",
)

Skipping import of cpp extensions due to incompatible torch version 2.7.0a0+7c8ec84dab.nv25.03 for torchao version 0.16.0             Please see https://github.com/pytorch/ao/issues/2919 for more info


Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]



Dataset 0: 37_3843837605846_0_V3A_V4, length = 215312
Sessions: ['37_3843837605846_0_V3A_V4']
Batches per session: {'37_3843837605846_0_V3A_V4': 26914}
Total batches: 26914
Created FastSessionDataLoader with 1 sessions and 26914 total batches
Dataset 0: 37_3843837605846_0_V3A_V4, length = 22977
Sessions: ['37_3843837605846_0_V3A_V4']
Batches per session: {'37_3843837605846_0_V3A_V4': 2872}
Total batches: 2872
Created FastSessionDataLoader with 1 sessions and 2872 total batches


In [3]:
from in_silico.model.wrapper import ModelWrapper

# out is the tuple returned by load_free_viewing_model_from_mlflow(...)
model, skip_samples, cfg, extra = out   # <-- now cfg exists

# pick skip_samples: prefer the returned value if it exists
# (some codebases return it explicitly, and cfg may or may not have it)
if skip_samples is None:
    skip_samples = cfg.trainer.skip_n_samples

wrapper = ModelWrapper(model, skip_samples=skip_samples)

### Load indices

In [4]:
indices_v3a = np.load('/workdir/analysis_parametric/indices_v3a.npy')

### Dotmapping analysis

In [None]:
from in_silico.analyses.dotmapping import (
    predict_responses,
    compute_sta,
    sta_to_rgb,
    compute_spatial_sta,
)


In [None]:
frames_all, pred_all, avg_resp_all, seeds = predict_responses(
    wrapper,
    key="37_3843837605846_0_V3A_V4",
    num_samples=12,
    dot_offset_samples=3,
    dot_duration_samples=6,
    fps=30.0,
    square_size_px=25,
    dots_per_frame=100,
    base_seed=61,
    N=5000,
    batch_size=10,
    win_start=2,
    win_dur=6,
    ds_factor=4,
)


In [None]:
sta, denom, stim_mean = compute_sta(
    frames_all,
    pred_all,
    t_frame=5,
    zscore_normalize=True,
)

np.save("sta_spatial.npy", sta)


### Visualise

In [None]:
idx = np.random.choice(indices_v3a)
rgb = sta_to_rgb(sta[idx], mode="robust", p=99.5)

import seaborn as sns
sample_ids = np.linspace(0, pred_all.shape[0] - 1, 100).astype(int)
plt.plot(pred_all[sample_ids, idx, :].T, color="gray", alpha=0.5, lw=0.8)
sns.despine()
print(f"Neuron {idx}")


In [None]:
plt.imshow(rgb)
plt.axis("off")


In [None]:
n_neurons = 7
idxs = [50, 459, 88, 392, 41, 112, 110]

fig, axes = plt.subplots(n_neurons, 2, figsize=(8, n_neurons * 2),
                         gridspec_kw={"width_ratios": [1, 1]})

sample_ids = np.linspace(0, pred_all.shape[0] - 1, 100).astype(int)

for row, idx in enumerate(idxs):
    ax_trace = axes[row, 0]
    ax_trace.plot(pred_all[sample_ids, idx, :].T, color="gray", alpha=0.5, lw=0.8)
    sns.despine(ax=ax_trace)
    ax_trace.set_title(f"Neuron {idx}", fontsize=12, loc="left")
    if row == 0:
        ax_trace.set_title(f"Neuron {idx}  —  Predicted activity", fontsize=12, loc="left")

    n_timepoints = pred_all.shape[2]
    times_ms = np.arange(n_timepoints) * (1000 / 30)
    ax_trace.set_xticks(np.arange(n_timepoints))
    ax_trace.set_xticklabels([f"{t:.0f}" for t in times_ms], fontsize=11, rotation=45, ha="right")
    if row == n_neurons - 1:
        ax_trace.set_xlabel("Time from stim onset (ms)", fontsize=12)
    else:
        ax_trace.set_xticklabels([])

    ax_sta = axes[row, 1]
    rgb = sta_to_rgb(sta[idx], mode="robust", p=99.5, gamma=1.0)
    ax_sta.imshow(rgb)
    ax_sta.axis("off")
    if row == 0:
        ax_sta.set_title("STA", fontsize=12, loc="left")

plt.tight_layout()
plt.savefig("Dotmap_RFs.png", dpi=300)
