**FlyVis Perturbations: Evaluate Model Performance**

This notebook will take in FlyVis models (pre-trained on optic flow task) that have been systematically perturbed in their connectivity, and then output various performance metrics on visual classification tasks and neuron tuning properties.

INPUTS: [FlyVis model]

OUTPUTS: {NETWORK}_performance.h5py

In [6]:
import os
import subprocess
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch

# flyvis imports
from flyvis import NetworkView
from flyvis.datasets.rendering import BoxEye
from flyvis.analysis.animations import HexScatter


import h5py
import datamate.io        # Where the function is defined
import datamate.directory # Where the function is actually CALLED causing the error

# Set up plotting
plt.rcParams["figure.figsize"] = [5, 3]
plt.rcParams["font.size"] = 6
plt.rcParams["figure.dpi"] = 200


data_path = Path("data/flyvis_data")
data_path.mkdir(parents=True, exist_ok=True)

env = os.environ.copy()
env["FLYVIS_ROOT_DIR"] = str(data_path)



network_view = NetworkView("flow/0000/000")

[2026-01-05 19:34:13] network_view:122 Initialized network view at C:\Users\madis\Documents\Perturbations\fly_wire_perturbations\.venv\Lib\site-packages\flyvis\data\results\flow\0000\000


In [None]:
import copy

[2026-01-05 19:48:29] chkpt_utils:36 Recovered network state.


In [23]:
from flyvis_cell_type_pert import FlyvisCellTypePert, PerturbationType
pert = FlyvisCellTypePert()

connectome = network_view.connectome
synapse_df = connectome.edges.to_df()
cell_type_df = synapse_df.groupby(['source_type', 'target_type']).size().reset_index(name='syn_count')

#original_network = network_view.init_network()
#perturbed_network = network_view.init_network()

original_network = network_view.init_network()
perturbed_network = copy.deepcopy(original_network)

pairs_to_perturb = [('Am', 'L3')]
pert_conn = pert.perturb(conn=cell_type_df,
                        perturbation_type=PerturbationType.PAIR_WISE,
                        pairs=pairs_to_perturb)

# validate the pert_weight is 0
pert_conn[pert_conn.pert_weight == 0]

#override network weights
def get_pair_index_in_network(perturbed_network, source_type, target_type):
    syn_str = perturbed_network.edge_params.syn_strength.keys
    for idx, (src, tar) in enumerate(syn_str):
        if src == source_type and tar == target_type:
            return idx
    return None

idx = get_pair_index_in_network(perturbed_network, 'Am', 'L3')
print('Am -> L3 index:', idx)
weight_before = perturbed_network.edge_params.syn_strength.raw_values[idx]
print('Weight before perturbation:', weight_before)

pert.override_network(perturbed_network, pert_conn)
weight_after = perturbed_network.edge_params.syn_strength.raw_values[idx]
print('Weight after perturbation:', weight_after)

[2026-01-05 19:49:27] chkpt_utils:36 Recovered network state.


Am -> L3 index: 127
Weight before perturbation: tensor(0.0402, grad_fn=<SelectBackward0>)
Weight after perturbation: tensor(0., grad_fn=<SelectBackward0>)


In [24]:
weight_after = perturbed_network.edge_params.syn_strength.raw_values[idx]
print(weight_after)
weight_original = original_network.edge_params.syn_strength.raw_values[idx]
print(weight_original)

tensor(0., grad_fn=<SelectBackward0>)
tensor(0.0402, grad_fn=<SelectBackward0>)


In [25]:
# Compare synaptic strength tensors
orig_weights = original_network.edge_params.syn_strength.raw_values
pert_weights = perturbed_network.edge_params.syn_strength.raw_values

# Find differences
weight_diff = (pert_weights - orig_weights).abs()
changed_mask = weight_diff > 1e-8  # tolerance for floating point comparison
changed_indices = torch.where(changed_mask)[0]

print(f"Number of connections changed: {len(changed_indices)}")

if len(changed_indices) > 0:
    syn_keys = original_network.edge_params.syn_strength.keys
    print("\nChanged connections:")
    for idx in changed_indices:
        src, tar = syn_keys[idx]
        print(f"  {src} -> {tar}: {orig_weights[idx].item():.4f} → {pert_weights[idx].item():.4f}")
else:
    print("No changes detected!")

Number of connections changed: 1

Changed connections:
  Am -> L3: 0.0402 → 0.0000


In [1]:
# We run this to download the pretrained model ensemble
!flyvis download-pretrained

[2026-01-04 17:17:21] utils:164 NumExpr defaulting to 4 threads.
[2026-01-04 17:17:33] __init__:49 file_cache is only supported with oauth2client<4.0.0
Downloading results_umap_and_clustering.zip to /usr/local/lib/python3.11/dist-packages/data/results_umap_and_clustering.zip.
Progress results_umap_and_clustering.zip: 100%
Checksum OK for results_umap_and_clustering.zip.
Unpacked results_umap_and_clustering.zip.
Downloading results_pretrained_models.zip to /usr/local/lib/python3.11/dist-packages/data/results_pretrained_models.zip.
Progress results_pretrained_models.zip: 100%
Checksum OK for results_pretrained_models.zip.
Unpacked results_pretrained_models.zip.


In [4]:
# TODO: Replace with perturbed trained model
# Load a pretrained network model from the ensemble
# The identifier "flow/0000/000" specifies: task/ensemble_id/network_id
# - "flow": the training task (optic flow prediction)
# - "0000": the first ensemble (collection of models trained with the same configuration)
# - "000": the first network within that ensemble
network_view = NetworkView("flow/0000/000")

In [6]:
# TODO: Replace with perturbed connectome
connectome = network_view.connectome
cell_types = connectome.unique_cell_types

**Plot receptive and projective fields**

**Plot neural traces**

**Plot response to Sintel movie set (naturalistic stimuli)**

**Plot tuning curves for direction and orientation selectivity**

**Find and plot maximally exciting stimulus (artificial)**

**Quantify model performance on optic flow task**

**Other potential evaluation metrics:**

* dimensionality of population activity
* frequency and speed tuning curves
