In [1]:
from flyvis_cell_type_pert import FlyvisCellTypePert, PerturbationType
from flyvis.datasets.sintel import MultiTaskSintel
from pathlib import Path
import os
import h5py
import datamate
import pandas as pd
import numpy as np
import torch
import shutil
from tqdm import tqdm

from optic_flow import SintelWrapper

  from .autonotebook import tqdm as notebook_tqdm


 -> Patched datamate.directory._write_h5
Importing flyvis...


In [2]:
data_path = Path("data/flyvis_data")
data_path.mkdir(parents=True, exist_ok=True)

env = os.environ.copy()
env["FLYVIS_ROOT_DIR"] = str(data_path)

def fixed_write_h5(path, val):
    """
    A Windows-safe replacement that skips the 'read-before-write' check.
    """
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)

    with h5py.File(path, mode="w", libver="latest") as f:
        f.create_dataset("data", data=val)

datamate.io._write_h5 = fixed_write_h5
if hasattr(datamate.directory, "_write_h5"):
    datamate.directory._write_h5 = fixed_write_h5
    print(" -> Patched datamate.directory._write_h5")
else:
    print(" -> Warning: Could not find _write_h5 in directory module")

print("Importing flyvis...")
from flyvis import NetworkView


 -> Patched datamate.directory._write_h5
Importing flyvis...


In [3]:
print("Initializing Sintel dataset...")

dataset = MultiTaskSintel(
    tasks=["flow"],
    boxfilter=dict(extent=15, kernel_size=13),
    vertical_splits=1,  # Use 1 for faster testing, 3 for full dataset
    n_frames=19,
    dt=1/50,  # Temporal resolution
    augment=False,  # Set to False for evaluation
    resampling=True,
    interpolate=True,
    all_frames=False,
    random_temporal_crop=False,
)

print(f"Dataset initialized with {len(dataset)} sequences")
if hasattr(dataset, 'arg_df'):
    print(f"First 5 sequences: {dataset.arg_df['name'].tolist()[:5]}")
    display(dataset.arg_df.head())

[2026-01-07 23:01:00] sintel_utils:331 Found Sintel at c:\Users\madis\Documents\Perturbations\fly_wire_perturbations\.venv\Lib\site-packages\flyvis\data\SintelDataSet


Initializing Sintel dataset...
Dataset initialized with 23 sequences
First 5 sequences: ['sequence_00_alley_1_split_00', 'sequence_01_alley_2_split_00', 'sequence_02_ambush_2_split_00', 'sequence_03_ambush_4_split_00', 'sequence_04_ambush_5_split_00']


Unnamed: 0,index,original_index,name,original_n_frames
0,0,0,sequence_00_alley_1_split_00,50
1,1,1,sequence_01_alley_2_split_00,50
2,2,2,sequence_02_ambush_2_split_00,21
3,3,3,sequence_03_ambush_4_split_00,33
4,4,4,sequence_04_ambush_5_split_00,50


In [None]:
# %% Inspect a Single Sample
print("\nInspecting first sample...")
sample = dataset[0]
print(f"Sample keys: {sample.keys()}")
print(f"Luminance shape: {sample['lum'].shape}")
print(f"Flow shape: {sample['flow'].shape}")

In [None]:
#  %% Load Network and Decoder
print("\nLoading network...")
src_folder = data_path / "results/flow/0000/000"
network_view = NetworkView(src_folder)
network = network_view.init_network()

print("Loading decoder...")
decoder = network_view.init_decoder()["flow"]
decoder.eval()

print(f"Network initialized successfully")
print(f"Number of network parameters: {sum(p.numel() for p in network.parameters())}")
print(f"Number of decoder parameters: {sum(p.numel() for p in decoder.parameters())}")

In [None]:
# %% Test Single Sequence Prediction
print("\nTesting prediction on first sequence...")
data = dataset[0]
lum = data["lum"]
flow = data["flow"]

# Simulate network response
stationary_state = network.fade_in_state(1.0, dataset.dt, lum[[0]])
responses = network.simulate(lum[None], dataset.dt, initial_state=stationary_state)

# Decode flow from neural responses
y_pred = decoder(responses)

# Compute EPE for this sequence
epe = torch.sqrt(((y_pred - flow) ** 2).sum(dim=1))

print(f"Prediction shape: {y_pred.shape}")
print(f"Ground truth shape: {flow.shape}")
print(f"EPE shape: {epe.shape}")
print(f"Mean EPE: {epe.mean().item():.4f} pixels")
print(f"Median EPE: {epe.median().item():.4f} pixels")

In [None]:
def evaluate_network(network, decoder, dataset, output_file=None):
    """
    Evaluate network on entire Sintel dataset
    """
    print('Generating Sintel optic flow responses...')
    
    all_pred_flow = []
    all_true_flow = []
    all_epe = []
    
    for i in range(len(dataset)):
        if i % 5 == 0:
            print(f"Processing sequence {i+1}/{len(dataset)}...")
        
        data = dataset[i]
        lum = data["lum"]
        flow = data["flow"]
        
        # Simulate network response
        stationary_state = network.fade_in_state(1.0, dataset.dt, lum[[0]])
        responses = network.simulate(lum[None], dataset.dt, initial_state=stationary_state)
        
        # Decode flow from neural responses
        y_pred = decoder(responses)
        
        # Compute EPE for this sequence
        epe = torch.sqrt(((y_pred - flow) ** 2).sum(dim=1))
        
        all_pred_flow.append(y_pred.detach().cpu())
        all_true_flow.append(flow.cpu() if hasattr(flow, 'cpu') else flow)
        all_epe.append(epe.detach().cpu())
    
    print('Evaluating performance...')
    
    # Aggregate metrics
    all_epe_tensor = torch.cat(all_epe, dim=0)
    
    # Compute overall statistics
    results = []
    
    results.append({
        'sequence': 'overall',
        'n_sequences': len(dataset),
        'mean_epe': float(all_epe_tensor.mean()),
        'median_epe': float(all_epe_tensor.median()),
        'std_epe': float(all_epe_tensor.std()),
        'epe_pixel_1': float((all_epe_tensor < 1).float().mean()),
        'epe_pixel_3': float((all_epe_tensor < 3).float().mean()),
        'epe_pixel_5': float((all_epe_tensor < 5).float().mean()),
    })
    
    # Per-sequence statistics
    for i, epe in enumerate(all_epe):
        results.append({
            'sequence': f'seq_{i:03d}',
            'sequence_name': dataset.arg_df.iloc[i]['name'] if hasattr(dataset, 'arg_df') else f'seq_{i}',
            'mean_epe': float(epe.mean()),
            'median_epe': float(epe.median()),
            'std_epe': float(epe.std()),
        })
    
    results_df = pd.DataFrame(results)
    
    if output_file:
        os.makedirs(os.path.dirname(output_file), exist_ok=True)
        results_df.to_csv(output_file, index=False)
        print(f"\nResults saved to {output_file}")
    
    # Print summary
    overall = results_df[results_df['sequence'] == 'overall'].iloc[0]
    print(f"\nOverall Performance:")
    print(f"  Mean EPE: {overall['mean_epe']:.4f} pixels")
    print(f"  Median EPE: {overall['median_epe']:.4f} pixels")
    print(f"  % pixels with EPE < 3px: {overall['epe_pixel_3']*100:.2f}%")
    
    return results_df

In [None]:
result = evaluate_network(network, decoder, dataset)


In [None]:
cell_type_df = pd.read_csv(f'{data_path}/flyvis_cell_type_connectivity.csv')

result_csv = pd.DataFrame()

for src, tar in tqdm(cell_type_df[['source_type', 'target_type']].values):
    print(f"Running perturbation: {src} -> {tar}")

    pert = FlyvisCellTypePert()
    pert.perturb(cell_type_df, PerturbationType.PAIR_WISE, pairs=[(src, tar)])

    wrapper = SintelWrapper(dataset, pert=pert, pert_folder_name=f'{src}_{tar}_perturbation')
    result = wrapper.run()  
    
    # Add the source-target info
    result['source_target_pair'] = f'{src}_{tar}'
    result['source_type'] = src
    result['target_type'] = tar
    
    result_csv = pd.concat([result_csv, result], ignore_index=True)

result_csv.to_csv(f'data/flyvis_data/optic_flow/pairwise_perturbation_sintel_results.csv', index=False)

In [None]:
cell_type_df = pd.read_csv(f'{data_path}/flyvis_cell_type_connectivity.csv')

result_csv = pd.DataFrame()



src = 'Tm9'
tar = 'T5a'
pert = FlyvisCellTypePert()
pert.perturb(cell_type_df, PerturbationType.PAIR_WISE, pairs=[(src, tar)])

wrapper = SintelWrapper(dataset, pert=pert, pert_folder_name=f'{src}_{tar}_perturbation')
result = wrapper.run()  
    
    # Add the source-target info
result['source_target_pair'] = f'{src}_{tar}'
result['source_type'] = src
result['target_type'] = tar
    
result_csv = pd.concat([result_csv, result], ignore_index=True)

result_csv.to_csv(f'data/flyvis_data/optic_flow/perturbation_results_test.csv', index=False)

[2026-01-07 23:01:10] network_view:122 Initialized network view at C:\Users\madis\Documents\Perturbations\fly_wire_perturbations\data\flyvis_data\results\flow\0000\000_Tm9_T5a_perturbation
[2026-01-07 23:01:10] logging_utils:23 epe not in C:\Users\madis\Documents\Perturbations\fly_wire_perturbations\data\flyvis_data\results\flow\0000\000_Tm9_T5a_perturbation\validation, but 'loss' is. Falling back to 'loss'. You can rerun the ensemble validation to make appropriate recordings of the losses.
[2026-01-07 23:01:13] network:222 Initialized network with NumberOfParams(free=734, fixed=2959) parameters.
[2026-01-07 23:01:13] chkpt_utils:36 Recovered network state.


Running Sintel optic flow simulation...
Applying perturbation to network in memory...
Overwriting disk checkpoints with perturbed weights...
 -> Updated: data\flyvis_data\results\flow\0000\000_Tm9_T5a_perturbation\best_chkpt
 -> Updated: data\flyvis_data\results\flow\0000\000_Tm9_T5a_perturbation\chkpts\chkpt_00000
Clearing caches...
 -> Removed __cache__


[2026-01-07 23:01:14] network_view:122 Initialized network view at C:\Users\madis\Documents\Perturbations\fly_wire_perturbations\data\flyvis_data\results\flow\0000\000_Tm9_T5a_perturbation
[2026-01-07 23:01:17] network:222 Initialized network with NumberOfParams(free=734, fixed=2959) parameters.
[2026-01-07 23:01:17] chkpt_utils:36 Recovered network state.


Initializing decoder...


[2026-01-07 23:01:19] decoder:282 Initialized decoder with NumberOfParams(free=7427, fixed=0) parameters.
[2026-01-07 23:01:19] decoder:283 DecoderGAVP(
  (base): Sequential(
    (0): Conv2dHexSpace(34, 8, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Softplus(beta=1.0, threshold=20.0)
    (3): Dropout(p=0.5, inplace=False)
  )
  (decoder): Sequential(
    (0): Conv2dHexSpace(8, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  )
  (head): Sequential()
)
[2026-01-07 23:01:19] chkpt_utils:65 Recovered flow decoder state.
  return func(*args, **kwargs)


Generating Sintel optic flow responses...
Evaluating performance...

Results saved to 
