In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

## Drifting grating analysis

This notebook estimates, for each neuron:

1. **Optimal grating parameters** – the combination of spatial frequency, temporal
   frequency and colour axis that drives the neuron most strongly.
2. **Direction tuning curve** – mean response at 12 evenly-spaced directions (30° steps)
   presented at each neuron's optimal parameters.
3. **Selectivity indices** – Direction Selectivity Index (DSI) and Orientation Selectivity
   Index (OSI), plus polar-plot visualisations.

### Colour axes

| Condition | Description | RGB modulation |
|-----------|-------------|----------------|
| `achromatic` | Luminance grating | Equal ΔR = ΔG = ΔB |
| `lm` | L−M isoluminant (red–green) | ΔR = +1, ΔG ≈ −0.51, ΔB = 0 |
| `s` | S-cone isoluminant (blue–yellow) | ΔR = ΔG ≈ −0.13, ΔB = +1 |

All colour vectors are normalised to unit length; `contrast` controls the modulation
amplitude uniformly across conditions.

### Sliding-window prediction

The model accepts 12-frame windows and returns 9 valid predictions (`skip_samples = 3`).
For the 2-second presentations used here (60 valid frames at 30 fps) the stimulus is
split into 7 consecutive non-overlapping windows; their predictions are concatenated
and averaged to give one scalar response per neuron.

### Load model

In [None]:
from in_silico.model.mlflow_loader import ModelPaths, DataPaths, load_free_viewing_model_from_mlflow

model_paths = ModelPaths(
    checkpoint_uri="mlflow-artifacts:/621818231566971674/2f85fd6f5dda46e280456d3186618e1c/artifacts/6806be20120f307fa684cd4c637ad949_final.pth.tar",
    config_uri="mlflow-artifacts:/621818231566971674/2f85fd6f5dda46e280456d3186618e1c/artifacts/6806be20120f307fa684cd4c637ad949_final_cfg.pth.tar",
)

data_paths = DataPaths(session_dirs=["/mnt/data1/enigma/goliath_10_20_sandbox/37_3843837605846_0_V3A_V4/"])

out = load_free_viewing_model_from_mlflow(
    model_paths,
    data_paths,
    cuda_visible_devices="9",
    mlflow_tracking_uri="https://mlflow.enigmatic.stanford.edu/",
    mlflow_username="mlflow-runner",
    mlflow_password="x3i#U9*73N75",
)

In [None]:
from in_silico.model.wrapper import ModelWrapper

model, skip_samples, cfg, extra = out

if skip_samples is None:
    skip_samples = cfg.trainer.skip_n_samples

wrapper = ModelWrapper(model=model, skip_samples=skip_samples)
print(f"skip_samples = {skip_samples}")
# With num_frames=12 and skip_samples=3, T_pred = 12 - 3 = 9 predicted frames.

### Load neuron indices

In [None]:
indices_v3a = np.load('/workdir/analysis_parametric/indices_v3a.npy')
print(f"{len(indices_v3a)} V3A/V4 neurons")

### Preview: grating stimulus

Each condition is a sinusoidal grating defined by direction, spatial frequency,
temporal frequency and colour axis.  Below we generate a single grating and
display several consecutive frames to confirm the drift.

In [None]:
from in_silico.stimuli.drifting_grating import DriftingGratingSpec, make_drifting_grating, COLORS

preview_spec = DriftingGratingSpec(
    num_frames=12,
    direction_deg=45.0,    # diagonal drift
    sf_cpd=2.0,            # 2 cycles per degree
    tf_hz=4.0,             # 4 Hz
    color="achromatic",
    contrast=0.5,
    mean_lum=0.5,
)

preview_frames = make_drifting_grating(preview_spec)  # (T, 3, H, W)
print(f"frames shape: {preview_frames.shape}  dtype: {preview_frames.dtype}")
print(f"value range: [{preview_frames.min():.3f}, {preview_frames.max():.3f}]")

In [None]:
# Show 6 frames to visualise the temporal drift
fig, axes = plt.subplots(2, 6, figsize=(18, 6))
for t, ax in enumerate(axes.flat):
    ax.imshow(preview_frames[t].transpose(1, 2, 0))  # (H, W, 3)
    ax.set_title(f"frame {t}")
    ax.axis("off")
plt.suptitle(
    f"Drifting grating – dir={preview_spec.direction_deg}°,  "
    f"SF={preview_spec.sf_cpd} cpd,  TF={preview_spec.tf_hz} Hz,  "
    f"color={preview_spec.color}",
    y=1.01,
)
plt.tight_layout()
plt.show()

In [None]:
# Compare colour conditions (first frame of each)
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
for ax, color in zip(axes, COLORS):
    spec = DriftingGratingSpec(
        num_frames=12,
        direction_deg=0.0,
        sf_cpd=2.0,
        tf_hz=4.0,
        color=color,
        contrast=0.5,
    )
    frame = make_drifting_grating(spec)[0]  # first frame, (3, H, W)
    ax.imshow(frame.transpose(1, 2, 0))
    ax.set_title(f"color = '{color}'", fontsize=12)
    ax.axis("off")
plt.suptitle("Colour conditions – same SF/TF/direction, SF=2 cpd, TF=4 Hz, dir=0°")
plt.tight_layout()
plt.show()

### Step-by-step pipeline

#### Phase 1: sweep SF × TF × color

We test all combinations of spatial frequency, temporal frequency and colour
axis at four cardinal reference directions (0°, 90°, 180°, 270°).  Responses are
averaged over the reference directions so that direction preference does not bias
the parameter selection.

In [None]:
from in_silico.analyses.drifting_grating import (
    sweep_grating_params,
    find_optimal_params,
    sweep_directions,
    compute_dsi,
    compute_osi,
    plot_polar_tuning,
    DEFAULT_SF_CPD,
    DEFAULT_TF_HZ,
    DEFAULT_COLORS,
)

KEY = "37_3843837605846_0_V3A_V4"
WIN_LEN = 12          # model input window (frames)
SKIP_FRAMES = skip_samples   # = 3; no valid predictions for first 3 frames
N_SECONDS = 2.0       # grating duration per condition

# Phase 1 – parameter sweep
param_responses = sweep_grating_params(
    wrapper,
    sf_cpd_list=DEFAULT_SF_CPD,      # (0.5, 1, 2, 4, 8) cpd
    tf_hz_list=DEFAULT_TF_HZ,        # (1, 2, 4, 8) Hz
    colors=DEFAULT_COLORS,           # achromatic, lm, s
    ref_directions_deg=(0., 90., 180., 270.),
    contrast=0.5,
    mean_lum=0.5,
    n_seconds=N_SECONDS,
    response_start_s=0.0,
    key=KEY,
    win_len=WIN_LEN,
    skip_frames=SKIP_FRAMES,
)

print(f"param_responses shape: {param_responses.shape}")
print("  → (n_sf, n_tf, n_color, n_ref_dir, U)")

In [None]:
# Find optimal (SF, TF, color) per neuron
optimal_indices = find_optimal_params(param_responses)  # (U, 3)
print(f"optimal_indices shape: {optimal_indices.shape}")

sf_list = list(DEFAULT_SF_CPD)
tf_list = list(DEFAULT_TF_HZ)
color_list = list(DEFAULT_COLORS)

# Summarise distribution of optimal parameters across neurons
print("\nOptimal SF distribution:")
for i, sf in enumerate(sf_list):
    n = (optimal_indices[:, 0] == i).sum()
    print(f"  {sf:4.1f} cpd : {n} neurons")

print("\nOptimal TF distribution:")
for i, tf in enumerate(tf_list):
    n = (optimal_indices[:, 1] == i).sum()
    print(f"  {tf:4.1f} Hz  : {n} neurons")

print("\nOptimal color distribution:")
for i, col in enumerate(color_list):
    n = (optimal_indices[:, 2] == i).sum()
    print(f"  {col:<12s}: {n} neurons")

#### Phase 2: direction sweep at optimal parameters

For each unique (SF, TF, color) combination that is optimal for at least one neuron
we present gratings at 12 directions (0°, 30°, …, 330°) and collect the mean response.
The model predicts all neurons simultaneously, so only one forward pass per direction
per unique parameter combination is needed.

In [None]:
direction_responses, directions_deg = sweep_directions(
    wrapper,
    optimal_indices,
    sf_cpd_list=DEFAULT_SF_CPD,
    tf_hz_list=DEFAULT_TF_HZ,
    colors=DEFAULT_COLORS,
    n_directions=12,           # 12 × 30° = full 360°
    contrast=0.5,
    mean_lum=0.5,
    n_seconds=N_SECONDS,
    response_start_s=0.0,
    key=KEY,
    win_len=WIN_LEN,
    skip_frames=SKIP_FRAMES,
)

print(f"direction_responses shape: {direction_responses.shape}")
print(f"directions_deg: {directions_deg}")

#### Tuning metrics

In [None]:
dsi, preferred_dir_deg = compute_dsi(direction_responses, directions_deg)
osi, preferred_ori_deg = compute_osi(direction_responses, directions_deg)

print(f"DSI – mean: {dsi.mean():.3f}  median: {np.median(dsi):.3f}  range: [{dsi.min():.3f}, {dsi.max():.3f}]")
print(f"OSI – mean: {osi.mean():.3f}  median: {np.median(osi):.3f}  range: [{osi.min():.3f}, {osi.max():.3f}]")

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].hist(dsi, bins=30, color="steelblue", edgecolor="white")
axes[0].set_xlabel("DSI")
axes[0].set_ylabel("# neurons")
axes[0].set_title("Direction Selectivity Index")

axes[1].hist(osi, bins=30, color="darkorange", edgecolor="white")
axes[1].set_xlabel("OSI")
axes[1].set_ylabel("# neurons")
axes[1].set_title("Orientation Selectivity Index")

axes[2].hist(preferred_dir_deg, bins=np.linspace(0, 360, 25), color="seagreen", edgecolor="white")
axes[2].set_xlabel("Preferred direction (°)")
axes[2].set_ylabel("# neurons")
axes[2].set_title("Preferred direction distribution")

for ax in axes:
    sns.despine(ax=ax)

plt.tight_layout()
plt.show()

#### Polar plots – direction tuning curves

In [None]:
# Select 12 neurons with the highest DSI for display
top_dsi_neurons = np.argsort(dsi)[::-1][:12]

# Build panel titles: show optimal params + DSI
panel_titles = []
for u in top_dsi_neurons:
    sf_i, tf_i, col_i = optimal_indices[u]
    sf = sf_list[sf_i]
    tf = tf_list[tf_i]
    col = color_list[col_i]
    panel_titles.append(
        f"n{u} | {sf} cpd {tf} Hz\n{col} | DSI={dsi[u]:.2f}"
    )

fig = plot_polar_tuning(
    direction_responses,
    directions_deg,
    neuron_indices=top_dsi_neurons,
    titles=panel_titles,
    n_cols=4,
    normalize=True,
)
fig.suptitle("Top-12 neurons by DSI – direction tuning curves", y=1.02, fontsize=13)
plt.show()

In [None]:
# Also show the 12 most orientation-selective neurons
top_osi_neurons = np.argsort(osi)[::-1][:12]

panel_titles_osi = []
for u in top_osi_neurons:
    sf_i, tf_i, col_i = optimal_indices[u]
    sf = sf_list[sf_i]
    tf = tf_list[tf_i]
    col = color_list[col_i]
    panel_titles_osi.append(
        f"n{u} | {sf} cpd {tf} Hz\n{col} | OSI={osi[u]:.2f}"
    )

fig2 = plot_polar_tuning(
    direction_responses,
    directions_deg,
    neuron_indices=top_osi_neurons,
    titles=panel_titles_osi,
    n_cols=4,
    normalize=True,
)
fig2.suptitle("Top-12 neurons by OSI – direction tuning curves", y=1.02, fontsize=13)
plt.show()

### Full pipeline (single call)

Once you are happy with the parameters, use `run_drifting_grating_analysis` to run
the full pipeline and optionally save all results to a `.npz` file.

In [None]:
from in_silico.analyses.drifting_grating import run_drifting_grating_analysis

results = run_drifting_grating_analysis(
    wrapper,
    key=KEY,
    # Stimulus
    contrast=0.5,
    mean_lum=0.5,
    n_seconds=2.0,
    response_start_s=0.0,
    # Phase 1
    sf_cpd_list=DEFAULT_SF_CPD,
    tf_hz_list=DEFAULT_TF_HZ,
    colors=DEFAULT_COLORS,
    ref_directions_deg=(0., 90., 180., 270.),
    # Phase 2
    n_directions=12,
    # Misc
    win_len=12,
    skip_frames=skip_samples,
    output_path="../results/drifting_grating.npz",  # set to None to skip saving
    show_progress=True,
)

print("Keys:", list(results.keys()))
print(f"direction_responses : {results['direction_responses'].shape}")
print(f"dsi mean/median     : {results['dsi'].mean():.3f} / {np.median(results['dsi']):.3f}")
print(f"osi mean/median     : {results['osi'].mean():.3f} / {np.median(results['osi']):.3f}")

In [None]:
# Reload from saved .npz (demonstrating round-trip)
# loaded = np.load("../results/drifting_grating.npz", allow_pickle=True)
# direction_responses = loaded["direction_responses"]
# directions_deg = loaded["directions_deg"]
# dsi = loaded["dsi"]
# osi = loaded["osi"]