In [None]:
from flyvis_cell_type_pert import FlyvisCellTypePert, PerturbationType
from flyvis.datasets.sintel import MultiTaskSintel
from pathlib import Path
import os
import h5py
import datamate
import pandas as pd
import numpy as np
import torch
import shutil


In [None]:
data_path = Path("data/flyvis_data")
data_path.mkdir(parents=True, exist_ok=True)

env = os.environ.copy()
env["FLYVIS_ROOT_DIR"] = str(data_path)

def fixed_write_h5(path, val):
    """
    A Windows-safe replacement that skips the 'read-before-write' check.
    """
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)

    with h5py.File(path, mode="w", libver="latest") as f:
        f.create_dataset("data", data=val)

datamate.io._write_h5 = fixed_write_h5
if hasattr(datamate.directory, "_write_h5"):
    datamate.directory._write_h5 = fixed_write_h5
    print(" -> Patched datamate.directory._write_h5")
else:
    print(" -> Warning: Could not find _write_h5 in directory module")

print("Importing flyvis...")
from flyvis import NetworkView


In [None]:
print("Initializing Sintel dataset...")

dataset = MultiTaskSintel(
    tasks=["flow"],
    boxfilter=dict(extent=15, kernel_size=13),
    vertical_splits=1,  # Use 1 for faster testing, 3 for full dataset
    n_frames=19,
    dt=1/50,  # Temporal resolution
    augment=False,  # Set to False for evaluation
    resampling=True,
    interpolate=True,
    all_frames=False,
    random_temporal_crop=False,
)

print(f"Dataset initialized with {len(dataset)} sequences")
if hasattr(dataset, 'arg_df'):
    print(f"First 5 sequences: {dataset.arg_df['name'].tolist()[:5]}")
    display(dataset.arg_df.head())

In [None]:
# %% Inspect a Single Sample
print("\nInspecting first sample...")
sample = dataset[0]
print(f"Sample keys: {sample.keys()}")
print(f"Luminance shape: {sample['lum'].shape}")
print(f"Flow shape: {sample['flow'].shape}")

In [None]:
#  %% Load Network and Decoder
print("\nLoading network...")
src_folder = data_path / "results/flow/0000/000"
network_view = NetworkView(src_folder)
network = network_view.init_network()

print("Loading decoder...")
decoder = network_view.init_decoder()["flow"]
decoder.eval()

print(f"Network initialized successfully")
print(f"Number of network parameters: {sum(p.numel() for p in network.parameters())}")
print(f"Number of decoder parameters: {sum(p.numel() for p in decoder.parameters())}")

In [None]:
# %% Test Single Sequence Prediction
print("\nTesting prediction on first sequence...")
data = dataset[0]
lum = data["lum"]
flow = data["flow"]

# Simulate network response
stationary_state = network.fade_in_state(1.0, dataset.dt, lum[[0]])
responses = network.simulate(lum[None], dataset.dt, initial_state=stationary_state)

# Decode flow from neural responses
y_pred = decoder(responses)

# Compute EPE for this sequence
epe = torch.sqrt(((y_pred - flow) ** 2).sum(dim=1))

print(f"Prediction shape: {y_pred.shape}")
print(f"Ground truth shape: {flow.shape}")
print(f"EPE shape: {epe.shape}")
print(f"Mean EPE: {epe.mean().item():.4f} pixels")
print(f"Median EPE: {epe.median().item():.4f} pixels")

In [None]:
# ## Full Evaluation Function

# %% Define Evaluation Function
def evaluate_network(network, decoder, dataset, output_file=None):
    """
    Evaluate network on entire Sintel dataset
    """
    print('Generating Sintel optic flow responses...')
    
    all_pred_flow = []
    all_true_flow = []
    all_epe = []
    
    for i in range(len(dataset)):
        if i % 5 == 0:
            print(f"Processing sequence {i+1}/{len(dataset)}...")
        
        data = dataset[i]
        lum = data["lum"]
        flow = data["flow"]
        
        # Simulate network response
        stationary_state = network.fade_in_state(1.0, dataset.dt, lum[[0]])
        responses = network.simulate(lum[None], dataset.dt, initial_state=stationary_state)
        
        # Decode flow from neural responses
        y_pred = decoder(responses)
        
        # Compute EPE for this sequence
        epe = torch.sqrt(((y_pred - flow) ** 2).sum(dim=1))
        
        all_pred_flow.append(y_pred.detach().cpu())
        all_true_flow.append(flow.cpu() if hasattr(flow, 'cpu') else flow)
        all_epe.append(epe.detach().cpu())
    
    print('Evaluating performance...')
    
    # Aggregate metrics
    all_epe_tensor = torch.cat(all_epe, dim=0)
    
    # Compute overall statistics
    results = []
    
    results.append({
        'sequence': 'overall',
        'n_sequences': len(dataset),
        'mean_epe': float(all_epe_tensor.mean()),
        'median_epe': float(all_epe_tensor.median()),
        'std_epe': float(all_epe_tensor.std()),
        'epe_pixel_1': float((all_epe_tensor < 1).float().mean()),
        'epe_pixel_3': float((all_epe_tensor < 3).float().mean()),
        'epe_pixel_5': float((all_epe_tensor < 5).float().mean()),
    })
    
    # Per-sequence statistics
    for i, epe in enumerate(all_epe):
        results.append({
            'sequence': f'seq_{i:03d}',
            'sequence_name': dataset.arg_df.iloc[i]['name'] if hasattr(dataset, 'arg_df') else f'seq_{i}',
            'mean_epe': float(epe.mean()),
            'median_epe': float(epe.median()),
            'std_epe': float(epe.std()),
        })
    
    results_df = pd.DataFrame(results)
    
    if output_file:
        os.makedirs(os.path.dirname(output_file), exist_ok=True)
        results_df.to_csv(output_file, index=False)
        print(f"\nResults saved to {output_file}")
    
    # Print summary
    overall = results_df[results_df['sequence'] == 'overall'].iloc[0]
    print(f"\nOverall Performance:")
    print(f"  Mean EPE: {overall['mean_epe']:.4f} pixels")
    print(f"  Median EPE: {overall['median_epe']:.4f} pixels")
    print(f"  % pixels with EPE < 3px: {overall['epe_pixel_3']*100:.2f}%")
    
    return results_df

In [None]:
# Create a smaller dataset for quick testing
test_dataset = MultiTaskSintel(
    tasks=["flow"],
    boxfilter=dict(extent=15, kernel_size=13),
    vertical_splits=1,
    n_frames=19,
    dt=1/50,
    augment=False,
    resampling=True,
    interpolate=True,
    all_frames=False,
    random_temporal_crop=False,
)

In [None]:
# Evaluate on first 3 sequences only for quick test
class SubsetDataset:
    def __init__(self, dataset, n_samples=3):
        self.dataset = dataset
        self.n_samples = min(n_samples, len(dataset))
        self.dt = dataset.dt
        self.arg_df = dataset.arg_df.iloc[:self.n_samples] if hasattr(dataset, 'arg_df') else None
    
    def __len__(self):
        return self.n_samples
    
    def __getitem__(self, idx):
        return self.dataset[idx]

subset_dataset = SubsetDataset(test_dataset, n_samples=3)
results_original = evaluate_network(
    network, 
    decoder, 
    subset_dataset,
    output_file="data/flyvis_data/perf/sintel_original_quick_test.csv"
)

display(results_original)


In [None]:
# %% Load Connectivity and Set Up Perturbation
conn_df = pd.read_csv('data/flyvis_data/flyvis_cell_type_connectivity.csv')
print(f"Loaded connectivity with {len(conn_df)} connections")
print(f"\nColumns: {conn_df.columns.tolist()}")

# %% Define Perturbation
pert = FlyvisCellTypePert()
pairs_to_perturb = [('L4', 'L4')]

pert.perturb(conn_df, PerturbationType.PAIR_WISE, pairs=pairs_to_perturb)


In [None]:
tar_folder = data_path / "results/flow/0000/000/sintel_test/000_test_pert"
shutil.rmtree(tar_folder, ignore_errors=True)
shutil.copytree(src_folder, tar_folder, dirs_exist_ok=True)
print(f"Created perturbed network folder: {tar_folder}")

# Apply perturbation
print('\nApplying perturbation to network in memory...')
pert.override_network(network)

# Save perturbed weights
print("Saving perturbed weights to disk...")
checkpoint_template = torch.load(src_folder / "best_chkpt", map_location='cpu')
perturbed_checkpoint = checkpoint_template.copy()
perturbed_checkpoint['network'] = network.state_dict()

target_best_chkpt = tar_folder / "best_chkpt"
torch.save(perturbed_checkpoint, target_best_chkpt)
print(f" -> Updated: {target_best_chkpt}")

# Update checkpoint files
chkpts_dir = tar_folder / "chkpts"
if chkpts_dir.exists():
    for chkpt_file in chkpts_dir.glob("*"):
        torch.save(perturbed_checkpoint, chkpt_file)
        print(f" -> Updated: {chkpt_file}")

# Clear caches
print("Clearing caches...")
for cache_name in ["__cache__", "__storage__"]:
    cache_dir = tar_folder / cache_name
    if cache_dir.exists():
        shutil.rmtree(cache_dir)
        print(f" -> Removed {cache_name}")

print("\nPerturbation applied and saved successfully!")

# %% Reload Perturbed Network
print("Reloading perturbed network from disk...")
network_view_pert = NetworkView(tar_folder)
network_pert = network_view_pert.init_network()
decoder_pert = network_view_pert.init_decoder()["flow"]
decoder_pert.eval()

print("Perturbed network loaded successfully!")


results_perturbed = evaluate_network(
    network_pert,
    decoder_pert,
    subset_dataset,
    output_file="data/flyvis_data/perf/sintel_L4_L4_pert_quick_test.csv"
)

display(results_perturbed)

In [None]:
comparison = pd.DataFrame({
    'metric': ['mean_epe', 'median_epe', 'epe_pixel_3'],
    'original': [
        results_original[results_original['sequence']=='overall']['mean_epe'].values[0],
        results_original[results_original['sequence']=='overall']['median_epe'].values[0],
        results_original[results_original['sequence']=='overall']['epe_pixel_3'].values[0] * 100
    ],
    'perturbed': [
        results_perturbed[results_perturbed['sequence']=='overall']['mean_epe'].values[0],
        results_perturbed[results_perturbed['sequence']=='overall']['median_epe'].values[0],
        results_perturbed[results_perturbed['sequence']=='overall']['epe_pixel_3'].values[0] * 100
    ]
})

comparison['change'] = comparison['perturbed'] - comparison['original']
comparison['percent_change'] = (comparison['change'] / comparison['original']) * 100

print("\nSummary:")
display(comparison)

# Plot comparison
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Plot 1: EPE comparison
ax1 = axes[0]
metrics = ['mean_epe', 'median_epe']
x = np.arange(len(metrics))
width = 0.35

bars1 = ax1.bar(x - width/2, [comparison[comparison['metric']==m]['original'].values[0] for m in metrics], 
                width, label='Original', alpha=0.8)
bars2 = ax1.bar(x + width/2, [comparison[comparison['metric']==m]['perturbed'].values[0] for m in metrics], 
                width, label='Perturbed (L4-L4)', alpha=0.8)

ax1.set_ylabel('EPE (pixels)')
ax1.set_title('End Point Error Comparison')
ax1.set_xticks(x)
ax1.set_xticklabels(['Mean EPE', 'Median EPE'])
ax1.legend()
ax1.grid(True, alpha=0.3)

# Plot 2: Accuracy comparison
ax2 = axes[1]
orig_acc = comparison[comparison['metric']=='epe_pixel_3']['original'].values[0]
pert_acc = comparison[comparison['metric']=='epe_pixel_3']['perturbed'].values[0]

bars = ax2.bar(['Original', 'Perturbed (L4-L4)'], [orig_acc, pert_acc], alpha=0.8)
ax2.set_ylabel('% pixels with EPE < 3px')
ax2.set_title('Accuracy Comparison')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nPerturbation Effect:")
print(f"  Mean EPE change: {comparison[comparison['metric']=='mean_epe']['change'].values[0]:+.4f} pixels ({comparison[comparison['metric']=='mean_epe']['percent_change'].values[0]:+.2f}%)")
print(f"  Accuracy change: {comparison[comparison['metric']=='epe_pixel_3']['change'].values[0]:+.2f} percentage points")

# %% [markdown]
# ## Full Dataset Evaluation (Optional)
# Uncomment and run this cell to evaluate on the full dataset (takes longer)

# %% Full Evaluation (Optional)
"""
# Full original network evaluation
print("FULL EVALUATION - Original Network")
results_full_original = evaluate_network(
    network,
    decoder,
    test_dataset,
    output_file="data/flyvis_data/perf/sintel_original_full.csv"
)

# Full perturbed network evaluation  
print("\nFULL EVALUATION - Perturbed Network")
results_full_perturbed = evaluate_network(
    network_pert,
    decoder_pert,
    test_dataset,
    output_file="data/flyvis_data/perf/sintel_L4_L4_pert_full.csv"
)

print("\nFull evaluation complete!")
"""

# %%