In [1]:
from flyvis.datasets.moving_bar import MovingEdge
import numpy as np
import os
from flyvis_cell_type_pert import FlyvisCellTypePert
import pandas as pd
from tqdm import tqdm
from moving_edge import MovingEdgeWrapper
from pathlib import Path
from flyvis_cell_type_pert import PerturbationType

 -> Patched datamate.directory._write_h5
Importing flyvis...


In [2]:
import datamate

data_path = Path("data/flyvis_data")
data_path.mkdir(parents=True, exist_ok=True)

env = os.environ.copy()
env["FLYVIS_ROOT_DIR"] = str(data_path)

In [3]:
cell_type_df = pd.read_csv(f'{data_path}/flyvis_cell_type_connectivity.csv')


# Normal

In [None]:
dataset = MovingEdge(
            offsets=[-10, 11],  # offset of bar from center in 1 * radians(2.25) led size
            intensities=[0, 1],  # intensity of bar
            speeds=[19],  # speed of bar in 1 * radians(5.8) / s
            height=80,  # height of moving bar in 1 * radians(2.25) led size
            post_pad_mode="continue",  # for post-stimulus period, continue with the last frame of the stimulus
            t_pre=1.0,  # duration of pre-stimulus period
            t_post=1.0,  # duration of post-stimulus period
            dt=1 / 200,  # temporal resolution of rendered video
            angles=list(np.arange(0, 360, 30)),  # motion direction (orthogonal to edge)
    )

In [None]:
pairwise_hits = ['CT1(Lo1)_T5b', 'CT1(M10)_Mi1', 'Mi4_T4d', 'R8_Mi4', 'Tm9_T5c', 'TmY15_Tm4', 'Mi1_T4a', 'R8_Mi1', 'Mi4_T4d', 'Mi4_Tm9']

In [None]:
for hit in pairwise_hits:
    src, tar = hit.split('_')
    result_output_path = f'data/flyvis_data/pairwise_edge_pert_results/{src}_{tar}.csv'
    plot_output_dir = f'data/flyvis_data/pairwise_edge_fig/{src}_{tar}'
    #if os.path.exists(result_output_path):
    #    print(f"Skipping existing perturbation: {src} -> {tar}")
    #    continue
    print(f"Running perturbation: {src} -> {tar}")

    pert = FlyvisCellTypePert()
    pert.perturb(cell_type_df, PerturbationType.PAIR_WISE, pairs=[(src, tar)])

    wrapper = MovingEdgeWrapper(dataset, pert=pert,
                                pert_folder_name=f'{src}_{tar}_perturbation',
                                output_file_name=result_output_path,
                                plot_output_dir=plot_output_dir)
    wrapper.run()

for outgoing perturb:

In [None]:
outgoing_hits = ['Tm9', 'CT1(M10)', 'Mi2', 'Mi4', 'R8', 'Mi1', 'R8']

In [None]:
for src in outgoing_hits:
    result_output_path = f'data/flyvis_data/outgoing_edge_pert_results/{src}.csv'
    plot_output_dir = f'data/flyvis_data/outgoing_edge_fig/{src}_outgoing'
    print(f"Running outgoing perturbation: {src}")

    pert = FlyvisCellTypePert()
    pert.perturb(cell_type_df, PerturbationType.OUTGOING, source_outgoing=src)

    wrapper = MovingEdgeWrapper(dataset, pert=pert,
                                pert_folder_name=f'{src}_outgoing_perturbation',
                                output_file_name=result_output_path,
                                plot_output_dir=plot_output_dir)
    wrapper.run()

In [None]:
result_output_path = f'data/flyvis_data/moving_edge_original.csv'
plot_output_dir = f'data/flyvis_data/moving_edge_original_fig'
print(f"Running original (unperturbed) network")

wrapper = MovingEdgeWrapper(dataset, pert=None,
                                pert_folder_name=f'original_no_perturbation',
                                output_file_name=result_output_path,
                                plot_output_dir=plot_output_dir)
wrapper.run()

# FeedBack Motif

In [4]:
dataset = MovingEdge(
            offsets=[-10, 11],  # offset of bar from center in 1 * radians(2.25) led size
            intensities=[0, 0.75],  # intensity of bar
            speeds=[19],  # speed of bar in 1 * radians(5.8) / s
            height=80,  # height of moving bar in 1 * radians(2.25) led size
            post_pad_mode="continue",  # for post-stimulus period, continue with the last frame of the stimulus
            t_pre=1.0,  # duration of pre-stimulus period
            t_post=1.0,  # duration of post-stimulus period
            dt=1 / 200,  # temporal resolution of rendered video
            angles=list(np.arange(0, 360, 30)),  # motion direction (orthogonal to edge)
    )

In [None]:
# visualize single sample
# %#matplotlib notebook
from flyvis.analysis.animations.hexscatter import HexScatter

animation = HexScatter(
    dataset[3][None, ::25, None], vmin=0, vmax=1
)  # intensity=1, radius=6
animation.animate_in_notebook()

In [5]:
triplets = pd.read_csv(f'{data_path}/motif_edges/motif_46_feedback_test.csv')

In [None]:
for idx, row in triplets.iterrows():
    a, b, c = row['a'], row['b'], row['c']
    print(f"Triplet {idx}: a={a}, b={b}, c={c}")

    experiments = [[(a, c)], [(b, c)], [(a, b), (b, a)]]
    print(experiments)
    for exp in experiments:
        exp_name = str(exp)
        plot_output_dir = f'data/flyvis_data/motif_feedback_test/moving_edge/triplet_{idx}/{exp_name}/'
        result_output_path = f'{plot_output_dir}/res.csv'
        pert = FlyvisCellTypePert()
        pert.perturb(cell_type_df, PerturbationType.PAIR_WISE, pairs=exp)

        wrapper = MovingEdgeWrapper(dataset, pert=pert,
                                    pert_folder_name=f'triplet_idx{idx}-{exp_name}_pairwise_perturbation',
                                    output_file_name=result_output_path,
                                    plot_output_dir=plot_output_dir)
        wrapper.run()

[2026-01-08 00:22:11] network_view:122 Initialized network view at C:\Users\dean\Documents\dev\fly_winter_school\fly_wire_perturbations\data\flyvis_data\results\flow\0000\000_triplet_idx0-[('Tm3', 'T4a')]_pairwise_perturbation


Triplet 0: a=Tm3, b=Mi1, c=T4a
[[('Tm3', 'T4a')], [('Mi1', 'T4a')], [('Tm3', 'Mi1'), ('Mi1', 'Tm3')]]


[2026-01-08 00:22:11] logging_utils:23 epe not in C:\Users\dean\Documents\dev\fly_winter_school\fly_wire_perturbations\data\flyvis_data\results\flow\0000\000_triplet_idx0-[('Tm3', 'T4a')]_pairwise_perturbation\validation, but 'loss' is. Falling back to 'loss'. You can rerun the ensemble validation to make appropriate recordings of the losses.
[2026-01-08 00:22:16] network:222 Initialized network with NumberOfParams(free=734, fixed=2959) parameters.
[2026-01-08 00:22:16] chkpt_utils:36 Recovered network state.


Running moving edge simulation...
Applying perturbation to network in memory...


[2026-01-08 00:22:16] network_view:122 Initialized network view at C:\Users\dean\Documents\dev\fly_winter_school\fly_wire_perturbations\data\flyvis_data\results\flow\0000\000_triplet_idx0-[('Tm3', 'T4a')]_pairwise_perturbation


Overwriting disk checkpoints with perturbed weights...
 -> Updated: data\flyvis_data\results\flow\0000\000_triplet_idx0-[('Tm3', 'T4a')]_pairwise_perturbation\best_chkpt
 -> Updated: data\flyvis_data\results\flow\0000\000_triplet_idx0-[('Tm3', 'T4a')]_pairwise_perturbation\chkpts\chkpt_00000
Clearing caches...
 -> Removed __cache__
Generating moving edge responses...


[2026-01-08 00:22:20] network:222 Initialized network with NumberOfParams(free=734, fixed=2959) parameters.
[2026-01-08 00:22:20] chkpt_utils:36 Recovered network state.
[2026-01-08 00:22:22] network:757 Computing 24 stimulus responses.


Batch:   0%|          | 0/6 [00:00<?, ?it/s]

[2026-01-08 00:23:23] xarray_joblib_backend:56 Store item C:\Users\dean\Documents\dev\fly_winter_school\fly_wire_perturbations\data\flyvis_data\results\flow\0000\000_triplet_idx0-[('Tm3', 'T4a')]_pairwise_perturbation\__cache__\flyvis\analysis\stimulus_responses\compute_responses\8faab3b6530bd9052ceaaa396e79cb3c\output.h5


Evaluating performance...
Generating response plots...
Saved plot: data/flyvis_data/motif_feedback_test/moving_edge/triplet_0/[('Tm3', 'T4a')]/T4a_intensity0.0.png
Saved tuning curve: data/flyvis_data/motif_feedback_test/moving_edge/triplet_0/[('Tm3', 'T4a')]/T4a_tuning_intensity0.0.png
Saved plot: data/flyvis_data/motif_feedback_test/moving_edge/triplet_0/[('Tm3', 'T4a')]/T4a_intensity0.75.png
Saved tuning curve: data/flyvis_data/motif_feedback_test/moving_edge/triplet_0/[('Tm3', 'T4a')]/T4a_tuning_intensity0.75.png
Saved plot: data/flyvis_data/motif_feedback_test/moving_edge/triplet_0/[('Tm3', 'T4a')]/T4b_intensity0.0.png
Saved tuning curve: data/flyvis_data/motif_feedback_test/moving_edge/triplet_0/[('Tm3', 'T4a')]/T4b_tuning_intensity0.0.png
Saved plot: data/flyvis_data/motif_feedback_test/moving_edge/triplet_0/[('Tm3', 'T4a')]/T4b_intensity0.75.png
Saved tuning curve: data/flyvis_data/motif_feedback_test/moving_edge/triplet_0/[('Tm3', 'T4a')]/T4b_tuning_intensity0.75.png
Saved plo

[2026-01-08 00:23:30] network_view:122 Initialized network view at C:\Users\dean\Documents\dev\fly_winter_school\fly_wire_perturbations\data\flyvis_data\results\flow\0000\000_triplet_idx0-[('Mi1', 'T4a')]_pairwise_perturbation
[2026-01-08 00:23:30] logging_utils:23 epe not in C:\Users\dean\Documents\dev\fly_winter_school\fly_wire_perturbations\data\flyvis_data\results\flow\0000\000_triplet_idx0-[('Mi1', 'T4a')]_pairwise_perturbation\validation, but 'loss' is. Falling back to 'loss'. You can rerun the ensemble validation to make appropriate recordings of the losses.
[2026-01-08 00:23:36] network:222 Initialized network with NumberOfParams(free=734, fixed=2959) parameters.
[2026-01-08 00:23:36] chkpt_utils:36 Recovered network state.


Running moving edge simulation...
Applying perturbation to network in memory...


[2026-01-08 00:23:36] network_view:122 Initialized network view at C:\Users\dean\Documents\dev\fly_winter_school\fly_wire_perturbations\data\flyvis_data\results\flow\0000\000_triplet_idx0-[('Mi1', 'T4a')]_pairwise_perturbation


Overwriting disk checkpoints with perturbed weights...
 -> Updated: data\flyvis_data\results\flow\0000\000_triplet_idx0-[('Mi1', 'T4a')]_pairwise_perturbation\best_chkpt
 -> Updated: data\flyvis_data\results\flow\0000\000_triplet_idx0-[('Mi1', 'T4a')]_pairwise_perturbation\chkpts\chkpt_00000
Clearing caches...
 -> Removed __cache__
Generating moving edge responses...


[2026-01-08 00:23:41] network:222 Initialized network with NumberOfParams(free=734, fixed=2959) parameters.
[2026-01-08 00:23:41] chkpt_utils:36 Recovered network state.
[2026-01-08 00:23:44] network:757 Computing 24 stimulus responses.


Batch:   0%|          | 0/6 [00:00<?, ?it/s]