In [1]:
from flyvis.datasets.moving_bar import MovingEdge
import numpy as np
import os
from flyvis_cell_type_pert import FlyvisCellTypePert
import pandas as pd
from tqdm import tqdm
from moving_edge import MovingEdgeWrapper
from pathlib import Path
from flyvis_cell_type_pert import PerturbationType

 -> Patched datamate.directory._write_h5
Importing flyvis...


In [2]:
import datamate

data_path = Path("data/flyvis_data")
data_path.mkdir(parents=True, exist_ok=True)

env = os.environ.copy()
env["FLYVIS_ROOT_DIR"] = str(data_path)

In [3]:
cell_type_df = pd.read_csv(f'{data_path}/flyvis_cell_type_connectivity.csv')


# Normal

In [None]:
dataset = MovingEdge(
            offsets=[-10, 11],  # offset of bar from center in 1 * radians(2.25) led size
            intensities=[0, 1],  # intensity of bar
            speeds=[19],  # speed of bar in 1 * radians(5.8) / s
            height=80,  # height of moving bar in 1 * radians(2.25) led size
            post_pad_mode="continue",  # for post-stimulus period, continue with the last frame of the stimulus
            t_pre=1.0,  # duration of pre-stimulus period
            t_post=1.0,  # duration of post-stimulus period
            dt=1 / 200,  # temporal resolution of rendered video
            angles=list(np.arange(0, 360, 30)),  # motion direction (orthogonal to edge)
    )

In [None]:
pairwise_hits = ['CT1(Lo1)_T5b', 'CT1(M10)_Mi1', 'Mi4_T4d', 'R8_Mi4', 'Tm9_T5c', 'TmY15_Tm4', 'Mi1_T4a', 'R8_Mi1', 'Mi4_T4d', 'Mi4_Tm9']

In [None]:
for hit in pairwise_hits:
    src, tar = hit.split('_')
    result_output_path = f'data/flyvis_data/pairwise_edge_pert_results/{src}_{tar}.csv'
    plot_output_dir = f'data/flyvis_data/pairwise_edge_fig/{src}_{tar}'
    #if os.path.exists(result_output_path):
    #    print(f"Skipping existing perturbation: {src} -> {tar}")
    #    continue
    print(f"Running perturbation: {src} -> {tar}")

    pert = FlyvisCellTypePert()
    pert.perturb(cell_type_df, PerturbationType.PAIR_WISE, pairs=[(src, tar)])

    wrapper = MovingEdgeWrapper(dataset, pert=pert,
                                pert_folder_name=f'{src}_{tar}_perturbation',
                                output_file_name=result_output_path,
                                plot_output_dir=plot_output_dir)
    wrapper.run()

for outgoing perturb:

In [None]:
outgoing_hits = ['Tm9', 'CT1(M10)', 'Mi2', 'Mi4', 'R8', 'Mi1', 'R8']

In [None]:
for src in outgoing_hits:
    result_output_path = f'data/flyvis_data/outgoing_edge_pert_results/{src}.csv'
    plot_output_dir = f'data/flyvis_data/outgoing_edge_fig/{src}_outgoing'
    print(f"Running outgoing perturbation: {src}")

    pert = FlyvisCellTypePert()
    pert.perturb(cell_type_df, PerturbationType.OUTGOING, source_outgoing=src)

    wrapper = MovingEdgeWrapper(dataset, pert=pert,
                                pert_folder_name=f'{src}_outgoing_perturbation',
                                output_file_name=result_output_path,
                                plot_output_dir=plot_output_dir)
    wrapper.run()

In [None]:
for idx, row in triplets.iterrows():
    a, b, c = row['a'], row['b'], row['c']
    print(f"Triplet {idx}: a={a}, b={b}, c={c}")

    experiments = [[(a, c)], [(b, c)], [(a, b), (b, a)]]
    print(experiments)
    for exp in experiments:
        exp_name = str(exp)
        plot_output_dir = f'data/flyvis_data/motif_feedback_test/moving_edge/triplet_{idx}/{exp_name}/'
        result_output_path = f'{plot_output_dir}/res.csv'
        pert = FlyvisCellTypePert()
        pert.perturb(cell_type_df, PerturbationType.PAIR_WISE, pairs=exp)

        wrapper = MovingEdgeWrapper(dataset, pert=pert,
                                    pert_folder_name=f'triplet_idx{idx}-{exp_name}_pairwise_perturbation',
                                    output_file_name=result_output_path,
                                    plot_output_dir=plot_output_dir)
        wrapper.run()

In [None]:
from PIL import Image
import matplotlib.pyplot as plt

def create_comparison_figure(original_dir, perturbed_dir, perturbation_name, output_path, 
                            cell_type_family="T4"):
    """
    Create a side-by-side comparison figure showing original vs perturbed tuning curves.
    
    Parameters:
    -----------
    original_dir : str
        Path to original/unperturbed figures
    perturbed_dir : str
        Path to perturbed figures for this perturbation
    perturbation_name : str
        Name of the perturbation (e.g., 'Mi4_T4d')
    output_path : str
        Where to save the comparison figure
    cell_type_family : str
        Either "T4" or "T5" to select which cell types to plot
    """
    # Select cell types and intensity based on family
    if cell_type_family == "T4":
        cell_types = ["T4a", "T4b", "T4c", "T4d"]
        intensity = 1
    elif cell_type_family == "T5":
        cell_types = ["T5a", "T5b", "T5c", "T5d"]
        intensity = 0
    else:
        raise ValueError("cell_type_family must be 'T4' or 'T5'")
    
    # Create figure with 2 rows, 4 columns
    # Row 0: Original, Row 1: Perturbed
    # Columns: T4a/T5a, T4b/T5b, T4c/T5c, T4d/T5d
    fig, axes = plt.subplots(2, 4, figsize=(20, 10))
    
    for idx, cell_type in enumerate(cell_types):
        plot_filename = f"{cell_type}_tuning_intensity{intensity}.png"
        
        # Load original (top row)
        original_path = os.path.join(original_dir, plot_filename)
        if os.path.exists(original_path):
            img_original = Image.open(original_path)
            axes[0, idx].imshow(img_original)
            axes[0, idx].axis('off')
            axes[0, idx].set_title(f"{cell_type}", fontsize=12, fontweight='bold')
        else:
            axes[0, idx].text(0.5, 0.5, 'Not found', ha='center', va='center')
            axes[0, idx].axis('off')
        
        # Load perturbed (bottom row)
        perturbed_path = os.path.join(perturbed_dir, perturbation_name, plot_filename)
        if os.path.exists(perturbed_path):
            img_perturbed = Image.open(perturbed_path)
            axes[1, idx].imshow(img_perturbed)
            axes[1, idx].axis('off')
        else:
            axes[1, idx].text(0.5, 0.5, 'Not found', ha='center', va='center')
            axes[1, idx].axis('off')
    
    # Add row labels on the left
    fig.text(0.02, 0.75, 'Original', fontsize=14, fontweight='bold', 
             va='center', ha='right', rotation=90)
    fig.text(0.02, 0.25, f'Perturbed\n{perturbation_name}', fontsize=14, fontweight='bold',
             va='center', ha='right', rotation=90)
    
    # Add title
    fig.suptitle(f"{cell_type_family} Tuning Curves (intensity={intensity})", 
                 fontsize=16, fontweight='bold', y=0.98)
    
    plt.tight_layout(rect=[0.03, 0, 1, 0.96])
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"Saved comparison figure: {output_path}")


# Generate comparison figures for all perturbations
original_dir = 'data/flyvis_data/perf/moving_edge_original_fig'
perturbed_dir = 'data/flyvis_data/perf/moving_edge_pairwise_fig'
output_dir = 'data/flyvis_data/perf/comparison_figures'

os.makedirs(output_dir, exist_ok=True)


In [None]:
from PIL import Image
import matplotlib.pyplot as plt

def create_comparison_figure(original_dir, perturbed_dir, perturbation_name, output_path, 
                            cell_type_family="T4", plot_type="tuning"):
    """
    Create a side-by-side comparison figure showing original vs perturbed plots.
    
    Parameters:
    -----------
    original_dir : str
        Path to original/unperturbed figures
    perturbed_dir : str
        Path to perturbed figures for this perturbation
    perturbation_name : str
        Name of the perturbation (e.g., 'Mi4_T4d')
    output_path : str
        Where to save the comparison figure
    cell_type_family : str
        Either "T4" or "T5" to select which cell types to plot
    plot_type : str
        Either "tuning" for tuning curves or "dynamics" for time traces
    """
    # Select cell types and intensity based on family
    if cell_type_family == "T4":
        cell_types = ["T4a", "T4b", "T4c", "T4d"]
        intensity = 1
    elif cell_type_family == "T5":
        cell_types = ["T5a", "T5b", "T5c", "T5d"]
        intensity = 0
    else:
        raise ValueError("cell_type_family must be 'T4' or 'T5'")
    
    # Determine filename pattern based on plot type
    if plot_type == "tuning":
        filename_pattern = lambda ct: f"{ct}_tuning_intensity{intensity}.png"
        title_suffix = "Tuning Curves"
    elif plot_type == "dynamics":
        filename_pattern = lambda ct: f"{ct}_intensity{intensity}.png"
        title_suffix = "Response Dynamics"
    else:
        raise ValueError("plot_type must be 'tuning' or 'dynamics'")
    
    # Create figure with 2 rows, 4 columns
    fig, axes = plt.subplots(2, 4, figsize=(20, 10))
    
    for idx, cell_type in enumerate(cell_types):
        plot_filename = filename_pattern(cell_type)
        
        # Load original (top row)
        original_path = os.path.join(original_dir, plot_filename)
        if os.path.exists(original_path):
            img_original = Image.open(original_path)
            axes[0, idx].imshow(img_original)
            axes[0, idx].axis('off')
            axes[0, idx].set_title(f"{cell_type}", fontsize=12, fontweight='bold')
        else:
            axes[0, idx].text(0.5, 0.5, 'Not found', ha='center', va='center')
            axes[0, idx].axis('off')
        
        # Load perturbed (bottom row)
        perturbed_path = os.path.join(perturbed_dir, perturbation_name, plot_filename)
        if os.path.exists(perturbed_path):
            img_perturbed = Image.open(perturbed_path)
            axes[1, idx].imshow(img_perturbed)
            axes[1, idx].axis('off')
        else:
            axes[1, idx].text(0.5, 0.5, 'Not found', ha='center', va='center')
            axes[1, idx].axis('off')
    
    # Add row labels on the left
    fig.text(0.02, 0.75, 'Original', fontsize=14, fontweight='bold', 
             va='center', ha='right', rotation=90)
    fig.text(0.02, 0.25, f'Perturbed\n{perturbation_name}', fontsize=14, fontweight='bold',
             va='center', ha='right', rotation=90)
    
    # Add title
    fig.suptitle(f"{cell_type_family} {title_suffix} (intensity={intensity})", 
                 fontsize=16, fontweight='bold', y=0.98)
    
    plt.tight_layout(rect=[0.03, 0, 1, 0.96])
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"Saved comparison figure: {output_path}")


In [None]:
pairwise_hits = ['CT1(Lo1)_T5b', 'CT1(M10)_Mi1', 'Mi4_T4d', 'R8_Mi4', 'Tm9_T5c', 'TmY15_Tm4', 'Mi1_T4a', 'R8_Mi1', 'Mi4_T4d', 'Mi4_Tm9']
#pairwise_hits = ['CT1(Lo1)_T5b']

In [None]:
# Generate comparison figures for all perturbations
original_dir = 'data/flyvis_data/perf/moving_edge_original_fig'
perturbed_dir = 'data/flyvis_data/perf/moving_edge_pairwise_fig'
output_base_dir = 'data/flyvis_data/perf/comparison_figures'

# Loop through all pairwise hits, both cell type families, and both plot types
for hit in pairwise_hits:
    src, tar = hit.split('_')
    perturbation_name = f"{src}_{tar}"
    
    # Create subfolder for this perturbation
    output_dir = os.path.join(output_base_dir, perturbation_name)
    os.makedirs(output_dir, exist_ok=True)
    
    for cell_family in ["T4", "T5"]:
        # Tuning curves
        output_path = os.path.join(output_dir, 
                                  f"comparison_{cell_family}_tuning.png")
        create_comparison_figure(original_dir, perturbed_dir, perturbation_name, output_path,
                                cell_type_family=cell_family, plot_type="tuning")
        
        # Dynamics
        output_path = os.path.join(output_dir, 
                                  f"comparison_{cell_family}_dynamics.png")
        create_comparison_figure(original_dir, perturbed_dir, perturbation_name, output_path,
                                cell_type_family=cell_family, plot_type="dynamics")

do the same for outgoing perturbations

In [28]:
# Generate comparison figures for all outgoing perturbations
original_dir = 'data/flyvis_data/perf/moving_edge_original_fig'
perturbed_dir = 'data/flyvis_data/perf/moving_edge_outgoing_fig'  # Changed to outgoing
output_base_dir = 'data/flyvis_data/perf/comparison_figures_outgoing'  # Separate folder

outgoing_hits = ['Tm9', 'CT1(M10)', 'Mi2', 'Mi4', 'R8', 'Mi1', 'R8']

# Loop through all outgoing hits, both cell type families, and both plot types
for src in outgoing_hits:
    perturbation_name = f"{src}_outgoing"  # Changed naming scheme
    
    # Create subfolder for this perturbation
    output_dir = os.path.join(output_base_dir, src)  # Just use src as folder name
    os.makedirs(output_dir, exist_ok=True)
    
    for cell_family in ["T4", "T5"]:
        # Tuning curves
        output_path = os.path.join(output_dir, 
                                  f"comparison_{cell_family}_tuning.png")
        create_comparison_figure(original_dir, perturbed_dir, perturbation_name, output_path,
                                cell_type_family=cell_family, plot_type="tuning")
        
        # Dynamics
        output_path = os.path.join(output_dir, 
                                  f"comparison_{cell_family}_dynamics.png")
        create_comparison_figure(original_dir, perturbed_dir, perturbation_name, output_path,
                                cell_type_family=cell_family, plot_type="dynamics")

Saved comparison figure: data/flyvis_data/perf/comparison_figures_outgoing\Tm9\comparison_T4_tuning.png
Saved comparison figure: data/flyvis_data/perf/comparison_figures_outgoing\Tm9\comparison_T4_dynamics.png
Saved comparison figure: data/flyvis_data/perf/comparison_figures_outgoing\Tm9\comparison_T5_tuning.png
Saved comparison figure: data/flyvis_data/perf/comparison_figures_outgoing\Tm9\comparison_T5_dynamics.png
Saved comparison figure: data/flyvis_data/perf/comparison_figures_outgoing\CT1(M10)\comparison_T4_tuning.png
Saved comparison figure: data/flyvis_data/perf/comparison_figures_outgoing\CT1(M10)\comparison_T4_dynamics.png
Saved comparison figure: data/flyvis_data/perf/comparison_figures_outgoing\CT1(M10)\comparison_T5_tuning.png
Saved comparison figure: data/flyvis_data/perf/comparison_figures_outgoing\CT1(M10)\comparison_T5_dynamics.png
Saved comparison figure: data/flyvis_data/perf/comparison_figures_outgoing\Mi2\comparison_T4_tuning.png
Saved comparison figure: data/flyvis