In [2]:
import pandas as pd
import os
import numpy as np
import seaborn as sns 
import plotly.express as px
from sklearn.decomposition import PCA
import plotly.graph_objects as go
import plotly.express as px
import matplotlib.cm as cm

os.chdir("/net/trapnell/vol1/home/mdcolon/proj/morphseq")

from src.functions.embryo_df_performance_metrics import *
from src.functions.spline_morph_spline_metrics import *

model_index = 74

results_dir = "/net/trapnell/vol1/home/mdcolon/proj/morphseq/results/mcolon/20250407"
data_dir = os.path.join(results_dir, "data")
plot_dir = os.path.join(results_dir, "plots")

os.makedirs(plot_dir, exist_ok=True)
os.makedirs(data_dir, exist_ok=True)

df_orig = pd.read_csv("/net/trapnell/vol1/home/mdcolon/proj/morphseq/results/mcolon/20250315/data/embryo_stats_df.csv")
z_mu_columns = [col for col in df_orig.columns if 'z_mu' in col]    
z_mu_biological_columns = [col for col in z_mu_columns if "b" in col]

In [3]:
# Filter for relevant cep290 perturbations
df_cep290 = df_orig[df_orig["short_pert_name"].isin(['cep290_het_cep290', 'wt_cep290', 'cep290_homo_cep290'])]

# Assign phenotype based on short_pert_name
df_cep290["phenotype"] = df_cep290["short_pert_name"]

# Filter wild-type samples
df_wt = df_orig[df_orig["phenotype"].isin([ "wt"])]
df_wt = df_wt[~df_wt["short_pert_name"].isin(['cep290_het_cep290', 'wt_cep290', 'cep290_homo_cep290'])] #wt only needed

# Combine wild-type and cep290 datasets
df = pd.concat([df_wt, df_cep290], ignore_index=True)

# Filter by temperature
df = df[df["temperature"].isin([30.0, 29.0, 22.0])]

# Append temperature to phenotype as a suffix
df["phenotype"] = df["phenotype"] + "_temp" + df["temperature"].astype(int).astype(str)

#
flagged_embryo_ids = ["20250305_G09_e00","20250305_B02_e00", "20250305_H12_e00", "20250305_F07_e00" ]

df = df[~df["embryo_id"].isin(flagged_embryo_ids)]


df = apply_pca_on_pert_comparisons(df  ,z_mu_biological_columns )

In [6]:
pert_splines, df_augmented, segment_info_df = build_splines_and_segments(
    df=df,
    # comparisons=pert_comparisons,
    # save_dir=data_dir,
    model_index=model_index,
    LocalPrincipalCurveClass=LocalPrincipalCurve,
    bandwidth=0.5,
    max_iter=250,
    tol=1e-3,
    angle_penalty_exp=2,
    early_stage_offset=1.0,
    late_stage_offset=3.0,
    k=50
)

In [19]:
import os
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split, KFold  # note KFold (not StratifiedKFold, since class=1 is uniform)
# from your_module import logistic_regression_multiclass, plot_avg_predictions_multiclass


def split_control_once_by_embryo_id(df_ctrl, embryo_id_col='embryo_id', test_size=0.2, random_state=42):
    """
    Splits control (class=0) samples by embryo ID into train/test sets.
    Returns (control_train_df, control_test_df).
    """
    unique_ids = df_ctrl[embryo_id_col].unique()
    
    # Single split of embryo IDs for control only
    ctrl_train_ids, ctrl_test_ids = train_test_split(
        unique_ids,
        test_size=test_size,
        random_state=random_state
    )
    
    control_train_df = df_ctrl[df_ctrl[embryo_id_col].isin(ctrl_train_ids)].copy()
    control_test_df  = df_ctrl[df_ctrl[embryo_id_col].isin(ctrl_test_ids)].copy()
    
    print(f"Control split by embryo ID: "
          f"train={len(control_train_df)}, test={len(control_test_df)}")
    
    return control_train_df, control_test_df


def kfold_by_embryo_id_single_class(
    df_class1,
    embryo_id_col='embryo_id',
    n_splits=3,
    shuffle=True,
    random_state=42
):
    """
    Perform K-fold on a single class=1 set of embryo IDs. 
    Yields (train_1_df, test_1_df) for each fold, ensuring each embryo ID 
    is in test exactly once.

    df_class1: DataFrame containing *only* the phenotype=1 samples
    embryo_id_col: column containing embryo IDs
    n_splits: number of folds
    shuffle: whether to shuffle IDs
    random_state: seed for reproducibility
    """
    # Unique embryo IDs for class=1
    unique_ids = df_class1[embryo_id_col].unique()

    # Use KFold since all these samples are the same class=1
    kf = KFold(n_splits=n_splits, shuffle=shuffle, random_state=random_state)
    
    for train_idx, test_idx in kf.split(unique_ids):
        train_ids = set(unique_ids[train_idx])
        test_ids  = set(unique_ids[test_idx])
        
        train_1_df = df_class1[df_class1[embryo_id_col].isin(train_ids)].copy()
        test_1_df  = df_class1[df_class1[embryo_id_col].isin(test_ids)].copy()
        
        yield train_1_df, test_1_df


# -------------------------------
# Example usage below
# -------------------------------

# 1) Filter your main DataFrame
ctrl = "wt_cep290_temp30"
phenotype = "cep290_homo_cep290_temp30"
pert_comparisons_pair = [ctrl, phenotype]

df_filtered = df[
    (df["predicted_stage_hpf"] > 20) & 
    (df["phenotype"].isin(pert_comparisons_pair))
].copy()

# Create a numeric label
df_filtered["class_num"] = df_filtered["phenotype"].apply(lambda x: 0 if x == ctrl else 1)

# Separate control (0) vs. mutant (1)
df_ctrl = df_filtered[df_filtered["class_num"] == 0].copy()
df_pheno = df_filtered[df_filtered["class_num"] == 1].copy()

# 2) Split control once by embryo ID, so some control is always in train, some in test
control_train_df, control_test_df = split_control_once_by_embryo_id(
    df_ctrl, embryo_id_col='embryo_id', test_size=0.2, random_state=42
)

# 3) K-fold on the *phenotype=1* subset by embryo ID
n_splits = 3
pheno_kf = kfold_by_embryo_id_single_class(
    df_pheno,
    embryo_id_col='embryo_id',
    n_splits=n_splits,
    shuffle=True,
    random_state=123  # or any choice for reproducibility
)

# We'll collect results across folds
fold_accuracies = []
fold_coeffs = []

plot_dir = "./kfold_only_on_class1_results"
os.makedirs(plot_dir, exist_ok=True)

for fold_idx, (pheno_train_df, pheno_test_df) in enumerate(pheno_kf, start=1):
    print(f"\n=== Fold {fold_idx}/{n_splits} ===")
    print(f"  Mutant train: {len(pheno_train_df)}  |  Mutant test: {len(pheno_test_df)}")

    # Combine control + phenotype
    train_df_all = pd.concat([control_train_df, pheno_train_df], ignore_index=True)
    test_df_all  = pd.concat([control_test_df,  pheno_test_df],  ignore_index=True)
    
    print(f"  Final train size: {len(train_df_all)}  |  Final test size: {len(test_df_all)}")
    
    # 4) Train your logistic regression model (assuming it's a function you have)
    (y_test_all,
     y_pred_proba_all,
     log_reg_all,
     train_df_out,
     test_df_out) = logistic_regression_multiclass(
        train_df_all,
        test_df_all,
        z_mu_biological_columns,
        pert_comparisons_pair,
        tol=1e-5
    )

    # 5) Evaluate on the full test set (both classes)
    accuracy = log_reg_all.score(
        test_df_all[z_mu_biological_columns],
        test_df_all["class_num"]
    )
    fold_accuracies.append(accuracy)
    print(f"  Fold Accuracy: {accuracy:.4f}")

    # 6) Coefficients
    coef_df = pd.DataFrame({
        "feature": z_mu_biological_columns,
        "coefficient": log_reg_all.coef_[0]
    }).sort_values("coefficient", ascending=False)
    fold_coeffs.append(coef_df)

    # (Optional) Save plots, etc.
    fold_plot_dir = os.path.join(plot_dir, f"fold_{fold_idx}")
    os.makedirs(fold_plot_dir, exist_ok=True)

    plot_avg_predictions_multiclass(
        test_df_out,
        y_pred_proba_all,
        pert_comparisons=pert_comparisons_pair,
        window_size=20,
        max_hpf=80,
        save_dir=fold_plot_dir,
        filename=f"kfold_{fold_idx}_{phenotype}_vs_{ctrl}_predictions.html"
    )

    # Save coefficients
    coef_path = os.path.join(fold_plot_dir, f"coefficients_fold_{fold_idx}.csv")
    coef_df.to_csv(coef_path, index=False)
    print(f"  Coefficients saved to: {coef_path}")

# 7) Summarize results across folds
if fold_accuracies:
    mean_acc = np.mean(fold_accuracies)
    std_acc  = np.std(fold_accuracies)
    print(f"\n== Summary of {n_splits}-fold CROSS-VAL (only on phenotype=1) ==")
    print(f"   Mean Accuracy: {mean_acc:.4f} ± {std_acc:.4f}")

    all_coeffs_df = pd.concat(fold_coeffs, keys=range(1, len(fold_coeffs) + 1))
    all_coeffs_path = os.path.join(plot_dir, "all_fold_coefficients.csv")
    all_coeffs_df.to_csv(all_coeffs_path, index=False)
    print(f"   All fold coefficients saved to '{all_coeffs_path}'.")
else:
    print("No folds were processed. Check data or parameters.")

In [22]:
df_augmented[df_augmented["phenotype"] == "cep290_homo_cep290_temp30"]["embryo_id"].unique()

In [29]:
set(segment_info_df.columns)
segment_info_df


In [None]:
wt_splines_n_planes = segment_info_df[segment_info_df["phenotype"] == "wt_cep290_temp30"]

# A) Points from "wt" itself
# wt_pert_df = df_augmented[df_augmented["phenotype"] == "wt"]
df_points = project_points_onto_reference_spline(
    df_augmented,
    wt_splines_n_planes
)

df_augmented_projec_wt = pd.merge(
    df_augmented, 
    df_points.drop(columns=[col for col in df_points.columns if col in df_augmented.columns and col != "snip_id"]),
    on="snip_id"
)

In [31]:
def plot_hypotenuse_over_stage(
    df,
    phenotypes_to_include=None,
    window_size=5,
    figsize=(12, 8),
    palette="tab10",
    alpha=0.5,
    plot_individual_embryos=True,
    plot_average=True,
    plot_median=False,
    highlight_embryos=None,
    highlight_phenotypes=None,
    highlight_alpha=0.9,
    max_hpf=None,
    min_points_per_embryo=5,
    save_path=None,
    show_legend=True,
    title=None
):
    """
    Plot hypotenuse (distance from spline) over predicted stage with smoothing.
    
    Parameters:
    -----------
    df : pd.DataFrame
        DataFrame containing the data, must have columns:
        ['hypotenuse', 'embryo_id', 'phenotype', 'predicted_stage_hpf']
    phenotypes_to_include : list, optional
        List of phenotypes to include in the plot. If None, all phenotypes are included.
    window_size : int, default=5
        Window size for rolling average smoothing.
    figsize : tuple, default=(12, 8)
        Figure size (width, height) in inches.
    palette : str or dict, default="tab10"
        Color palette name or dict mapping phenotypes to colors.
    alpha : float, default=0.5
        Transparency level for individual embryo lines.
    plot_individual_embryos : bool, default=True
        Whether to plot individual embryo traces.
    plot_average : bool, default=True
        Whether to plot average line per phenotype.
    plot_median : bool, default=False
        Whether to plot median line per phenotype.
    highlight_embryos : list, optional
        List of embryo_ids to highlight with thicker lines.
    highlight_phenotypes : list, optional
        List of phenotypes to highlight with higher opacity.
    highlight_alpha : float, default=0.9
        Transparency level for highlighted phenotypes.
    max_hpf : float, optional
        Maximum hours post-fertilization to include in the plot.
    min_points_per_embryo : int, default=5
        Minimum number of data points required for an embryo to be included.
    save_path : str, optional
        Path to save the figure, if provided.
    show_legend : bool, default=True
        Whether to show the legend.
    title : str, optional
        Plot title. If None, a default title is used.
    
    Returns:
    --------
    matplotlib.figure.Figure
        The created figure
    
    Example:
    --------
    # Basic usage with all phenotypes
    fig = plot_hypotenuse_over_stage(my_dataframe)
    
    # Highlighting specific phenotypes
    fig = plot_hypotenuse_over_stage(
        my_dataframe,
        highlight_phenotypes=['cep290_homo_cep290_temp30'],
        phenotypes_to_include=['cep290_het_cep290_temp30', 'wt_cep290_temp30', 'cep290_homo_cep290_temp30']
    )
    
    # Plotting median instead of average
    fig = plot_hypotenuse_over_stage(
        my_dataframe,
        plot_average=False,
        plot_median=True
    )
    
    # Customizing appearance
    fig = plot_hypotenuse_over_stage(
        my_dataframe,
        figsize=(15, 10),
        palette="Set2",
        alpha=0.3,
        highlight_alpha=1.0,
        title="My Custom Plot"
    )
    """
    import matplotlib.pyplot as plt
    import seaborn as sns
    import pandas as pd
    import numpy as np
    
    # Make a copy to avoid modifying the original
    df = df.copy()
    
    # Ensure required columns exist
    required_cols = ['hypotenuse', 'embryo_id', 'phenotype', 'predicted_stage_hpf']
    missing_cols = [col for col in required_cols if col not in df.columns]
    if missing_cols:
        raise ValueError(f"Missing required columns: {missing_cols}")
    
    # Filter by phenotype if specified
    if phenotypes_to_include is not None:
        df = df[df['phenotype'].isin(phenotypes_to_include)]
    
    # Filter by max_hpf if specified
    if max_hpf is not None:
        df = df[df['predicted_stage_hpf'] <= max_hpf]
    
    # Filter embryos with too few data points
    embryo_counts = df.groupby('embryo_id').size()
    valid_embryos = embryo_counts[embryo_counts >= min_points_per_embryo].index
    df = df[df['embryo_id'].isin(valid_embryos)]
    
    # If DataFrame is empty after filtering, return empty plot
    if df.empty:
        plt.figure(figsize=figsize)
        plt.text(0.5, 0.5, "No data available after filtering", 
                 horizontalalignment='center', verticalalignment='center',
                 fontsize=14)
        plt.gca().set_axis_off()
        return plt.gcf()
    
    # Setup figure
    plt.figure(figsize=figsize)
    
    # Get unique phenotypes and assign colors
    unique_phenotypes = df['phenotype'].unique()
    if isinstance(palette, str):
        color_palette = sns.color_palette(palette, n_colors=len(unique_phenotypes))
        phenotype_colors = {phenotype: color_palette[i] for i, phenotype in enumerate(unique_phenotypes)}
    else:
        # If palette is a dict, use it directly
        phenotype_colors = palette
    
    # Store lines for legend
    phenotype_avg_lines = {}
    phenotype_med_lines = {}
    highlight_lines = {}
    
    # Plot individual embryos
    if plot_individual_embryos:
        for embryo_id, group in df.groupby('embryo_id'):
            phenotype = group['phenotype'].iloc[0]
            color = phenotype_colors.get(phenotype, 'gray')
            
            # Sort by predicted_stage_hpf
            group = group.sort_values('predicted_stage_hpf')
            
            # Apply smoothing with rolling window
            group['smooth_hypotenuse'] = group['hypotenuse'].rolling(
                window=window_size, min_periods=1, center=True
            ).mean()
            
            # Determine line properties
            is_highlight_embryo = highlight_embryos is not None and embryo_id in highlight_embryos
            is_highlight_phenotype = highlight_phenotypes is not None and phenotype in highlight_phenotypes
            
            # Set alpha and line width based on highlight status
            line_alpha = highlight_alpha if (is_highlight_embryo or is_highlight_phenotype) else alpha
            line_width = 2.5 if is_highlight_embryo else 1
            
            # Plot the line
            line = plt.plot(
                group['predicted_stage_hpf'], 
                group['smooth_hypotenuse'], 
                color=color, 
                alpha=line_alpha,
                linewidth=line_width,
                label=None  # We'll add to legend separately
            )
            
            # Store for legend if this is a highlight embryo
            if is_highlight_embryo:
                highlight_lines[embryo_id] = line[0]
    
    # Plot average line per phenotype
    if plot_average:
        for phenotype, group in df.groupby('phenotype'):
            color = phenotype_colors.get(phenotype, 'gray')
            
            # Group by predicted_stage_hpf bins (0.5 hour increments)
            group['stage_bin'] = (group['predicted_stage_hpf'] * 2).astype(int) / 2
            
            # Calculate mean per bin
            stage_means = group.groupby('stage_bin')['hypotenuse'].mean().reset_index()
            
            # Sort by stage
            stage_means = stage_means.sort_values('stage_bin')
            
            # Apply smoothing
            stage_means['smooth_hypotenuse'] = stage_means['hypotenuse'].rolling(
                window=window_size, min_periods=1, center=True
            ).mean()
            
            # Determine if this is a highlighted phenotype
            is_highlight = highlight_phenotypes is not None and phenotype in highlight_phenotypes
            line_alpha = highlight_alpha if is_highlight else 1.0
            
            # Plot average line
            line = plt.plot(
                stage_means['stage_bin'], 
                stage_means['smooth_hypotenuse'], 
                color=color, 
                linewidth=5,
                alpha=line_alpha,
                linestyle='-',
                label=f"{phenotype} (mean, n={len(group['embryo_id'].unique())})"
            )
            
            # Store for legend
            phenotype_avg_lines[phenotype] = line[0]
    
    # Plot median line per phenotype
    if plot_median:
        for phenotype, group in df.groupby('phenotype'):
            color = phenotype_colors.get(phenotype, 'gray')
            
            # Group by predicted_stage_hpf bins (0.5 hour increments)
            group['stage_bin'] = (group['predicted_stage_hpf'] * 2).astype(int) / 2
            
            # Calculate median per bin
            stage_medians = group.groupby('stage_bin')['hypotenuse'].median().reset_index()
            
            # Sort by stage
            stage_medians = stage_medians.sort_values('stage_bin')
            
            # Apply smoothing
            stage_medians['smooth_hypotenuse'] = stage_medians['hypotenuse'].rolling(
                window=window_size, min_periods=1, center=True
            ).mean()
            
            # Determine if this is a highlighted phenotype
            is_highlight = highlight_phenotypes is not None and phenotype in highlight_phenotypes
            line_alpha = highlight_alpha if is_highlight else 1.0
            
            # Plot median line
            line = plt.plot(
                stage_medians['stage_bin'], 
                stage_medians['smooth_hypotenuse'], 
                color=color, 
                linewidth=5,
                alpha=line_alpha,
                linestyle='--',  # Use dashed line to distinguish from mean
                label=f"{phenotype} (median, n={len(group['embryo_id'].unique())})"
            )
            
            # Store for legend
            phenotype_med_lines[phenotype] = line[0]
    
    # Set title and labels
    title = title or "Distance from Spline by Developmental Stage"
    plt.title(title, fontsize=14)
    plt.xlabel('Predicted Stage (hpf)', fontsize=12)
    plt.ylabel('Distance from Spline (Hypotenuse)', fontsize=12)
    
    # Add legend
    if show_legend:
        # Combine all lines for the legend
        all_lines = {}
        all_labels = []
        
        # Add mean lines
        if plot_average and phenotype_avg_lines:
            for phenotype, line in phenotype_avg_lines.items():
                all_lines[f"{phenotype} (mean)"] = line
                all_labels.append(f"{phenotype} (mean, n={len(df[df['phenotype']==phenotype]['embryo_id'].unique())})")
        
        # Add median lines
        if plot_median and phenotype_med_lines:
            for phenotype, line in phenotype_med_lines.items():
                all_lines[f"{phenotype} (median)"] = line
                all_labels.append(f"{phenotype} (median, n={len(df[df['phenotype']==phenotype]['embryo_id'].unique())})")
        
        # Add highlighted embryos
        if highlight_embryos and highlight_lines:
            for embryo_id, line in highlight_lines.items():
                all_lines[f"Embryo {embryo_id}"] = line
                all_labels.append(f"Embryo {embryo_id}")
        
        # Create the legend if we have any lines
        if all_lines:
            plt.legend(
                handles=list(all_lines.values()),
                labels=all_labels,
                title="Phenotypes and Highlighted Embryos",
                loc="best",
                fontsize=10
            )
    
    # Set grid
    plt.grid(True, linestyle='--', alpha=0.7)
    
    # Adjust layout
    plt.tight_layout()
    
    # Save if requested
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    return plt.gcf()

In [34]:
# Basic usage - plotting all phenotypes with default settings
fig = plot_hypotenuse_over_stage(df_augmented_projec_wt)

# Highlighting the cep290_homo_cep290_temp30 phenotype as requested
fig = plot_hypotenuse_over_stage(
    df_augmented_projec_wt,
    phenotypes_to_include=[
        'cep290_het_cep290_temp30',
        'wt_cep290_temp30',
        'cep290_homo_cep290_temp30'
    ],
    highlight_phenotypes=['cep290_homo_cep290_temp30'],
    highlight_alpha=0.9
)

# Showing both mean and median lines for comparison
fig = plot_hypotenuse_over_stage(
    df_augmented_projec_wt,
    plot_average=True,
    plot_median=True,
    phenotypes_to_include=[
        'cep290_het_cep290_temp30',
        'wt_cep290_temp30',
        'cep290_homo_cep290_temp30'
    ]
)

# Focus only on median with no individual embryo traces
fig = plot_hypotenuse_over_stage(
    df_augmented_projec_wt,
    plot_average=False,
    plot_median=True,
    plot_individual_embryos=False,
    title="Median Distance from Spline by Phenotype"
)

In [None]:
# Showing both mean and median lines for comparison
fig = plot_hypotenuse_over_stage(
    df_augmented_projec_wt,
    plot_average=True,
    plot_median=True,
    phenotypes_to_include=[
        'cep290_het_cep290_temp30',
        'wt_cep290_temp30',
        'cep290_homo_cep290_temp30'
    ],
    window_size=1,
    alpha=.1,
    highlight_phenotypes=['cep290_homo_cep290_temp30'])

In [None]:
# Define phenotypes of interest (must match those in color_dict)
phenotypes_of_interest = ['wt_temp30', 'cep290_het_cep290_temp30', 
                         'wt_cep290_temp30', 'cep290_homo_cep290_temp30']

# Filter dataframes to only include phenotypes of interest
filtered_points = df_augmented[df_augmented['phenotype'].isin(phenotypes_of_interest)]
filtered_splines = pert_splines[pert_splines['phenotype'].isin(phenotypes_of_interest)]

In [41]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

def rank_embryos_by_max_hypotenuse(df, min_stage=30, max_stage=35):
    """
    Ranks embryos by their maximum hypotenuse value within a specific predicted stage interval.
    Includes the snip_id associated with the maximum hypotenuse value.
    
    Parameters:
    -----------
    df : pd.DataFrame
        DataFrame containing the data, must have columns:
        ['hypotenuse', 'embryo_id', 'phenotype', 'predicted_stage_hpf', 'snip_id']
    min_stage : float, default=30
        Minimum predicted stage (hpf) to include
    max_stage : float, default=35
        Maximum predicted stage (hpf) to include
    
    Returns:
    --------
    pd.DataFrame
        A DataFrame with ranked embryos, their max hypotenuse, corresponding stage, snip_id, and phenotype
    
    Example:
    --------
    # Basic usage
    ranked_df = rank_embryos_by_max_hypotenuse(df_augmented_projec_wt)
    
    # Custom stage interval
    ranked_df = rank_embryos_by_max_hypotenuse(df_augmented_projec_wt, min_stage=28, max_stage=32)
    """
    import pandas as pd
    
    # Make a copy to avoid modifying the original
    df = df.copy()
    
    # Ensure required columns exist
    required_cols = ['hypotenuse', 'embryo_id', 'phenotype', 'predicted_stage_hpf']
    missing_cols = [col for col in required_cols if col not in df.columns]
    if missing_cols:
        raise ValueError(f"Missing required columns: {missing_cols}")
    
    # Check if snip_id column exists, if not, add a warning
    has_snip_id = 'snip_id' in df.columns
    if not has_snip_id:
        print("Warning: 'snip_id' column not found in the DataFrame. Adding placeholder values.")
    
    # Filter data within the specified stage interval
    stage_filtered = df[(df['predicted_stage_hpf'] > min_stage) & (df['predicted_stage_hpf'] < max_stage)]
    
    if stage_filtered.empty:
        columns = ['embryo_id', 'phenotype', 'max_hypotenuse', 'stage_at_max_hpf']
        if has_snip_id:
            columns.append('snip_id')
        columns.append('rank')
        return pd.DataFrame(columns=columns)
    
    # For each embryo, find the maximum hypotenuse value and corresponding information
    result_rows = []
    
    for embryo_id, embryo_data in stage_filtered.groupby('embryo_id'):
        if embryo_data.empty:
            continue
            
        # Find the row with maximum hypotenuse for this embryo
        max_idx = embryo_data['hypotenuse'].idxmax()
        max_hypotenuse_row = embryo_data.loc[max_idx]
        
        result_row = {
            'embryo_id': embryo_id,
            'phenotype': max_hypotenuse_row['phenotype'],
            'max_hypotenuse': max_hypotenuse_row['hypotenuse'],
            'stage_at_max_hpf': max_hypotenuse_row['predicted_stage_hpf']
        }
        
        # Add snip_id if available
        if has_snip_id:
            result_row['snip_id'] = max_hypotenuse_row['snip_id']
        elif 'snip_id' in max_hypotenuse_row:
            result_row['snip_id'] = max_hypotenuse_row['snip_id']
        else:
            result_row['snip_id'] = f"unknown_{embryo_id}_{max_idx}"
            
        result_rows.append(result_row)
    
    # Create DataFrame from results
    result_df = pd.DataFrame(result_rows)
    
    # Sort by max_hypotenuse in descending order and add rank
    if not result_df.empty:
        result_df = result_df.sort_values('max_hypotenuse', ascending=False).reset_index(drop=True)
        result_df['rank'] = result_df.index + 1
    
    return result_df

# Example usage:
# ranked_embryos = rank_embryos_by_max_hypotenuse(df_augmented_projec_wt)
# display(ranked_embryos.head(10))  # Show top 10 embryos with highest max hypotenuse

# Second function: Visualize the top ranked embryos
def visualize_ranked_embryos(df, ranked_df, min_stage=30, max_stage=35, top_n=10, figsize=(15, 10)):
    """
    Visualizes the top ranked embryos in the specified stage interval.
    
    Parameters:
    -----------
    df : pd.DataFrame
        Original DataFrame with all embryo data
    ranked_df : pd.DataFrame
        DataFrame from rank_embryos_by_max_hypotenuse function
    min_stage : float, default=30
        Minimum predicted stage (hpf) to include
    max_stage : float, default=35
        Maximum predicted stage (hpf) to include
    top_n : int, default=10
        Number of top-ranked embryos to visualize
    figsize : tuple, default=(15, 10)
        Figure size
    """
    # Get the top N embryos
    top_embryos = ranked_df.head(top_n)
    
    # Filter original data for these embryos and the stage interval
    df = df.copy()
    filtered_data = df[
        (df['embryo_id'].isin(top_embryos['embryo_id'])) & 
        (df['predicted_stage_hpf'] > min_stage) & 
        (df['predicted_stage_hpf'] < max_stage)
    ]
    
    # Set up the plot
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
    
    # 1. Create a bar chart of max hypotenuse values by embryo ID
    sns.barplot(
        x='embryo_id', 
        y='max_hypotenuse', 
        hue='phenotype',
        data=top_embryos,
        ax=ax1
    )
    ax1.set_title(f'Maximum Hypotenuse Values (Stage {min_stage}-{max_stage} hpf)')
    ax1.set_xlabel('Embryo ID')
    ax1.set_ylabel('Maximum Hypotenuse')
    ax1.tick_params(axis='x', rotation=45)
    
    # 2. Plot trajectories of hypotenuse over stage for the top embryos
    for embryo_id, group in filtered_data.groupby('embryo_id'):
        phenotype = group['phenotype'].iloc[0]
        
        # Find the embryo rank for label
        rank = top_embryos[top_embryos['embryo_id'] == embryo_id]['rank'].iloc[0]
        
        # Sort by predicted_stage_hpf
        group = group.sort_values('predicted_stage_hpf')
        
        # Plot the line
        ax2.plot(
            group['predicted_stage_hpf'], 
            group['hypotenuse'], 
            label=f"#{rank} - {embryo_id} ({phenotype})"
        )
        
        # Highlight maximum point
        max_point = group.loc[group['hypotenuse'].idxmax()]
        ax2.scatter(
            max_point['predicted_stage_hpf'],
            max_point['hypotenuse'],
            s=100, 
            marker='o',
            edgecolors='black'
        )
    
    ax2.set_title(f'Hypotenuse Trajectories in Stage {min_stage}-{max_stage} hpf')
    ax2.set_xlabel('Predicted Stage (hpf)')
    ax2.set_ylabel('Hypotenuse')
    ax2.grid(True, linestyle='--', alpha=0.7)
    ax2.legend(loc='upper right', fontsize=8)
    
    plt.tight_layout()
    return fig

# Third function: Generate a statistical summary of hypotenuse values by phenotype
def summarize_hypotenuse_by_phenotype(df, min_stage=30, max_stage=35):
    """
    Generates a statistical summary of hypotenuse values grouped by phenotype
    within the specified stage interval.
    
    Parameters:
    -----------
    df : pd.DataFrame
        DataFrame containing the data
    min_stage : float, default=30
        Minimum predicted stage (hpf) to include
    max_stage : float, default=35
        Maximum predicted stage (hpf) to include
        
    Returns:
    --------
    pd.DataFrame
        A DataFrame with statistics by phenotype
    """
    # Filter data within the specified stage interval
    stage_filtered = df[(df['predicted_stage_hpf'] > min_stage) & (df['predicted_stage_hpf'] < max_stage)]
    
    # Group by phenotype and calculate statistics
    summary = stage_filtered.groupby('phenotype')['hypotenuse'].agg([
        ('count', 'count'),
        ('mean', 'mean'), 
        ('median', 'median'),
        ('std', 'std'),
        ('min', 'min'), 
        ('max', 'max'),
        ('25%', lambda x: x.quantile(0.25)),
        ('75%', lambda x: x.quantile(0.75))
    ]).reset_index()
    
    # Count unique embryos per phenotype
    embryo_counts = stage_filtered.groupby('phenotype')['embryo_id'].nunique().reset_index()
    embryo_counts.columns = ['phenotype', 'unique_embryos']
    
    # Merge the counts with the summary
    summary = pd.merge(summary, embryo_counts, on='phenotype')
    
    return summary

# Example usage
def analyze_embryos(df, min_stage=30, max_stage=35, top_n=8):
    """
    Complete workflow to analyze embryos in the given stage interval.
    
    Parameters:
    -----------
    df : pd.DataFrame
        DataFrame containing the embryo data
    min_stage : float, default=30
        Minimum predicted stage (hpf) to include
    max_stage : float, default=35
        Maximum predicted stage (hpf) to include
    
    Returns:
    --------
    tuple
        (ranked_embryos_df, phenotype_summary_df, visualization_fig)
    """
    # 1. Rank embryos by max hypotenuse
    ranked_df = rank_embryos_by_max_hypotenuse(df, min_stage, max_stage)
    
    # 2. Summarize by phenotype
    summary_df = summarize_hypotenuse_by_phenotype(df, min_stage, max_stage)
    
    # 3. Visualize top embryos
    fig = visualize_ranked_embryos(df, ranked_df, min_stage, max_stage, top_n=top_n)
    
    return ranked_df, summary_df, fig

# Usage example:

df_analyze = df_augmented_projec_wt[df_augmented_projec_wt["phenotype"].isin([ 'cep290_het_cep290_temp30','wt_cep290_temp30' ,'cep290_homo_cep290_temp30'])]
ranked_df, summary_df, fig = analyze_embryos(df_analyze)

# Display the table of ranked embryos
print("Top embryos ranked by maximum hypotenuse value:")
display(ranked_df.head(15))

# Display summary statistics by phenotype
print("\nSummary statistics by phenotype:")
display(summary_df)

# Show the visualization
plt.show()

In [43]:
df_analyze = df_augmented_projec_wt[df_augmented_projec_wt["phenotype"].isin([ 'cep290_homo_cep290_temp30'])]
ranked_df, summary_df, fig = analyze_embryos(df_analyze)

# Display the table of ranked embryos
print("Top embryos ranked by maximum hypotenuse value:")
display(ranked_df.head(15))

# Display summary statistics by phenotype
print("\nSummary statistics by phenotype:")
display(summary_df)

# Show the visualization
plt.show()

In [44]:
df_augmented_projec_wt[df_augmented_projec_wt["embryo_id"].isin([ '20250305_C06_e00'])][["predicted_stage_hpf","snip_id","hypotenuse"]]

In [45]:
df_analyze = df_augmented_projec_wt[df_augmented_projec_wt["phenotype"].isin([ 'cep290_homo_cep290_temp30'])]
ranked_df, summary_df, fig = analyze_embryos(df_analyze, min_stage=20, max_stage=70)

# Display the table of ranked embryos
print("Top embryos ranked by maximum hypotenuse value:")
display(ranked_df.head(20))

# Display summary statistics by phenotype
print("\nSummary statistics by phenotype:")
display(summary_df)

# Show the visualization
plt.show()

In [47]:
df_analyze = df_augmented_projec_wt[df_augmented_projec_wt["phenotype"].isin([ 'cep290_homo_cep290_temp30'])]
ranked_df, summary_df, fig = analyze_embryos(df_analyze, min_stage=20, max_stage=70, top_n=18)

# Display the table of ranked embryos
print("Top embryos ranked by maximum hypotenuse value:")
display(ranked_df.head(20))

# Display summary statistics by phenotype
print("\nSummary statistics by phenotype:")
display(summary_df)

# Show the visualization
plt.show()




In [53]:
visualize_ranked_embryos(df_analyze, ranked_df, min_stage=20, max_stage=80, top_n=18)
plt.show()

visualize_ranked_embryos(df_analyze, ranked_df, min_stage=30, max_stage=80, top_n=18)
plt.show()


visualize_ranked_embryos(df_analyze, ranked_df, min_stage=33, max_stage=80, top_n=18)
plt.show()

## used this for imaging

In [55]:
df_analyze = df_augmented_projec_wt[df_augmented_projec_wt["phenotype"].isin([ 'cep290_homo_cep290_temp30'])]
ranked_df, summary_df, fig = analyze_embryos(df_analyze, min_stage=33, max_stage=70)

# Display the table of ranked embryos
print("Top embryos ranked by maximum hypotenuse value:")
display(ranked_df.head(20))

# Display summary statistics by phenotype
print("\nSummary statistics by phenotype:")
display(summary_df)

# Show the visualization
plt.show()


visualize_ranked_embryos(df_analyze, ranked_df, min_stage=30, max_stage=80, top_n=18)
plt.show()


In [56]:
df_analyze = df_augmented_projec_wt[df_augmented_projec_wt["phenotype"].isin([ 'wt_cep290_temp30',"cep290_het_cep290_temp30","cep290_homo_cep290_temp30"])]
ranked_df, summary_df, fig = analyze_embryos(df_analyze, min_stage=33, max_stage=70)

# Display the table of ranked embryos
print("Top embryos ranked by maximum hypotenuse value:")
display(ranked_df.head(20))

# Display summary statistics by phenotype
print("\nSummary statistics by phenotype:")
display(summary_df)


# Show the visualization
plt.show()

visualize_ranked_embryos(df_analyze, ranked_df, min_stage=30, max_stage=80, top_n=18)
plt.show()


In [81]:
df_augmented_projec_wt[df_augmented_projec_wt["embryo_id"].isin([ '20250305_F01_e00'])][["predicted_stage_hpf","snip_id","hypotenuse","phenotype"]]

In [82]:
df_augmented_projec_wt[df_augmented_projec_wt["embryo_id"].isin([ '20250305_B07_e00'])][["predicted_stage_hpf","snip_id","hypotenuse","phenotype"]]

In [68]:
df_augmented[df_augmented["phenotype"] == "wt_cep290_temp30"]["embryo_id"].unique()

In [75]:
df_analyze = df_augmented_projec_wt[df_augmented_projec_wt["phenotype"].isin([ 'cep290_het_cep290_temp30'])]
ranked_df, summary_df, fig = analyze_embryos(df_analyze, min_stage=33, max_stage=70)

# Display the table of ranked embryos
print("Top embryos ranked by maximum hypotenuse value:")
display(ranked_df.head(20))

# Display summary statistics by phenotype
print("\nSummary statistics by phenotype:")
display(summary_df)

# Show the visualization
plt.show()


visualize_ranked_embryos(df_analyze, ranked_df, min_stage=30, max_stage=80, top_n=18)
plt.show()

In [76]:
df_analyze = df_augmented_projec_wt[df_augmented_projec_wt["phenotype"].isin([ 'wt_cep290_temp30'])]
ranked_df, summary_df, fig = analyze_embryos(df_analyze, min_stage=33, max_stage=70)

# Display the table of ranked embryos
print("Top embryos ranked by maximum hypotenuse value:")
display(ranked_df.head(20))

# Display summary statistics by phenotype
print("\nSummary statistics by phenotype:")
display(summary_df)

# Show the visualization
plt.show()


visualize_ranked_embryos(df_analyze, ranked_df, min_stage=30, max_stage=80, top_n=18)
plt.show()

# Zscore over stage

In [None]:
def plot_hypotenuse_zscore_over_stage(
    df,
    phenotypes_to_include=None,
    window_size=5,  # For smoothing
    stage_bin_width=0.5,  # Size of stage windows for z-score calculation
    figsize=(12, 8),
    palette="tab10",
    alpha=0.5,
    plot_individual_embryos=True,
    plot_average=True,
    plot_median=False,
    highlight_embryos=None,
    highlight_phenotypes=None,
    highlight_alpha=0.9,
    max_hpf=None,
    min_points_per_embryo=5,
    save_path=None,
    show_legend=True,
    title=None,
    z_score_method='all'  # 'all', 'per_phenotype'
):
    """
    Plot z-score of hypotenuse (distance from spline) over predicted stage with smoothing.
    Z-scores are calculated within stage windows to normalize values across developmental progression.
    
    Parameters:
    -----------
    df : pd.DataFrame
        DataFrame containing the data, must have columns:
        ['hypotenuse', 'embryo_id', 'phenotype', 'predicted_stage_hpf']
    phenotypes_to_include : list, optional
        List of phenotypes to include in the plot. If None, all phenotypes are included.
    window_size : int, default=5
        Window size for rolling average smoothing of the plot lines.
    stage_bin_width : float, default=0.5
        Width of stage windows (in hpf) for z-score calculation.
    figsize : tuple, default=(12, 8)
        Figure size (width, height) in inches.
    palette : str or dict, default="tab10"
        Color palette name or dict mapping phenotypes to colors.
    alpha : float, default=0.5
        Transparency level for individual embryo lines.
    plot_individual_embryos : bool, default=True
        Whether to plot individual embryo traces.
    plot_average : bool, default=True
        Whether to plot average line per phenotype.
    plot_median : bool, default=False
        Whether to plot median line per phenotype.
    highlight_embryos : list, optional
        List of embryo_ids to highlight with thicker lines.
    highlight_phenotypes : list, optional
        List of phenotypes to highlight with higher opacity.
    highlight_alpha : float, default=0.9
        Transparency level for highlighted phenotypes.
    max_hpf : float, optional
        Maximum hours post-fertilization to include in the plot.
    min_points_per_embryo : int, default=5
        Minimum number of data points required for an embryo to be included.
    save_path : str, optional
        Path to save the figure, if provided.
    show_legend : bool, default=True
        Whether to show the legend.
    title : str, optional
        Plot title. If None, a default title is used.
    z_score_method : str, default='all'
        Method for z-score calculation: 
        - 'all': z-scores calculated across all embryos within each stage window
        - 'per_phenotype': z-scores calculated separately per phenotype within each stage window
    
    Returns:
    --------
    matplotlib.figure.Figure
        The created figure
    
    Example:
    --------
    # Basic usage with all phenotypes
    fig = plot_hypotenuse_zscore_over_stage(my_dataframe)
    
    # Highlighting specific phenotypes with finer stage binning
    fig = plot_hypotenuse_zscore_over_stage(
        my_dataframe,
        highlight_phenotypes=['cep290_homo_cep290_temp30'],
        phenotypes_to_include=['cep290_het_cep290_temp30', 'wt_cep290_temp30', 'cep290_homo_cep290_temp30'],
        stage_bin_width=0.25
    )
    
    # Calculate z-scores separately for each phenotype
    fig = plot_hypotenuse_zscore_over_stage(
        my_dataframe,
        z_score_method='per_phenotype'
    )
    """
    import matplotlib.pyplot as plt
    import seaborn as sns
    import pandas as pd
    import numpy as np
    from scipy import stats
    
    # Make a copy to avoid modifying the original
    df = df.copy()
    
    # Ensure required columns exist
    required_cols = ['hypotenuse', 'embryo_id', 'phenotype', 'predicted_stage_hpf']
    missing_cols = [col for col in required_cols if col not in df.columns]
    if missing_cols:
        raise ValueError(f"Missing required columns: {missing_cols}")
    
    # Filter by phenotype if specified
    if phenotypes_to_include is not None:
        df = df[df['phenotype'].isin(phenotypes_to_include)]
    
    # Filter by max_hpf if specified
    if max_hpf is not None:
        df = df[df['predicted_stage_hpf'] <= max_hpf]
    
    # Filter embryos with too few data points
    embryo_counts = df.groupby('embryo_id').size()
    valid_embryos = embryo_counts[embryo_counts >= min_points_per_embryo].index
    df = df[df['embryo_id'].isin(valid_embryos)]
    
    # If DataFrame is empty after filtering, return empty plot
    if df.empty:
        plt.figure(figsize=figsize)
        plt.text(0.5, 0.5, "No data available after filtering", 
                 horizontalalignment='center', verticalalignment='center',
                 fontsize=14)
        plt.gca().set_axis_off()
        return plt.gcf()
    
    # Create stage bins for z-score calculation
    df['stage_bin'] = (df['predicted_stage_hpf'] / stage_bin_width).astype(int) * stage_bin_width
    
    # Calculate z-scores within each stage bin
    if z_score_method == 'all':
        # Calculate z-scores across all embryos within each stage bin
        df['hypotenuse_zscore'] = df.groupby('stage_bin')['hypotenuse'].transform(
            lambda x: stats.zscore(x, nan_policy='omit') if len(x) > 1 else np.zeros(len(x))
        )
    elif z_score_method == 'per_phenotype':
        # Calculate z-scores separately for each phenotype within each stage bin
        df['hypotenuse_zscore'] = df.groupby(['stage_bin', 'phenotype'])['hypotenuse'].transform(
            lambda x: stats.zscore(x, nan_policy='omit') if len(x) > 1 else np.zeros(len(x))
        )
    else:
        raise ValueError(f"Invalid z_score_method: {z_score_method}. Use 'all' or 'per_phenotype'.")
    
    # Setup figure
    plt.figure(figsize=figsize)
    
    # Get unique phenotypes and assign colors
    unique_phenotypes = df['phenotype'].unique()
    if isinstance(palette, str):
        color_palette = sns.color_palette(palette, n_colors=len(unique_phenotypes))
        phenotype_colors = {phenotype: color_palette[i] for i, phenotype in enumerate(unique_phenotypes)}
    else:
        # If palette is a dict, use it directly
        phenotype_colors = palette
    
    # Store lines for legend
    phenotype_avg_lines = {}
    phenotype_med_lines = {}
    highlight_lines = {}
    
    # Plot individual embryos
    if plot_individual_embryos:
        for embryo_id, group in df.groupby('embryo_id'):
            phenotype = group['phenotype'].iloc[0]
            color = phenotype_colors.get(phenotype, 'gray')
            
            # Sort by predicted_stage_hpf
            group = group.sort_values('predicted_stage_hpf')
            
            # Apply smoothing with rolling window
            group['smooth_zscore'] = group['hypotenuse_zscore'].rolling(
                window=window_size, min_periods=1, center=True
            ).mean()
            
            # Determine line properties
            is_highlight_embryo = highlight_embryos is not None and embryo_id in highlight_embryos
            is_highlight_phenotype = highlight_phenotypes is not None and phenotype in highlight_phenotypes
            
            # Set alpha and line width based on highlight status
            line_alpha = highlight_alpha if (is_highlight_embryo or is_highlight_phenotype) else alpha
            line_width = 2.5 if is_highlight_embryo else 1
            
            # Plot the line
            line = plt.plot(
                group['predicted_stage_hpf'], 
                group['smooth_zscore'], 
                color=color, 
                alpha=line_alpha,
                linewidth=line_width,
                label=None  # We'll add to legend separately
            )
            
            # Store for legend if this is a highlight embryo
            if is_highlight_embryo:
                highlight_lines[embryo_id] = line[0]
    
    # Plot average line per phenotype
    if plot_average:
        for phenotype, group in df.groupby('phenotype'):
            color = phenotype_colors.get(phenotype, 'gray')
            
            # Group by stage bins (as created earlier)
            # Calculate mean per bin
            stage_means = group.groupby('stage_bin')['hypotenuse_zscore'].mean().reset_index()
            
            # Sort by stage
            stage_means = stage_means.sort_values('stage_bin')
            
            # Apply smoothing
            stage_means['smooth_zscore'] = stage_means['hypotenuse_zscore'].rolling(
                window=window_size, min_periods=1, center=True
            ).mean()
            
            # Determine if this is a highlighted phenotype
            is_highlight = highlight_phenotypes is not None and phenotype in highlight_phenotypes
            line_alpha = highlight_alpha if is_highlight else 1.0
            
            # Plot average line
            line = plt.plot(
                stage_means['stage_bin'], 
                stage_means['smooth_zscore'], 
                color=color, 
                linewidth=5,
                alpha=line_alpha,
                linestyle='-',
                label=f"{phenotype} (mean, n={len(group['embryo_id'].unique())})"
            )
            
            # Store for legend
            phenotype_avg_lines[phenotype] = line[0]
    
    # Plot median line per phenotype
    if plot_median:
        for phenotype, group in df.groupby('phenotype'):
            color = phenotype_colors.get(phenotype, 'gray')
            
            # Calculate median per bin (using the stage bins created earlier)
            stage_medians = group.groupby('stage_bin')['hypotenuse_zscore'].median().reset_index()
            
            # Sort by stage
            stage_medians = stage_medians.sort_values('stage_bin')
            
            # Apply smoothing
            stage_medians['smooth_zscore'] = stage_medians['hypotenuse_zscore'].rolling(
                window=window_size, min_periods=1, center=True
            ).mean()
            
            # Determine if this is a highlighted phenotype
            is_highlight = highlight_phenotypes is not None and phenotype in highlight_phenotypes
            line_alpha = highlight_alpha if is_highlight else 1.0
            
            # Plot median line
            line = plt.plot(
                stage_medians['stage_bin'], 
                stage_medians['smooth_zscore'], 
                color=color, 
                linewidth=5,
                alpha=line_alpha,
                linestyle='--',  # Use dashed line to distinguish from mean
                label=f"{phenotype} (median, n={len(group['embryo_id'].unique())})"
            )
            
            # Store for legend
            phenotype_med_lines[phenotype] = line[0]
    
    # Add reference lines for std deviations
    plt.axhline(y=0, color='black', linestyle='-', alpha=0.3)
    plt.axhline(y=1, color='red', linestyle=':', alpha=0.5)
    plt.axhline(y=-1, color='red', linestyle=':', alpha=0.5)
    plt.axhline(y=2, color='red', linestyle='--', alpha=0.5)
    plt.axhline(y=-2, color='red', linestyle='--', alpha=0.5)
    
    # Set title and labels
    method_label = "across all phenotypes" if z_score_method == 'all' else "within each phenotype"
    default_title = f"Z-Score of Distance from Spline ({method_label})"
    title = title or default_title
    plt.title(title, fontsize=14)
    plt.xlabel('Predicted Stage (hpf)', fontsize=12)
    plt.ylabel('Z-Score of Hypotenuse', fontsize=12)
    
    # Add legend
    if show_legend:
        # Combine all lines for the legend
        all_lines = {}
        all_labels = []
        
        # Add mean lines
        if plot_average and phenotype_avg_lines:
            for phenotype, line in phenotype_avg_lines.items():
                all_lines[f"{phenotype} (mean)"] = line
                all_labels.append(f"{phenotype} (mean, n={len(df[df['phenotype']==phenotype]['embryo_id'].unique())})")
        
        # Add median lines
        if plot_median and phenotype_med_lines:
            for phenotype, line in phenotype_med_lines.items():
                all_lines[f"{phenotype} (median)"] = line
                all_labels.append(f"{phenotype} (median, n={len(df[df['phenotype']==phenotype]['embryo_id'].unique())})")
        
        # Add highlighted embryos
        if highlight_embryos and highlight_lines:
            for embryo_id, line in highlight_lines.items():
                all_lines[f"Embryo {embryo_id}"] = line
                all_labels.append(f"Embryo {embryo_id}")
        
        # Create the legend if we have any lines
        if all_lines:
            plt.legend(
                handles=list(all_lines.values()),
                labels=all_labels,
                title="Phenotypes and Highlighted Embryos",
                loc="best",
                fontsize=10
            )
    
    # Set grid
    plt.grid(True, linestyle='--', alpha=0.7)
    
    # Adjust layout
    plt.tight_layout()
    
    # Save if requested
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    return plt.gcf()

def identify_abnormal_embryos(
    df, 
    stage_bin_width=0.5,
    z_threshold=2.0,
    consecutive_bins=2,
    z_score_method='all',
    min_stage=None,
    max_stage=None
):
    """
    Identify embryos with abnormally high or low hypotenuse values
    based on z-scores within development stage windows.
    
    Parameters:
    -----------
    df : pd.DataFrame
        DataFrame containing the data, must have columns:
        ['hypotenuse', 'embryo_id', 'phenotype', 'predicted_stage_hpf']
    stage_bin_width : float, default=0.5
        Width of stage windows (in hpf) for z-score calculation
    z_threshold : float, default=2.0
        Z-score threshold to consider abnormal (absolute value)
    consecutive_bins : int, default=2
        Number of consecutive bins required to be abnormal
    z_score_method : str, default='all'
        Method for z-score calculation: 
        - 'all': z-scores calculated across all embryos 
        - 'per_phenotype': z-scores calculated separately per phenotype
    min_stage : float, optional
        Minimum stage (hpf) to include in analysis
    max_stage : float, optional
        Maximum stage (hpf) to include in analysis
    
    Returns:
    --------
    dict
        Dictionary with:
        - 'abnormal_embryos': list of embryo IDs identified as abnormal
        - 'abnormal_summary': DataFrame with summary of abnormal embryos
        - 'z_score_df': DataFrame with all calculated z-scores
    
    Example:
    --------
    # Basic usage
    results = identify_abnormal_embryos(my_dataframe)
    abnormal_embryos = results['abnormal_embryos']
    
    # More stringent criteria
    results = identify_abnormal_embryos(
        my_dataframe,
        z_threshold=2.5,
        consecutive_bins=3
    )
    
    # Then use the identified embryos in the plotting function
    fig = plot_hypotenuse_zscore_over_stage(
        my_dataframe,
        highlight_embryos=results['abnormal_embryos']
    )
    """
    import pandas as pd
    import numpy as np
    from scipy import stats
    
    # Make a copy to avoid modifying the original
    df = df.copy()
    
    # Filter by stage range if specified
    if min_stage is not None:
        df = df[df['predicted_stage_hpf'] >= min_stage]
    if max_stage is not None:
        df = df[df['predicted_stage_hpf'] <= max_stage]
    
    # Create stage bins
    df['stage_bin'] = (df['predicted_stage_hpf'] / stage_bin_width).astype(int) * stage_bin_width
    
    # Calculate z-scores within each stage bin
    if z_score_method == 'all':
        # Calculate z-scores across all embryos within each stage bin
        df['hypotenuse_zscore'] = df.groupby('stage_bin')['hypotenuse'].transform(
            lambda x: stats.zscore(x, nan_policy='omit') if len(x) > 1 else np.zeros(len(x))
        )
    elif z_score_method == 'per_phenotype':
        # Calculate z-scores separately for each phenotype within each stage bin
        df['hypotenuse_zscore'] = df.groupby(['stage_bin', 'phenotype'])['hypotenuse'].transform(
            lambda x: stats.zscore(x, nan_policy='omit') if len(x) > 1 else np.zeros(len(x))
        )
    else:
        raise ValueError(f"Invalid z_score_method: {z_score_method}. Use 'all' or 'per_phenotype'.")
    
    # Identify abnormal data points (absolute z-score > threshold)
    df['is_abnormal'] = abs(df['hypotenuse_zscore']) > z_threshold
    
    # Check for consecutive abnormal bins per embryo
    abnormal_embryos = []
    embryo_abnormal_summary = []
    
    for embryo_id, embryo_df in df.groupby('embryo_id'):
        # Sort by stage
        embryo_df = embryo_df.sort_values('predicted_stage_hpf')
        
        # Check for consecutive abnormal bins
        consecutive_count = 0
        max_consecutive = 0
        abnormal_stages = []
        max_zscore = 0
        max_zscore_stage = None
        phenotype = embryo_df['phenotype'].iloc[0]  # Get phenotype for this embryo
        
        for _, row in embryo_df.iterrows():
            if row['is_abnormal']:
                consecutive_count += 1
                abnormal_stages.append(row['predicted_stage_hpf'])
                if abs(row['hypotenuse_zscore']) > abs(max_zscore):
                    max_zscore = row['hypotenuse_zscore']
                    max_zscore_stage = row['predicted_stage_hpf']
            else:
                consecutive_count = 0
            
            max_consecutive = max(max_consecutive, consecutive_count)
        
        # If embryo meets the criteria for abnormality
        if max_consecutive >= consecutive_bins:
            abnormal_embryos.append(embryo_id)
            
            embryo_abnormal_summary.append({
                'embryo_id': embryo_id,
                'phenotype': phenotype,
                'max_consecutive_abnormal': max_consecutive,
                'max_zscore': max_zscore,
                'max_zscore_stage': max_zscore_stage,
                'abnormal_stages': abnormal_stages
            })
    
    # Create summary DataFrame
    if embryo_abnormal_summary:
        abnormal_summary = pd.DataFrame(embryo_abnormal_summary)
        abnormal_summary = abnormal_summary.sort_values('max_zscore', ascending=False)
    else:
        abnormal_summary = pd.DataFrame(columns=[
            'embryo_id', 'phenotype', 'max_consecutive_abnormal', 
            'max_zscore', 'max_zscore_stage', 'abnormal_stages'
        ])
    
    return {
        'abnormal_embryos': abnormal_embryos,
        'abnormal_summary': abnormal_summary,
        'z_score_df': df
    }

# display(results['abnormal_summary'])

In [89]:
# 1. First identify embryos with abnormal development patterns
results = identify_abnormal_embryos(
    df_augmented_projec_wt,
    stage_bin_width=0.5,  # Stage window size in hpf
    z_threshold=2.0,      # Statistical threshold
    consecutive_bins=2    # How many consecutive abnormal bins required
)

# 2. Display summary of abnormal embryos
display(results['abnormal_summary'])

# 3. Visualize the z-scores with abnormal embryos highlighted
fig = plot_hypotenuse_zscore_over_stage(
    df_augmented_projec_wt,
    highlight_embryos=results['abnormal_embryos'],
    phenotypes_to_include=['cep290_het_cep290_temp30', 'wt_cep290_temp30', 'cep290_homo_cep290_temp30'],
    highlight_phenotypes=['cep290_homo_cep290_temp30']
)

# 4. Save the visualization if needed
# plt.savefig('embryo_zscore_analysis.png', dpi=300, bbox_inches='tight')

In [92]:
# 3. Visualize the z-scores with abnormal embryos highlighted
fig = plot_hypotenuse_zscore_over_stage(
    df_augmented_projec_wt,
    # highlight_embryos=results['abnormal_embryos'],
    phenotypes_to_include=['cep290_het_cep290_temp30', 'wt_cep290_temp30', 'cep290_homo_cep290_temp30'],
    highlight_phenotypes=['cep290_homo_cep290_temp30']
)


fig = plot_hypotenuse_zscore_over_stage(
    df_augmented_projec_wt,
    # highlight_embryos=results['abnormal_embryos'],
    phenotypes_to_include=['cep290_het_cep290_temp30', 'wt_cep290_temp30', 'cep290_homo_cep290_temp30'],
    highlight_phenotypes=['cep290_homo_cep290_temp30'],
    stage_bin_width=1,
)


fig = plot_hypotenuse_zscore_over_stage(
    df_augmented_projec_wt,
    # highlight_embryos=results['abnormal_embryos'],
    phenotypes_to_include=['cep290_het_cep290_temp30', 'wt_cep290_temp30', 'cep290_homo_cep290_temp30'],
    highlight_phenotypes=['cep290_homo_cep290_temp30'],
    stage_bin_width=5,
)


fig = plot_hypotenuse_zscore_over_stage(
    df_augmented_projec_wt,
    # highlight_embryos=results['abnormal_embryos'],
    phenotypes_to_include=['cep290_het_cep290_temp30', 'wt_cep290_temp30', 'cep290_homo_cep290_temp30'],
    highlight_phenotypes=['cep290_homo_cep290_temp30'],
    stage_bin_width=5,
    window_size=1,
)

In [93]:
wt_splines_n_planes_not_cep290 = segment_info_df[segment_info_df["phenotype"] == "wt_temp30"]

# A) Points from "wt" itself
# wt_pert_df = df_augmented[df_augmented["phenotype"] == "wt"]
df_points = project_points_onto_reference_spline(
    df_augmented,
    wt_splines_n_planes_not_cep290 
)

df_augmented_projec_wt_not_cep290 = pd.merge(
    df_augmented, 
    df_points.drop(columns=[col for col in df_points.columns if col in df_augmented.columns and col != "snip_id"]),
    on="snip_id"
)

In [96]:
wt_splines_n_planes_not_cep290

# Create the mapping dictionary from seg_id to segment_avg_time
mapping = dict(zip(wt_splines_n_planes_not_cep290["seg_id"], wt_splines_n_planes_not_cep290["segment_avg_time"]))

# Use the mapping to add the new column to df_augmented_projec_wt_not_cep290
df_augmented_projec_wt_not_cep290["ref_segment_avg_time"] = df_augmented_projec_wt_not_cep290["ref_seg_id"].map(mapping)

In [97]:
set(df_augmented_projec_wt_not_cep290.columns)

In [102]:
import matplotlib.pyplot as plt

# Assume the DataFrame is named df_augmented_projec_wt_not_cep290
# and that it contains the columns "predicted_stage hpf", "ref_segment_avg_time", and "phenotype".

# Create a new figure and axis
fig, ax = plt.subplots(figsize=(10, 8))  # width=10, height=8 inches

# Get a list of the unique phenotype values
unique_phenotypes = df_augmented_projec_wt_not_cep290["phenotype"].unique()

# Plot the data, color-coding points by phenotype
for phenotype in unique_phenotypes:
    subset = df_augmented_projec_wt_not_cep290[df_augmented_projec_wt_not_cep290["phenotype"] == phenotype]
    ax.scatter(
        subset["predicted_stage_hpf"],
        subset["ref_segment_avg_time"],
        label=phenotype,
        s=10
    )


# Determine the range for the y=x line based on the predicted_stage hpf values
x_min = df_augmented_projec_wt_not_cep290["predicted_stage_hpf"].min()
x_max = df_augmented_projec_wt_not_cep290["predicted_stage_hpf"].max()

# Plot the y=x line
ax.plot([x_min, x_max], [x_min, x_max], 'k--', label='y=x')


# Label the axes and add a legend
ax.set_xlabel("Predicted Stage hpf")
ax.set_ylabel("Reference Segment Avg Time")
ax.legend(title="Phenotype")

# Optionally, add a title to the plot
ax.set_title("Ref Segment Avg Time by Predicted Stage hpf (Colored by Phenotype)")

# Display the plot
plt.show()

In [104]:
import matplotlib.pyplot as plt

# Assume the DataFrame is named df_augmented_projec_wt_not_cep290
# and that it contains the columns "predicted_stage hpf", "ref_segment_avg_time", and "phenotype".

# Create a new figure and axis
fig, ax = plt.subplots(figsize=(10, 8))  # width=10, height=8 inches

# Get a list of the unique phenotype values
unique_phenotypes = ["wt_temp30",
        'cep290_het_cep290_temp30',
        'wt_cep290_temp30',
        'cep290_homo_cep290_temp30'
    ]

# Plot the data, color-coding points by phenotype
for phenotype in unique_phenotypes:
    subset = df_augmented_projec_wt_not_cep290[df_augmented_projec_wt_not_cep290["phenotype"] == phenotype]
    subset = subset[subset["predicted_stage_hpf"]<35]
    ax.scatter(
        subset["predicted_stage_hpf"],
        subset["ref_segment_avg_time"],
        label=phenotype,
        s=10
    )


# Determine the range for the y=x line based on the predicted_stage hpf values
x_min = df_augmented_projec_wt_not_cep290["predicted_stage_hpf"].min()
x_max = 40

# Plot the y=x line
ax.plot([x_min, x_max], [x_min, x_max], 'k--', label='y=x')


# Label the axes and add a legend
ax.set_xlabel("Predicted Stage hpf")
ax.set_ylabel("Reference Segment Avg Time")
ax.legend(title="Phenotype")

# Optionally, add a title to the plot
ax.set_title("Ref Segment Avg Time by Predicted Stage hpf (Colored by Phenotype)")

# Display the plot
plt.show()

In [113]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# Calculate the delta for each data point in the specified time range
unique_phenotypes = [
    'cep290_het_cep290_temp30',
    'wt_cep290_temp30',
    'cep290_homo_cep290_temp30',
    'wt_temp30'
]


df_filtered = df_augmented_projec_wt_not_cep290[
                                     (df_augmented_projec_wt_not_cep290['predicted_stage_hpf'] < 35)]

df_filtered = df_filtered[df_filtered["phenotype"].isin(unique_phenotypes)]
# Calculate the delta (difference) between ref_segment_avg_time and predicted_stage_hpf
df_filtered['delta'] = df_filtered['ref_segment_avg_time'] - df_filtered['predicted_stage_hpf']

# Calculate average delta for each embryo
embryo_deltas = df_filtered.groupby(['embryo_id', 'phenotype'])['delta'].mean().reset_index()

# Sort by phenotype and then by delta for better visualization
embryo_deltas = embryo_deltas.sort_values(['phenotype', 'delta'])

# Create figure
fig, ax = plt.subplots(figsize=(14, 8))

# Get unique phenotypes for color mapping


# Define a color map for phenotypes
color_map = {
    'cep290_het_cep290_temp30': '#FF9999',  # Light red
    'wt_cep290_temp30': '#66B2FF',          # Light blue
    'cep290_homo_cep290_temp30': '#FF3333', # Dark red
    'wt_temp30': '#0066CC'                  # Dark blue
}

# Plot bars for each embryo, colored by phenotype
bar_positions = np.arange(len(embryo_deltas))
bars = ax.bar(
    bar_positions, 
    embryo_deltas['delta'],
    color=[color_map.get(p, 'gray') for p in embryo_deltas['phenotype']]
)

# Customize the plot
ax.set_xlabel('Embryo ID', fontsize=12)
ax.set_ylabel('Avg Delta (ref_segment_avg_time - predicted_stage_hpf)', fontsize=12)
ax.set_title('Average Time Delta by Embryo (<35 hpf)', fontsize=14)

# Set x-ticks to embryo IDs
ax.set_xticks(bar_positions)
ax.set_xticklabels(embryo_deltas['embryo_id'], rotation=90)

# Add a horizontal line at y=0 for reference
ax.axhline(y=0, color='black', linestyle='-', alpha=0.3)

# Add legend for phenotypes
from matplotlib.patches import Patch
legend_elements = [Patch(facecolor=color_map[p], label=p) for p in unique_phenotypes 
                  if p in embryo_deltas['phenotype'].values]
ax.legend(handles=legend_elements, title='Phenotype')

# Add grid for easier reading of values
ax.grid(axis='y', linestyle='--', alpha=0.7)

# Add summary statistics by phenotype
summary_stats = df_filtered.groupby('phenotype')['delta'].agg(['mean', 'std', 'count']).reset_index()
summary_text = "Summary by Phenotype:\n"
for _, row in summary_stats.iterrows():
    summary_text += f"{row['phenotype']}: mean={row['mean']:.2f}, std={row['std']:.2f}, n={int(row['count'])}\n"

plt.figtext(0.01, 0.01, summary_text, fontsize=10, 
            bbox=dict(facecolor='white', alpha=0.8))

plt.tight_layout()
plt.show()

In [112]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# Calculate the delta for each data point in the specified time range
unique_phenotypes = [
    'cep290_het_cep290_temp30',
    'wt_cep290_temp30',
    'cep290_homo_cep290_temp30',
]


df_filtered = df_augmented_projec_wt_not_cep290[
                                     (df_augmented_projec_wt_not_cep290['predicted_stage_hpf'] < 35)]

df_filtered = df_filtered[df_filtered["phenotype"].isin(unique_phenotypes)]
# Calculate the delta (difference) between ref_segment_avg_time and predicted_stage_hpf
df_filtered['delta'] = df_filtered['ref_segment_avg_time'] - df_filtered['predicted_stage_hpf']

# Calculate average delta for each embryo
embryo_deltas = df_filtered.groupby(['embryo_id', 'phenotype'])['delta'].mean().reset_index()

# Sort by phenotype and then by delta for better visualization
embryo_deltas = embryo_deltas.sort_values(['phenotype', 'delta'])

# Create figure
fig, ax = plt.subplots(figsize=(14, 8))

# Get unique phenotypes for color mapping


# Define a color map for phenotypes
color_map = {
    'cep290_het_cep290_temp30': '#FF9999',  # Light red
    'wt_cep290_temp30': '#66B2FF',          # Light blue
    'cep290_homo_cep290_temp30': '#FF3333', # Dark red
    'wt_temp30': '#0066CC'                  # Dark blue
}

# Plot bars for each embryo, colored by phenotype
bar_positions = np.arange(len(embryo_deltas))
bars = ax.bar(
    bar_positions, 
    embryo_deltas['delta'],
    color=[color_map.get(p, 'gray') for p in embryo_deltas['phenotype']]
)

# Customize the plot
ax.set_xlabel('Embryo ID', fontsize=12)
ax.set_ylabel('Avg Delta (ref_segment_avg_time - predicted_stage_hpf)', fontsize=12)
ax.set_title('Average Time Delta by Embryo (<35 hpf)', fontsize=14)

# Set x-ticks to embryo IDs
ax.set_xticks(bar_positions)
ax.set_xticklabels(embryo_deltas['embryo_id'], rotation=90)

# Add a horizontal line at y=0 for reference
ax.axhline(y=0, color='black', linestyle='-', alpha=0.3)

# Add legend for phenotypes
from matplotlib.patches import Patch
legend_elements = [Patch(facecolor=color_map[p], label=p) for p in unique_phenotypes 
                  if p in embryo_deltas['phenotype'].values]
ax.legend(handles=legend_elements, title='Phenotype')

# Add grid for easier reading of values
ax.grid(axis='y', linestyle='--', alpha=0.7)

# Add summary statistics by phenotype
summary_stats = df_filtered.groupby('phenotype')['delta'].agg(['mean', 'std', 'count']).reset_index()
summary_text = "Summary by Phenotype:\n"
for _, row in summary_stats.iterrows():
    summary_text += f"{row['phenotype']}: mean={row['mean']:.2f}, std={row['std']:.2f}, n={int(row['count'])}\n"

print(summary_text)

plt.tight_layout()
plt.show()

In [133]:
imp_cols = ["snip_id","predicted_stage_hpf","ref_segment_avg_time", "hypotenuse", "ref_seg_id"]

df_augmented_projec_wt_not_cep290[df_augmented_projec_wt_not_cep290["embryo_id"]=="20250305_A08_e00"][imp_cols]

In [130]:
imp_cols = ["snip_id","predicted_stage_hpf","ref_segment_avg_time", "hypotenuse", "ref_seg_id"]

df_augmented_projec_wt_not_cep290[df_augmented_projec_wt_not_cep290["embryo_id"]=="20250305_H10_e00"][imp_cols]

In [142]:
import matplotlib.pyplot as plt

# Filter the DataFrame
df_filtered = df_augmented_projec_wt_not_cep290[
    (df_augmented_projec_wt_not_cep290['predicted_stage_hpf'] > 50) &
    (df_augmented_projec_wt_not_cep290['phenotype'].str.contains("cep290"))
]

# Group by embryo_id and get the first phenotype for each
embryo_phenotypes = df_filtered.groupby('embryo_id')['phenotype'].first()

# Count number of embryos per phenotype
phenotype_counts = embryo_phenotypes.value_counts()

# Plot the bar chart
plt.figure(figsize=(8, 5))
phenotype_counts.plot(kind='bar')
plt.xlabel('Phenotype')
plt.ylabel('Number of Embryos')
plt.title('Embryo Counts per Phenotype (predicted_stage_hpf > 50 & phenotype contains "cep290")')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

In [144]:
df_augmented_projec_wt_not_cep290.to_csv(os.path.join(data_dir,"df_augmented_projec_wt_not_cep290.csv"))
df_augmented_projec_wt.to_csv(os.path.join(data_dir,"df_augmented_projec_wt.csv"))

In [143]:
data_dir

# redoing phenotypic seeverrity eith Macimum mvari

In [5]:
import numpy as np
import pandas as pd

def orient_segment_axes(
    segment_info_df: pd.DataFrame,
    first_n: int = 5,
    reference_point: np.ndarray = None
):
    """
    Ensures consistent orientation of each segment's principal axis (M)
    and defines a sedary in-plane axis (m) perpendicular to the segment direction.
    
    Steps:
      1) If no reference_point is provided, use the centroid of segment midpoints.
      2) Sort segments by seg_id.
      3) For the first `first_n` segments, flip principal_axis if needed so dot(principal_axis, ref_vec) > 0
      4) For each subsequent segment i+1, ensure dot(principal_axis_{i+1}, principal_axis_i) > 0
      5) Compute the segment direction in-plane axis (m) for each segment and ensure consistency
         by aligning axis_in_plane_i+1 with i if needed (dot > 0).
    
    Modifies segment_info_df in-place, adding columns:
      ['principal_axis_x', 'principal_axis_y', 'principal_axis_z',
       'axis_in_plane_x', 'axis_in_plane_y', 'axis_in_plane_z']
      with consistent orientation.
    """
    # 1) If no reference_point is provided, use centroid of all segment midpoints
    if reference_point is None:
        ref_pt = segment_info_df[["segment_midpoint_x","segment_midpoint_y","segment_midpoint_z"]].mean().values
    else:
        ref_pt = reference_point
    
    # Sort by seg_id so we do them in a natural order
    segment_info_df = segment_info_df.sort_values("seg_id").reset_index(drop=True)
    
    # 2) Convert columns to numpy arrays for easy manipulation
    principal_axes = segment_info_df[[
        "principal_axis_x","principal_axis_y","principal_axis_z"
    ]].values
    
    midpoints = segment_info_df[[
        "segment_midpoint_x","segment_midpoint_y","segment_midpoint_z"
    ]].values
    
    seg_starts = segment_info_df[[
        "segment_start_x","segment_start_y","segment_start_z"
    ]].values
    
    seg_ends = segment_info_df[[
        "segment_end_x","segment_end_y","segment_end_z"
    ]].values
    
    # 3) Normalize each principal_axis, possibly flip
    #    We'll do it in two passes:
    #    - For each axis, ensure it's normalized
    #    - Then flip orientation for the first_n segments if needed
    #    - Then ensure consecutive dot > 0
    for i in range(len(principal_axes)):
        norm_val = np.linalg.norm(principal_axes[i])
        if norm_val > 1e-12:
            principal_axes[i] /= norm_val
    
    # 4) Flip the first `first_n` segments if needed
    for i in range(min(first_n, len(principal_axes))):
        axis_i = principal_axes[i]
        midpoint_i = midpoints[i]
        ref_vec = ref_pt - midpoint_i
        # If dot < 0, flip
        if np.dot(axis_i, ref_vec) < 0:
            principal_axes[i] = -axis_i
    
    # 5) Align each subsequent axis with the previous one
    for i in range(1, len(principal_axes)):
        prev_axis = principal_axes[i-1]
        curr_axis = principal_axes[i]
        if np.dot(prev_axis, curr_axis) < 0:
            principal_axes[i] = -curr_axis
    
    # 6) Compute and orient the in-plane axis for each segment
    #    We'll store them in an array of shape (n_segments, 3)
    axis_in_plane = np.zeros_like(principal_axes)
    
    for i in range(len(principal_axes)):
        M_axis = principal_axes[i]  # already oriented
        seg_vec = seg_ends[i] - seg_starts[i]
        
        # project seg_vec onto plane normal to M_axis
        seg_vec_inplane = seg_vec - np.dot(seg_vec, M_axis) * M_axis
        norm_seg_inplane = np.linalg.norm(seg_vec_inplane)
        
        if norm_seg_inplane < 1e-12:
            # fallback: pick any vector orthonormal to M_axis
            # e.g. if M_axis = (0,0,1), choose (1,0,0)
            # for general M_axis, can do a cross with something
            # but let's do a simple approach:
            # find a vector not parallel to M_axis
            test_vec = np.array([1,0,0], dtype=float)
            if abs(np.dot(test_vec, M_axis)) > 0.9:
                test_vec = np.array([0,1,0], dtype=float)
            plane_vec = np.cross(M_axis, test_vec)
            plane_vec_norm = np.linalg.norm(plane_vec)
            if plane_vec_norm < 1e-12:
                axis_in_plane[i] = np.zeros(3)
            else:
                axis_in_plane[i] = plane_vec / plane_vec_norm
        else:
            seg_vec_inplane /= norm_seg_inplane
            # define axis_in_plane = cross(M_axis, seg_vec_inplane)
            plane_vec = np.cross(M_axis, seg_vec_inplane)
            plane_vec_norm = np.linalg.norm(plane_vec)
            if plane_vec_norm < 1e-12:
                axis_in_plane[i] = np.zeros(3)
            else:
                axis_in_plane[i] = plane_vec / plane_vec_norm
    
    # 7) Ensure consecutive in-plane axes have dot>0 to keep consistent orientation
    for i in range(1, len(axis_in_plane)):
        if np.dot(axis_in_plane[i-1], axis_in_plane[i]) < 0:
            axis_in_plane[i] = -axis_in_plane[i]
    
    # 8) Store results back into segment_info_df
    segment_info_df["principal_axis_x"] = principal_axes[:,0]
    segment_info_df["principal_axis_y"] = principal_axes[:,1]
    segment_info_df["principal_axis_z"] = principal_axes[:,2]
    
    segment_info_df["axis_in_plane_x"] = axis_in_plane[:,0]
    segment_info_df["axis_in_plane_y"] = axis_in_plane[:,1]
    segment_info_df["axis_in_plane_z"] = axis_in_plane[:,2]
    
    return segment_info_df

def project_points_onto_reference_spline(
    df_points: pd.DataFrame,
    reference_spline_info: pd.DataFrame,
    reorient_axes: bool = True,
    first_n: int = 5,
    reference_point: np.ndarray = None
):
    """
    Projects the rows in df_points onto a reference spline (and planes) given by reference_spline_info.
    
    New Features:
      - Optionally reorient the principal_axis (M) and define an in-plane axis (m).
      - Compute signed coordinates M, m for each point, and a 0..1 progress_t measure along the segment.
    
    Steps:
      1. If reorient_axes=True, call orient_segment_axes(...) so that principal_axis 
         and axis_in_plane are consistently oriented across segments.
      2. Build segment_dicts with 'axis_in_plane' included.
      3. For each point, find the closest segment.
      4. Compute 'closest_on_spline' (3D).
      5. Compute M, m, and progress_t.
      6. Return a DataFrame with the new columns.
      7. (Optionally) merge with df_points on 'snip_id'.
    """
    # 1) Reorient principal_axes if requested
    seg_df = reference_spline_info.copy()
    if reorient_axes:
        seg_df = orient_segment_axes(
            segment_info_df=seg_df,
            first_n=first_n,
            reference_point=reference_point
        )
    
    # 2) Build segment dicts with the newly oriented data
    segment_dicts = []
    for _, row in seg_df.iterrows():
        seg_id = row["seg_id"]
        
        M_axis = np.array([
            row["principal_axis_x"],
            row["principal_axis_y"],
            row["principal_axis_z"]
        ], dtype=float)
        
        m_axis = np.array([
            row["axis_in_plane_x"],
            row["axis_in_plane_y"],
            row["axis_in_plane_z"]
        ], dtype=float)
        
        midpoint = np.array([
            row["segment_midpoint_x"],
            row["segment_midpoint_y"],
            row["segment_midpoint_z"]
        ], dtype=float)
        
        seg_start = np.array([
            row["segment_start_x"],
            row["segment_start_y"],
            row["segment_start_z"]
        ], dtype=float)
        
        seg_end = np.array([
            row["segment_end_x"],
            row["segment_end_y"],
            row["segment_end_z"]
        ], dtype=float)
        
        seg_vec = seg_end - seg_start
        seg_len_sq = np.dot(seg_vec, seg_vec)
        
        segment_dicts.append({
            "seg_id": seg_id,
            "M_axis": M_axis,     # principal axis (oriented)
            "m_axis": m_axis,     # in-plane axis
            "midpoint": midpoint,
            "seg_start": seg_start,
            "seg_end": seg_end,
            "seg_vec": seg_vec,
            "seg_len_sq": seg_len_sq
        })
    
    # Helper for progress
    def segment_progress(p, seg_start, seg_vec, seg_len_sq):
        """Return t in [0..1], fraction along seg_start->(seg_start+seg_vec)."""
        if seg_len_sq < 1e-12:
            return 0.0
        pt_vec = p - seg_start
        t = np.dot(pt_vec, seg_vec) / seg_len_sq
        return np.clip(t, 0.0, 1.0)
    
    records = []
    
    for idx, row in df_points.iterrows():
        p = np.array([row["PCA_1"], row["PCA_2"], row["PCA_3"]], dtype=float)
        
        # Find closest segment
        min_dist = np.inf
        best_seg = None
        for seg_info in segment_dicts:
            seg_start = seg_info["seg_start"]
            seg_end   = seg_info["seg_end"]
            seg_vec   = seg_info["seg_vec"]
            seg_len_sq= seg_info["seg_len_sq"]
            
            # project to segment
            pt_vec = p - seg_start
            if seg_len_sq < 1e-12:
                dist = np.linalg.norm(pt_vec)
                closest_3d = seg_start
            else:
                t = np.dot(pt_vec, seg_vec) / seg_len_sq
                t_clamped = np.clip(t, 0, 1)
                closest_3d = seg_start + t_clamped * seg_vec
                dist = np.linalg.norm(p - closest_3d)
            
            if dist < min_dist:
                min_dist = dist
                best_seg = seg_info
        
        if best_seg is None:
            continue
        
        # Extract best segment info
        seg_id    = best_seg["seg_id"]
        M_axis    = best_seg["M_axis"]
        m_axis    = best_seg["m_axis"]
        midpoint  = best_seg["midpoint"]
        seg_start = best_seg["seg_start"]
        seg_end   = best_seg["seg_end"]
        seg_vec   = best_seg["seg_vec"]
        seg_len_sq= best_seg["seg_len_sq"]
        
        # Closest point on the segment
        pt_vec = p - seg_start
        if seg_len_sq < 1e-12:
            t_clamped = 0.0
        else:
            t_unnorm = np.dot(pt_vec, seg_vec) / seg_len_sq
            t_clamped = np.clip(t_unnorm, 0, 1)
        closest_on_spline = seg_start + t_clamped * seg_vec
        
        # progress_t
        progress_t = segment_progress(p, seg_start, seg_vec, seg_len_sq)
        
        # Compute M, m (signed coords)
        delta = p - midpoint
        M_val = np.dot(delta, M_axis)
        m_val = np.dot(delta, m_axis)
        
        # If you still want plane_point & distance_to_plane, etc.:
        alpha = np.dot(delta, M_axis)  # same as M_val
        plane_point = p - alpha * M_axis
        distance_to_plane = abs(M_val)
        # distance to axis is distance from p to midpoint+alpha*M_axis
        closest_on_axis = midpoint + alpha * M_axis
        distance_to_axis = np.linalg.norm(p - closest_on_axis)
        hypotenuse = np.sqrt(distance_to_plane**2 + distance_to_axis**2)
        
        record = {
            "snip_id": row.get("snip_id", None),
            "embryo_id": row.get("embryo_id", None),
            "phenotype": row.get("phenotype", None),
            "predicted_stage_hpf": row.get("predicted_stage_hpf", None),
            
            "PCA_1": p[0],
            "PCA_2": p[1],
            "PCA_3": p[2],
            
            "ref_seg_id": seg_id,
            "progress_t": progress_t,   # fraction along segment in [0,1]
            
            "closest_on_spline_x": closest_on_spline[0],
            "closest_on_spline_y": closest_on_spline[1],
            "closest_on_spline_z": closest_on_spline[2],
            
            "M": M_val,  # signed distance along principal axis
            "m": m_val,  # signed distance along in-plane axis
            
            "plane_point_x": plane_point[0],
            "plane_point_y": plane_point[1],
            "plane_point_z": plane_point[2],
            
            "distance_to_plane": distance_to_plane,
            "distance_to_axis": distance_to_axis,
            "hypotenuse": hypotenuse
        }
        records.append(record)
    
    projection_df = pd.DataFrame(records)
    
    # Optional: merge with df_points on "snip_id"
    # If 'snip_id' is guaranteed unique, we can do a left join
    merged_df = pd.merge(df_points, projection_df, on="snip_id", how="left", suffixes=("", "_proj"))
    
    return merged_df

In [None]:
segment

In [7]:
wt_splines_n_planes = segment_info_df[segment_info_df["phenotype"] == "wt_cep290_temp30"]

# A) Points from "wt" itself
df_points = project_points_onto_reference_spline(
    df_augmented,
    wt_splines_n_planes
)


df_augmented_projec_wt = pd.merge(
    df_augmented, 
    df_points.drop(columns=[col for col in df_points.columns if col in df_augmented.columns and col != "snip_id"]),
    on="snip_id"
)

In [147]:
# Merge based on 'snip_id'
df_augmented_projec_wt = df_augmented_projec_wt.merge(
    df_augmented_projec_wt_not_cep290[['snip_id', 'ref_segment_avg_time']],
    on='snip_id',
    how='left'
)

# Rename the newly added column
df_augmented_projec_wt = df_augmented_projec_wt.rename(
    columns={'ref_segment_avg_time': 'ref_segment_avg_time_wt_temp30'}
)

In [148]:
set(df_augmented_projec_wt.columns)

In [8]:
def plot_metric_over_stage(
    df,
    metric_column='hypotenuse',
    time_column='predicted_stage_hpf',
    phenotypes_to_include=None,
    window_size=5,
    figsize=(12, 8),
    palette="tab10",
    alpha=0.5,
    plot_individual_embryos=True,
    plot_average=True,
    plot_median=False,
    highlight_embryos=None,
    highlight_phenotypes=None,
    highlight_alpha=0.9,
    max_time=None,
    min_time=None,
    min_points_per_embryo=5,
    save_path=None,
    show_legend=True,
    title=None,
    xlabel=None,
    ylabel=None
):
    """
    Plot any metric over developmental stage or time with smoothing.
    
    Parameters:
    -----------
    df : pd.DataFrame
        DataFrame containing the data
    metric_column : str, default='hypotenuse'
        Column name for the metric to plot on y-axis
    time_column : str, default='predicted_stage_hpf'
        Column name for the time/stage to plot on x-axis
    phenotypes_to_include : list, optional
        List of phenotypes to include in the plot. If None, all phenotypes are included.
    window_size : int, default=5
        Window size for rolling average smoothing.
    figsize : tuple, default=(12, 8)
        Figure size (width, height) in inches.
    palette : str or dict, default="tab10"
        Color palette name or dict mapping phenotypes to colors.
    alpha : float, default=0.5
        Transparency level for individual embryo lines.
    plot_individual_embryos : bool, default=True
        Whether to plot individual embryo traces.
    plot_average : bool, default=True
        Whether to plot average line per phenotype.
    plot_median : bool, default=False
        Whether to plot median line per phenotype.
    highlight_embryos : list, optional
        List of embryo_ids to highlight with thicker lines.
    highlight_phenotypes : list, optional
        List of phenotypes to highlight with higher opacity.
    highlight_alpha : float, default=0.9
        Transparency level for highlighted phenotypes.
    max_time : float, optional
        Maximum time value to include in the plot.
    min_time : float, optional
        Minimum time value to include in the plot.
    min_points_per_embryo : int, default=5
        Minimum number of data points required for an embryo to be included.
    save_path : str, optional
        Path to save the figure, if provided.
    show_legend : bool, default=True
        Whether to show the legend.
    title : str, optional
        Plot title. If None, a default title is used.
    xlabel : str, optional
        X-axis label. If None, uses time_column.
    ylabel : str, optional
        Y-axis label. If None, uses metric_column.
    
    Returns:
    --------
    matplotlib.figure.Figure
        The created figure
    
    Example:
    --------
    # Basic usage with default columns (hypotenuse over predicted_stage_hpf)
    fig = plot_metric_over_stage(my_dataframe)
    
    # Plot a different metric over developmental time
    fig = plot_metric_over_stage(
        my_dataframe,
        metric_column='M',  # Axis of maximum variation
        time_column='ref_segment_avg_time_wt_temp30'
    )
    
    # Customize with phenotype highlighting and time range
    fig = plot_metric_over_stage(
        my_dataframe,
        metric_column='curvature',
        highlight_phenotypes=['cep290_homo_cep290_temp30'],
        phenotypes_to_include=['cep290_het_cep290_temp30', 'wt_cep290_temp30', 'cep290_homo_cep290_temp30'],
        min_time=20,
        max_time=48
    )
    """
    import matplotlib.pyplot as plt
    import seaborn as sns
    import pandas as pd
    import numpy as np
    
    # Make a copy to avoid modifying the original
    df = df.copy()
    
    # Ensure required columns exist
    required_cols = [metric_column, 'embryo_id', 'phenotype', time_column]
    missing_cols = [col for col in required_cols if col not in df.columns]
    if missing_cols:
        raise ValueError(f"Missing required columns: {missing_cols}")
    
    # Filter by phenotype if specified
    if phenotypes_to_include is not None:
        df = df[df['phenotype'].isin(phenotypes_to_include)]
    
    # Filter by time range if specified
    if min_time is not None:
        df = df[df[time_column] >= min_time]
    if max_time is not None:
        df = df[df[time_column] <= max_time]
    
    # Filter embryos with too few data points
    embryo_counts = df.groupby('embryo_id').size()
    valid_embryos = embryo_counts[embryo_counts >= min_points_per_embryo].index
    df = df[df['embryo_id'].isin(valid_embryos)]
    
    # If DataFrame is empty after filtering, return empty plot
    if df.empty:
        plt.figure(figsize=figsize)
        plt.text(0.5, 0.5, "No data available after filtering", 
                 horizontalalignment='center', verticalalignment='center',
                 fontsize=14)
        plt.gca().set_axis_off()
        return plt.gcf()
    
    # Setup figure
    plt.figure(figsize=figsize)
    
    # Get unique phenotypes and assign colors
    unique_phenotypes = df['phenotype'].unique()
    if isinstance(palette, str):
        color_palette = sns.color_palette(palette, n_colors=len(unique_phenotypes))
        phenotype_colors = {phenotype: color_palette[i] for i, phenotype in enumerate(unique_phenotypes)}
    else:
        # If palette is a dict, use it directly
        phenotype_colors = palette
    
    # Store lines for legend
    phenotype_avg_lines = {}
    phenotype_med_lines = {}
    highlight_lines = {}
    
    # Plot individual embryos
    if plot_individual_embryos:
        for embryo_id, group in df.groupby('embryo_id'):
            phenotype = group['phenotype'].iloc[0]
            color = phenotype_colors.get(phenotype, 'gray')
            
            # Sort by time column
            group = group.sort_values(time_column)
            
            # Apply smoothing with rolling window
            group[f'smooth_{metric_column}'] = group[metric_column].rolling(
                window=window_size, min_periods=1, center=True
            ).mean()
            
            # Determine line properties
            is_highlight_embryo = highlight_embryos is not None and embryo_id in highlight_embryos
            is_highlight_phenotype = highlight_phenotypes is not None and phenotype in highlight_phenotypes
            
            # Set alpha and line width based on highlight status
            line_alpha = highlight_alpha if (is_highlight_embryo or is_highlight_phenotype) else alpha
            line_width = 2.5 if is_highlight_embryo else 1
            
            # Plot the line
            line = plt.plot(
                group[time_column], 
                group[f'smooth_{metric_column}'], 
                color=color, 
                alpha=line_alpha,
                linewidth=line_width,
                label=None  # We'll add to legend separately
            )
            
            # Store for legend if this is a highlight embryo
            if is_highlight_embryo:
                highlight_lines[embryo_id] = line[0]
    
    # Create time bins for average/median calculations
    # Determine bin size based on data range
    time_range = df[time_column].max() - df[time_column].min()
    bin_size = max(0.5, time_range / 40)  # Adjust bin size based on data range
    
    # Add time bin column
    df['time_bin'] = (df[time_column] / bin_size).astype(int) * bin_size
    
    # Plot average line per phenotype
    if plot_average:
        for phenotype, group in df.groupby('phenotype'):
            color = phenotype_colors.get(phenotype, 'gray')
            
            # Calculate mean per bin
            bin_means = group.groupby('time_bin')[metric_column].mean().reset_index()
            
            # Sort by time
            bin_means = bin_means.sort_values('time_bin')
            
            # Apply smoothing
            bin_means[f'smooth_{metric_column}'] = bin_means[metric_column].rolling(
                window=window_size, min_periods=1, center=True
            ).mean()
            
            # Determine if this is a highlighted phenotype
            is_highlight = highlight_phenotypes is not None and phenotype in highlight_phenotypes
            line_alpha = highlight_alpha if is_highlight else 1.0
            
            # Plot average line
            line = plt.plot(
                bin_means['time_bin'], 
                bin_means[f'smooth_{metric_column}'], 
                color=color, 
                linewidth=5,
                alpha=line_alpha,
                linestyle='-',
                label=f"{phenotype} (mean, n={len(group['embryo_id'].unique())})"
            )
            
            # Store for legend
            phenotype_avg_lines[phenotype] = line[0]
    
    # Plot median line per phenotype
    if plot_median:
        for phenotype, group in df.groupby('phenotype'):
            color = phenotype_colors.get(phenotype, 'gray')
            
            # Calculate median per bin
            bin_medians = group.groupby('time_bin')[metric_column].median().reset_index()
            
            # Sort by time
            bin_medians = bin_medians.sort_values('time_bin')
            
            # Apply smoothing
            bin_medians[f'smooth_{metric_column}'] = bin_medians[metric_column].rolling(
                window=window_size, min_periods=1, center=True
            ).mean()
            
            # Determine if this is a highlighted phenotype
            is_highlight = highlight_phenotypes is not None and phenotype in highlight_phenotypes
            line_alpha = highlight_alpha if is_highlight else 1.0
            
            # Plot median line
            line = plt.plot(
                bin_medians['time_bin'], 
                bin_medians[f'smooth_{metric_column}'], 
                color=color, 
                linewidth=5,
                alpha=line_alpha,
                linestyle='--',  # Use dashed line to distinguish from mean
                label=f"{phenotype} (median, n={len(group['embryo_id'].unique())})"
            )
            
            # Store for legend
            phenotype_med_lines[phenotype] = line[0]
    
    # Set title and labels
    default_title = f"{metric_column} by {time_column}"
    title = title or default_title
    plt.title(title, fontsize=14)
    
    xlabel = xlabel or time_column
    ylabel = ylabel or metric_column
    plt.xlabel(xlabel, fontsize=12)
    plt.ylabel(ylabel, fontsize=12)
    
    # Add legend
    if show_legend:
        # Combine all lines for the legend
        all_lines = {}
        all_labels = []
        
        # Add mean lines
        if plot_average and phenotype_avg_lines:
            for phenotype, line in phenotype_avg_lines.items():
                all_lines[f"{phenotype} (mean)"] = line
                all_labels.append(f"{phenotype} (mean, n={len(df[df['phenotype']==phenotype]['embryo_id'].unique())})")
        
        # Add median lines
        if plot_median and phenotype_med_lines:
            for phenotype, line in phenotype_med_lines.items():
                all_lines[f"{phenotype} (median)"] = line
                all_labels.append(f"{phenotype} (median, n={len(df[df['phenotype']==phenotype]['embryo_id'].unique())})")
        
        # Add highlighted embryos
        if highlight_embryos and highlight_lines:
            for embryo_id, line in highlight_lines.items():
                all_lines[f"Embryo {embryo_id}"] = line
                all_labels.append(f"Embryo {embryo_id}")
        
        # Create the legend if we have any lines
        if all_lines:
            plt.legend(
                handles=list(all_lines.values()),
                labels=all_labels,
                title="Phenotypes and Highlighted Embryos",
                loc="best",
                fontsize=10
            )
    
    # Set grid
    plt.grid(True, linestyle='--', alpha=0.7)
    
    # Adjust layout
    plt.tight_layout()
    
    # Save if requested
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    return plt.gcf()

In [9]:
df_augmented_projec_wt["M_flipped"] = df_augmented_projec_wt["M"] * -1

In [155]:
fig = plot_metric_over_stage(
    df_augmented_projec_wt,
    metric_column='M_flipped',
    time_column='predicted_stage_hpf',
    phenotypes_to_include=[
        'cep290_het_cep290_temp30',
        'wt_cep290_temp30',
        'cep290_homo_cep290_temp30'
    ],
    highlight_phenotypes=['cep290_homo_cep290_temp30'],
    plot_median=True,
    # plot_average=True
)

fig = plot_metric_over_stage(
    df_augmented_projec_wt,
    metric_column='hypotenuse',
    time_column='predicted_stage_hpf',
    phenotypes_to_include=[
        'cep290_het_cep290_temp30',
        'wt_cep290_temp30',
        'cep290_homo_cep290_temp30'
    ],
    highlight_phenotypes=['cep290_homo_cep290_temp30'],
    plot_median=True,
    # plot_average=True
)

fig = plot_metric_over_stage(
    df_augmented_projec_wt,
    metric_column='m',
    time_column='predicted_stage_hpf',
    phenotypes_to_include=[
        'cep290_het_cep290_temp30',
        'wt_cep290_temp30',
        'cep290_homo_cep290_temp30'
    ],
    highlight_phenotypes=['cep290_homo_cep290_temp30'],
    plot_median=True,
    # plot_average=True
)



In [157]:
fig = plot_metric_over_stage(
    df_augmented_projec_wt,
    metric_column='M_flipped',
    time_column='predicted_stage_hpf',
    phenotypes_to_include=[
        'cep290_homo_cep290_temp30'
    ],
    highlight_phenotypes=['cep290_homo_cep290_temp30'],
    plot_median=True,
    # plot_average=True
)


fig = plot_metric_over_stage(
    df_augmented_projec_wt,
    metric_column='M_flipped',
    time_column='predicted_stage_hpf',
    phenotypes_to_include=[
        'cep290_het_cep290_temp30'
    ],
    highlight_phenotypes=['cep290_homo_cep290_temp30'],
    plot_median=True,
    # plot_average=True
)


In [153]:
fig = plot_metric_over_stage(
    df_augmented_projec_wt,
    metric_column='M_flipped',
    time_column='ref_segment_avg_time_wt_temp30',
    phenotypes_to_include=[
        'cep290_het_cep290_temp30',
        'wt_cep290_temp30',
        'cep290_homo_cep290_temp30'
    ],
    highlight_phenotypes=['cep290_homo_cep290_temp30'],
    plot_median=True,
    # plot_average=True
)

fig = plot_metric_over_stage(
    df_augmented_projec_wt,
    metric_column='m',
    time_column='ref_segment_avg_time_wt_temp30',
    phenotypes_to_include=[
        'cep290_het_cep290_temp30',
        'wt_cep290_temp30',
        'cep290_homo_cep290_temp30'
    ],
    highlight_phenotypes=['cep290_homo_cep290_temp30'],
    plot_median=True,
    # plot_average=True
)



In [10]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

def rank_embryos_by_metric(
    df, 
    metric_column='hypotenuse',
    time_column='predicted_stage_hpf',
    min_time=30, 
    max_time=35,
    ascending=False
):
    """
    Ranks embryos by their maximum or minimum value of a specified metric 
    within a specific time/stage interval.
    
    Parameters:
    -----------
    df : pd.DataFrame
        DataFrame containing the data
    metric_column : str, default='hypotenuse'
        Column name for the metric to rank by
    time_column : str, default='predicted_stage_hpf'
        Column name for the time/stage to filter by
    min_time : float, default=30
        Minimum time/stage value to include
    max_time : float, default=35
        Maximum time/stage value to include
    ascending : bool, default=False
        If True, ranks by minimum values (ascending).
        If False, ranks by maximum values (descending).
    
    Returns:
    --------
    pd.DataFrame
        A DataFrame with ranked embryos, metric values, corresponding time points, and phenotype
    
    Example:
    --------
    # Rank by maximum hypotenuse (default)
    ranked_df = rank_embryos_by_metric(df_augmented_projec_wt)
    
    # Rank by minimum curvature
    ranked_df = rank_embryos_by_metric(
        df_augmented_projec_wt, 
        metric_column='curvature',
        ascending=True
    )
    
    # Rank by maximum M value using ref_segment_avg_time
    ranked_df = rank_embryos_by_metric(
        df_augmented_projec_wt, 
        metric_column='M',
        time_column='ref_segment_avg_time',
        min_time=32,
        max_time=40
    )
    """
    # Make a copy to avoid modifying the original
    df = df.copy()
    
    # Ensure required columns exist
    required_cols = [metric_column, 'embryo_id', 'phenotype', time_column]
    missing_cols = [col for col in required_cols if col not in df.columns]
    if missing_cols:
        raise ValueError(f"Missing required columns: {missing_cols}")
    
    # Check if snip_id column exists
    has_snip_id = 'snip_id' in df.columns
    if not has_snip_id:
        print(f"Note: 'snip_id' column not found. Adding placeholder values.")
    
    # Filter data within the specified time interval
    time_filtered = df[(df[time_column] > min_time) & (df[time_column] < max_time)]
    
    if time_filtered.empty:
        print(f"Warning: No data found in the {time_column} range {min_time}-{max_time}.")
        columns = ['embryo_id', 'phenotype', f'{metric_column}_value', f'time_at_{metric_column}']
        if has_snip_id:
            columns.append('snip_id')
        columns.append('rank')
        return pd.DataFrame(columns=columns)
    
    # For each embryo, find the extreme value (max or min) and corresponding information
    result_rows = []
    
    for embryo_id, embryo_data in time_filtered.groupby('embryo_id'):
        if embryo_data.empty:
            continue
            
        # Find the row with extreme metric value for this embryo
        if ascending:
            extreme_idx = embryo_data[metric_column].idxmin()
            extreme_label = 'min'
        else:
            extreme_idx = embryo_data[metric_column].idxmax()
            extreme_label = 'max'
            
        extreme_row = embryo_data.loc[extreme_idx]
        
        result_row = {
            'embryo_id': embryo_id,
            'phenotype': extreme_row['phenotype'],
            f'{metric_column}_value': extreme_row[metric_column],
            f'time_at_{metric_column}': extreme_row[time_column]
        }
        
        # Add snip_id if available
        if has_snip_id:
            result_row['snip_id'] = extreme_row['snip_id']
        elif 'snip_id' in extreme_row:
            result_row['snip_id'] = extreme_row['snip_id']
        else:
            result_row['snip_id'] = f"unknown_{embryo_id}_{extreme_idx}"
            
        result_rows.append(result_row)
    
    # Create DataFrame from results
    result_df = pd.DataFrame(result_rows)
    
    # Sort by metric value and add rank
    if not result_df.empty:
        result_df = result_df.sort_values(f'{metric_column}_value', ascending=ascending).reset_index(drop=True)
        result_df['rank'] = result_df.index + 1
    
    return result_df


def visualize_embryo_rankings(
    df, 
    ranked_df, 
    metric_column='hypotenuse',
    time_column='predicted_stage_hpf',
    min_time=30, 
    max_time=35, 
    top_n=10, 
    figsize=(15, 10),
    title=None
):
    """
    Visualizes the top ranked embryos in the specified time/stage interval.
    
    Parameters:
    -----------
    df : pd.DataFrame
        Original DataFrame with all embryo data
    ranked_df : pd.DataFrame
        DataFrame from rank_embryos_by_metric function
    metric_column : str, default='hypotenuse'
        Column name for the metric being analyzed
    time_column : str, default='predicted_stage_hpf'
        Column name for the time/stage
    min_time : float, default=30
        Minimum time/stage value to include
    max_time : float, default=35
        Maximum time/stage value to include
    top_n : int, default=10
        Number of top-ranked embryos to visualize
    figsize : tuple, default=(15, 10)
        Figure size
    title : str, optional
        Custom title for the plot
        
    Returns:
    --------
    matplotlib.figure.Figure
        The created figure
        
    Example:
    --------
    # First rank the embryos
    ranked_df = rank_embryos_by_metric(df_augmented_projec_wt)
    
    # Then visualize the top 10
    fig = visualize_embryo_rankings(df_augmented_projec_wt, ranked_df)
    
    # Visualize different metric
    ranked_df = rank_embryos_by_metric(df, metric_column='curvature')
    fig = visualize_embryo_rankings(df, ranked_df, metric_column='curvature')
    """
    # Get the top N embryos
    top_embryos = ranked_df.head(top_n)
    
    # Filter original data for these embryos and the time interval
    df = df.copy()
    filtered_data = df[
        (df['embryo_id'].isin(top_embryos['embryo_id'])) & 
        (df[time_column] > min_time) & 
        (df[time_column] < max_time)
    ]
    
    # Set up the plot
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
    
    # 1. Create a bar chart of metric values by embryo ID
    sns.barplot(
        x='embryo_id', 
        y=f'{metric_column}_value', 
        hue='phenotype',
        data=top_embryos,
        ax=ax1
    )
    bar_title = f'Top {metric_column.capitalize()} Values ({min_time}-{max_time})'
    ax1.set_title(bar_title)
    ax1.set_xlabel('Embryo ID')
    ax1.set_ylabel(f'{metric_column.capitalize()} Value')
    ax1.tick_params(axis='x', rotation=45)
    
    # 2. Plot trajectories of metric over time for the top embryos
    for embryo_id, group in filtered_data.groupby('embryo_id'):
        phenotype = group['phenotype'].iloc[0]
        
        # Find the embryo rank for label
        rank = top_embryos[top_embryos['embryo_id'] == embryo_id]['rank'].iloc[0]
        
        # Sort by time
        group = group.sort_values(time_column)
        
        # Plot the line
        ax2.plot(
            group[time_column], 
            group[metric_column], 
            label=f"#{rank} - {embryo_id} ({phenotype})"
        )
        
        # Highlight extreme point (max or min)
        if f'{metric_column}_value' in top_embryos.columns:
            # Find the value we're looking for
            target_value = top_embryos[top_embryos['embryo_id'] == embryo_id][f'{metric_column}_value'].values[0]
            
            # Find the closest data point
            closest_idx = (group[metric_column] - target_value).abs().idxmin()
            extreme_point = group.loc[closest_idx]
            
            ax2.scatter(
                extreme_point[time_column],
                extreme_point[metric_column],
                s=100, 
                marker='o',
                edgecolors='black'
            )
    
    line_title = f'{metric_column.capitalize()} Trajectories ({min_time}-{max_time})'
    ax2.set_title(line_title)
    ax2.set_xlabel(f'{time_column}')
    ax2.set_ylabel(metric_column.capitalize())
    ax2.grid(True, linestyle='--', alpha=0.7)
    ax2.legend(loc='best', fontsize=8)
    
    # Set overall title if provided
    if title:
        fig.suptitle(title, fontsize=16, y=1.05)
    
    plt.tight_layout()
    return fig


def summarize_metric_by_phenotype(
    df, 
    metric_column='hypotenuse',
    time_column='predicted_stage_hpf',
    min_time=30, 
    max_time=35
):
    """
    Generates a statistical summary of specified metric grouped by phenotype
    within the specified time/stage interval.
    
    Parameters:
    -----------
    df : pd.DataFrame
        DataFrame containing the data
    metric_column : str, default='hypotenuse'
        Column name for the metric to analyze
    time_column : str, default='predicted_stage_hpf'
        Column name for the time/stage
    min_time : float, default=30
        Minimum time/stage value to include
    max_time : float, default=35
        Maximum time/stage value to include
        
    Returns:
    --------
    pd.DataFrame
        A DataFrame with statistics by phenotype
        
    Example:
    --------
    # Summarize hypotenuse statistics by phenotype
    summary_df = summarize_metric_by_phenotype(df_augmented_projec_wt)
    
    # Summarize M values using a different time column
    summary_df = summarize_metric_by_phenotype(
        df_augmented_projec_wt, 
        metric_column='M',
        time_column='ref_segment_avg_time'
    )
    """
    # Filter data within the specified time interval
    time_filtered = df[(df[time_column] > min_time) & (df[time_column] < max_time)]
    
    if time_filtered.empty:
        print(f"Warning: No data found in the {time_column} range {min_time}-{max_time}.")
        return pd.DataFrame()
    
    # Group by phenotype and calculate statistics
    summary = time_filtered.groupby('phenotype')[metric_column].agg([
        ('count', 'count'),
        ('mean', 'mean'), 
        ('median', 'median'),
        ('std', 'std'),
        ('min', 'min'), 
        ('max', 'max'),
        ('25%', lambda x: x.quantile(0.25)),
        ('75%', lambda x: x.quantile(0.75))
    ]).reset_index()
    
    # Count unique embryos per phenotype
    embryo_counts = time_filtered.groupby('phenotype')['embryo_id'].nunique().reset_index()
    embryo_counts.columns = ['phenotype', 'unique_embryos']
    
    # Merge the counts with the summary
    summary = pd.merge(summary, embryo_counts, on='phenotype')
    
    return summary


def analyze_embryo_metric(
    df, 
    metric_column='hypotenuse',
    time_column='predicted_stage_hpf',
    min_time=10, 
    max_time=70, 
    top_n=8,
    ascending=False,
    title=None
):
    """
    Complete workflow to analyze embryos in the given time/stage interval
    for a specified metric.
    
    Parameters:
    -----------
    df : pd.DataFrame
        DataFrame containing the embryo data
    metric_column : str, default='hypotenuse'
        Column name for the metric to analyze
    time_column : str, default='predicted_stage_hpf'
        Column name for the time/stage
    min_time : float, default=30
        Minimum time/stage value to include
    max_time : float, default=35
        Maximum time/stage value to include
    top_n : int, default=8
        Number of top-ranked embryos to visualize
    ascending : bool, default=False
        If True, ranks by minimum values (ascending).
        If False, ranks by maximum values (descending).
    title : str, optional
        Custom title for the visualization
    
    Returns:
    --------
    tuple
        (ranked_embryos_df, phenotype_summary_df, visualization_fig)
        
    Example:
    --------
    # Analyze hypotenuse (default)
    ranked_df, summary_df, fig = analyze_embryo_metric(df_augmented_projec_wt)
    
    # Analyze curvature, finding minimums
    ranked_df, summary_df, fig = analyze_embryo_metric(
        df_augmented_projec_wt,
        metric_column='curvature',
        ascending=True,
        title='Minimum Curvature Analysis'
    )
    
    # Analyze axis of max variation (M) against reference time
    ranked_df, summary_df, fig = analyze_embryo_metric(
        df_augmented_projec_wt,
        metric_column='M',
        time_column='ref_segment_avg_time',
        min_time=32,
        max_time=42
    )
    """
    # 1. Rank embryos by the metric
    ranked_df = rank_embryos_by_metric(
        df, 
        metric_column=metric_column,
        time_column=time_column,
        min_time=min_time, 
        max_time=max_time,
        ascending=ascending
    )
    
    # 2. Summarize by phenotype
    summary_df = summarize_metric_by_phenotype(
        df, 
        metric_column=metric_column,
        time_column=time_column,
        min_time=min_time, 
        max_time=max_time
    )
    
    # 3. Visualize top embryos
    fig = visualize_embryo_rankings(
        df, 
        ranked_df, 
        metric_column=metric_column,
        time_column=time_column,
        min_time=min_time, 
        max_time=max_time, 
        top_n=top_n,
        title=title
    )
    
    return ranked_df, summary_df, fig


In [11]:
# Filter to specific phenotypes
df_analyze = df_augmented_projec_wt[df_augmented_projec_wt["phenotype"].isin([
    'cep290_het_cep290_temp30',
    'wt_cep290_temp30',
    'cep290_homo_cep290_temp30'
])]

# Run the complete analysis with default settings
ranked_df, summary_df, fig = analyze_embryo_metric(df_analyze)

# Display results
display(ranked_df.head(15))
display(summary_df)
plt.show()

In [14]:
# Analyze axis of maximum variation (M) over reference time
ranked_df, summary_df, fig = analyze_embryo_metric(
    df_analyze,
    metric_column='M',
    title='Axis of Maximum Variation Analysis'
)

# Display results
display(ranked_df.head(15))
display(summary_df)
plt.show()

In [20]:
df_augmented_projec_wt[df_augmented_projec_wt["embryo_id"].isin([ '20250305_H08_e00'])][["predicted_stage_hpf","snip_id","hypotenuse","phenotype"]]

In [None]:
# Filter to specific phenotypes
df_analyze = df_augmented_projec_wt[df_augmented_projec_wt["phenotype"].isin([
    'cep290_homo_cep290_temp30'
])]

# # Run the complete analysis with default settings
# ranked_df, summary_df, fig = analyze_embryo_metric(df_analyze, min_time = 55)
ranked_df, summary_df, fig = analyze_embryo_metric(
    df_analyze,
    min_time=55,
    metric_column='M',
    title='Axis of Maximum Variation Analysis'
)

# Display results
display(ranked_df.head(15))
display(summary_df)
plt.show()

In [17]:
# Filter to specific phenotypes
df_analyze = df_augmented_projec_wt[df_augmented_projec_wt["phenotype"].isin([
    'cep290_homo_cep290_temp30'
])]

# # Run the complete analysis with default settings
# ranked_df, summary_df, fig = analyze_embryo_metric(df_analyze, min_time = 55)
ranked_df, summary_df, fig = analyze_embryo_metric(
    df_analyze,
    min_time=55,
    metric_column='M',
    title='Axis of Maximum Variation Analysis'
)

# Display results
display(ranked_df.head(15))
display(summary_df)
plt.show()

In [18]:
# Filter to specific phenotypes
df_analyze = df_augmented_projec_wt[df_augmented_projec_wt["phenotype"].isin([
    'cep290_het_cep290_temp30',
    'wt_cep290_temp30',
    'cep290_homo_cep290_temp30'
])]

# Run the complete analysis with default settings
ranked_df, summary_df, fig = analyze_embryo_metric(df_analyze, min_time=55)

# Display results
display(ranked_df.head(15))
display(summary_df)
plt.show()

In [22]:
# Find embryos with the lowest curvature (potentially straighter embryos)
ranked_df, summary_df, fig = analyze_embryo_metric(
    df_analyze,
    metric_column='hypotenuse',
    ascending=True,
    title='Minimum Curvature Analysis',
    min_time=55
)

In [23]:
def visualize_extreme_embryos(
    df,
    metric_column='hypotenuse',
    time_column='predicted_stage_hpf',
    time_threshold=55,
    top_n=5,
    bottom_n=5,
    window_size=5,
    figsize=(15, 8),
    color_top='#1f77b4',  # Blue
    color_bottom='#d62728',  # Red
    include_phenotypes=True,
    title=None,
    save_path=None
):
    """
    Visualizes the trajectories of embryos with extreme (highest/lowest) metric values
    over time, focusing on values after a specified time threshold.
    
    Parameters:
    -----------
    df : pd.DataFrame
        DataFrame containing the embryo data
    metric_column : str, default='hypotenuse'
        Column name for the metric to analyze
    time_column : str, default='predicted_stage_hpf'
        Column name for the time/stage
    time_threshold : float, default=55
        Minimum time value for calculating average metrics
    top_n : int, default=5
        Number of top-scoring embryos to visualize
    bottom_n : int, default=5
        Number of bottom-scoring embryos to visualize
    window_size : int, default=5
        Window size for rolling average smoothing
    figsize : tuple, default=(15, 8)
        Figure size
    color_top : str, default='#1f77b4'
        Color for top-scoring embryos
    color_bottom : str, default='#d62728'
        Color for bottom-scoring embryos
    include_phenotypes : bool, default=True
        Whether to include phenotype in the legend labels
    title : str, optional
        Custom title for the plot
    save_path : str, optional
        Path to save the figure, if provided
        
    Returns:
    --------
    matplotlib.figure.Figure
        The created figure
        
    Example:
    --------
    # Basic usage showing top 5 and bottom 5 embryos by hypotenuse after 55 hpf
    fig = visualize_extreme_embryos(df_augmented_projec_wt)
    
    # Custom metric and larger window for smoother lines
    fig = visualize_extreme_embryos(
        df_augmented_projec_wt,
        metric_column='M',
        time_threshold=48,
        window_size=7
    )
    """
    import pandas as pd
    import matplotlib.pyplot as plt
    import numpy as np
    
    # Make a copy to avoid modifying the original
    df = df.copy()
    
    # Ensure required columns exist
    required_cols = [metric_column, 'embryo_id', time_column]
    missing_cols = [col for col in required_cols if col not in df.columns]
    if missing_cols:
        raise ValueError(f"Missing required columns: {missing_cols}")
        
    if include_phenotypes and 'phenotype' not in df.columns:
        print("Warning: 'phenotype' column not found. Setting include_phenotypes=False.")
        include_phenotypes = False
    
    # Calculate average metric values for each embryo after the time threshold
    filtered_df = df[df[time_column] > time_threshold]
    
    if filtered_df.empty:
        raise ValueError(f"No data found after {time_column} > {time_threshold}")
    
    # Calculate average metric per embryo
    embryo_metrics = filtered_df.groupby('embryo_id')[metric_column].mean().reset_index()
    embryo_metrics = embryo_metrics.sort_values(metric_column)
    
    # Get bottom and top embryo IDs
    bottom_embryos = embryo_metrics.head(bottom_n)['embryo_id'].tolist()
    top_embryos = embryo_metrics.tail(top_n)['embryo_id'].tolist()
    extreme_embryos = bottom_embryos + top_embryos
    
    # Create figure
    fig, ax = plt.subplots(figsize=figsize)
    
    # Style for plotting
    line_styles = ['-', '--', '-.', ':', (0, (3, 1, 1, 1))]
    
    # Plot trajectory for each extreme embryo
    for i, embryo_id in enumerate(extreme_embryos):
        # Get embryo data
        embryo_data = df[df['embryo_id'] == embryo_id].copy()
        
        if embryo_data.empty:
            continue
            
        # Sort by time
        embryo_data = embryo_data.sort_values(time_column)
        
        # Apply smoothing with rolling window
        embryo_data[f'smooth_{metric_column}'] = embryo_data[metric_column].rolling(
            window=window_size, min_periods=1, center=True
        ).mean()
        
        # Determine if this is a top or bottom embryo
        is_top = embryo_id in top_embryos
        color = color_top if is_top else color_bottom
        line_style = line_styles[i % len(line_styles) if is_top else (i - bottom_n) % len(line_styles)]
        
        # Get phenotype if needed
        if include_phenotypes and 'phenotype' in embryo_data.columns:
            phenotype = embryo_data['phenotype'].iloc[0]
            label = f"{'Top' if is_top else 'Bottom'} - {embryo_id} ({phenotype})"
        else:
            label = f"{'Top' if is_top else 'Bottom'} - {embryo_id}"
        
        # Plot the line
        ax.plot(
            embryo_data[time_column],
            embryo_data[f'smooth_{metric_column}'],
            linestyle=line_style,
            color=color,
            linewidth=2,
            label=label
        )
        
        # Add a vertical line at the threshold
        ax.axvline(x=time_threshold, color='gray', linestyle='--', alpha=0.5)
        
    # Add annotation for the threshold
    ax.text(
        time_threshold + 0.5, 
        ax.get_ylim()[0] + 0.05 * (ax.get_ylim()[1] - ax.get_ylim()[0]),
        f"Threshold: {time_threshold}",
        color='gray', fontsize=10
    )
    
    # Set labels and title
    ax.set_xlabel(time_column, fontsize=12)
    ax.set_ylabel(metric_column, fontsize=12)
    
    if title:
        ax.set_title(title, fontsize=14)
    else:
        ax.set_title(f"Top {top_n} and Bottom {bottom_n} Embryos by {metric_column} after {time_column} > {time_threshold}", fontsize=14)
    
    # Add legend with two columns
    ax.legend(
        bbox_to_anchor=(1.05, 1),
        loc='upper left',
        ncol=2,
        fontsize=10
    )
    
    # Add grid
    ax.grid(True, linestyle='--', alpha=0.7)
    
    # Adjust layout
    plt.tight_layout()
    
    # Save if requested
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    return fig

# Example usage for hypotenuse values after 55 hpf:

fig = visualize_extreme_embryos(
    df_augmented_projec_wt,
    metric_column='hypotenuse',
    time_column='predicted_stage_hpf',
    time_threshold=55,
    top_n=5,
    bottom_n=5,
    window_size=5
)
plt.show()


In [25]:
def visualize_extreme_embryos_dual(
    df,
    metric_column='hypotenuse',
    time_column='predicted_stage_hpf',
    time_threshold=55,
    top_n=5,
    bottom_n=5,
    window_size=5,
    figsize=(18, 8),
    extreme_colors={'top': '#1f77b4', 'bottom': '#d62728'},  # Blue, Red
    save_path=None
):
    """
    Creates two side-by-side plots of embryos with extreme metric values:
    1. Left plot: Colored by top vs. bottom status
    2. Right plot: Colored by phenotype
    
    Parameters:
    -----------
    df : pd.DataFrame
        DataFrame containing the embryo data
    metric_column : str, default='hypotenuse'
        Column name for the metric to analyze
    time_column : str, default='predicted_stage_hpf'
        Column name for the time/stage
    time_threshold : float, default=55
        Minimum time value for calculating average metrics
    top_n : int, default=5
        Number of top-scoring embryos to visualize
    bottom_n : int, default=5
        Number of bottom-scoring embryos to visualize
    window_size : int, default=5
        Window size for rolling average smoothing
    figsize : tuple, default=(18, 8)
        Figure size (width, height)
    extreme_colors : dict, default={'top': '#1f77b4', 'bottom': '#d62728'}
        Colors for top and bottom embryos in the first plot
    save_path : str, optional
        Path to save the figure, if provided
        
    Returns:
    --------
    matplotlib.figure.Figure
        The created figure
        
    Example:
    --------
    # Basic usage showing top 5 and bottom 5 embryos by hypotenuse after 55 hpf
    fig = visualize_extreme_embryos_dual(df_augmented_projec_wt)
    
    # Customize with different threshold and window size
    fig = visualize_extreme_embryos_dual(
        df_augmented_projec_wt,
        time_threshold=48,
        window_size=7
    )
    """
    import pandas as pd
    import matplotlib.pyplot as plt
    import numpy as np
    import seaborn as sns
    
    # Make a copy to avoid modifying the original
    df = df.copy()
    
    # Ensure required columns exist
    required_cols = [metric_column, 'embryo_id', time_column, 'phenotype']
    missing_cols = [col for col in required_cols if col not in df.columns]
    if missing_cols:
        raise ValueError(f"Missing required columns: {missing_cols}")
    
    # Calculate average metric values for each embryo after the time threshold
    filtered_df = df[df[time_column] > time_threshold]
    
    if filtered_df.empty:
        raise ValueError(f"No data found after {time_column} > {time_threshold}")
    
    # Calculate average metric per embryo
    embryo_metrics = filtered_df.groupby('embryo_id')[metric_column].mean().reset_index()
    embryo_metrics = embryo_metrics.sort_values(metric_column)
    
    # Get bottom and top embryo IDs
    bottom_embryos = embryo_metrics.head(bottom_n)['embryo_id'].tolist()
    top_embryos = embryo_metrics.tail(top_n)['embryo_id'].tolist()
    extreme_embryos = bottom_embryos + top_embryos
    
    # Get phenotype for each embryo
    embryo_phenotypes = {}
    for embryo_id in extreme_embryos:
        phenotype = df[df['embryo_id'] == embryo_id]['phenotype'].iloc[0]
        embryo_phenotypes[embryo_id] = phenotype
    
    # Get unique phenotypes for coloring
    unique_phenotypes = list(set(embryo_phenotypes.values()))
    phenotype_colors = dict(zip(unique_phenotypes, 
                              sns.color_palette("husl", len(unique_phenotypes))))
    
    # Create figure with two subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize, sharey=True)
    
    # Style for plotting
    line_styles = ['-', '--', '-.', ':', (0, (3, 1, 1, 1))]
    
    # ---- PLOT 1: Colored by top/bottom status ----
    
    # Plot trajectory for each extreme embryo
    for i, embryo_id in enumerate(extreme_embryos):
        # Get embryo data
        embryo_data = df[df['embryo_id'] == embryo_id].copy()
        
        if embryo_data.empty:
            continue
            
        # Sort by time
        embryo_data = embryo_data.sort_values(time_column)
        
        # Apply smoothing with rolling window
        embryo_data[f'smooth_{metric_column}'] = embryo_data[metric_column].rolling(
            window=window_size, min_periods=1, center=True
        ).mean()
        
        # Determine if this is a top or bottom embryo
        is_top = embryo_id in top_embryos
        status = "top" if is_top else "bottom"
        color = extreme_colors[status]
        
        # Choose line style
        idx = i if is_top else (i - bottom_n)
        line_style = line_styles[idx % len(line_styles)]
        
        # Get phenotype
        phenotype = embryo_phenotypes[embryo_id]
        
        # Plot on first axis - colored by top/bottom
        ax1.plot(
            embryo_data[time_column],
            embryo_data[f'smooth_{metric_column}'],
            linestyle=line_style,
            color=color,
            linewidth=2,
            label=f"{'Top' if is_top else 'Bottom'} - {embryo_id} ({phenotype})"
        )
        
        # Plot on second axis - colored by phenotype
        ax2.plot(
            embryo_data[time_column],
            embryo_data[f'smooth_{metric_column}'],
            linestyle=line_style,
            color=phenotype_colors[phenotype],
            linewidth=2,
            label=f"{embryo_id} ({phenotype})"
        )
    
    # Add vertical lines at the threshold
    ax1.axvline(x=time_threshold, color='gray', linestyle='--', alpha=0.5)
    ax2.axvline(x=time_threshold, color='gray', linestyle='--', alpha=0.5)
    
    # Add annotations for the threshold
    for ax in [ax1, ax2]:
        ax.text(
            time_threshold + 0.5, 
            ax.get_ylim()[0] + 0.05 * (ax.get_ylim()[1] - ax.get_ylim()[0]),
            f"Threshold: {time_threshold}",
            color='gray', fontsize=10
        )
    
    # Set titles and labels
    ax1.set_title(f"Top {top_n} and Bottom {bottom_n} Embryos\n(Colored by Rank)", fontsize=14)
    ax2.set_title(f"Same Embryos\n(Colored by Phenotype)", fontsize=14)
    
    ax1.set_xlabel(time_column, fontsize=12)
    ax2.set_xlabel(time_column, fontsize=12)
    
    ax1.set_ylabel(metric_column, fontsize=12)
    
    # Add grids
    ax1.grid(True, linestyle='--', alpha=0.7)
    ax2.grid(True, linestyle='--', alpha=0.7)
    
    # Create legends
    # For the first plot, create a simpler legend with just top/bottom
    top_line = plt.Line2D([0], [0], color=extreme_colors['top'], lw=2)
    bottom_line = plt.Line2D([0], [0], color=extreme_colors['bottom'], lw=2)
    ax1.legend(
        [top_line, bottom_line],
        [f'Top {top_n} embryos', f'Bottom {bottom_n} embryos'],
        loc='best', fontsize=10
    )
    
    # For the second plot, create a legend with phenotypes
    phenotype_lines = [plt.Line2D([0], [0], color=color, lw=2) 
                      for phenotype, color in phenotype_colors.items()]
    ax2.legend(
        phenotype_lines,
        list(phenotype_colors.keys()),
        title="Phenotypes",
        loc='best', fontsize=10
    )
    
    # Add a main title for the entire figure
    fig.suptitle(
        f"{metric_column.capitalize()} Trajectories - Extreme Values after {time_column} > {time_threshold}", 
        fontsize=16, y=1.02
    )
    
    # Adjust layout
    plt.tight_layout()
    
    # Save if requested
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    return fig

# Example usage for hypotenuse values after 55 hpf:


# Filter to specific phenotypes
df_analyze = df_augmented_projec_wt[df_augmented_projec_wt["phenotype"].isin([
    'cep290_het_cep290_temp30',
    'wt_cep290_temp30',
    'cep290_homo_cep290_temp30'
])]


fig = visualize_extreme_embryos_dual(
    df_analyze,
    metric_column='hypotenuse',
    time_column='predicted_stage_hpf',
    time_threshold=55,
    top_n=5,
    bottom_n=5,
    window_size=5
)
plt.show()


In [28]:

# Find embryos that have at least one data point above 60 hpf
embryos_with_late_stages = df_analyze.groupby('embryo_id')['predicted_stage_hpf'].max() >= 60

# Get the list of embryo IDs that meet the criteria
valid_embryo_ids = embryos_with_late_stages[embryos_with_late_stages].index.tolist()

# Filter the dataframe to only include these embryos
df_analyze = df_analyze[df_analyze['embryo_id'].isin(valid_embryo_ids)]

# Filter to specific phenotypes
df_analyze = df_augmented_projec_wt[df_augmented_projec_wt["phenotype"].isin([
    'cep290_homo_cep290_temp30'
])]


fig = visualize_extreme_embryos_dual(
    df_analyze,
    metric_column='hypotenuse',
    time_column='predicted_stage_hpf',
    time_threshold=55,
    top_n=5,
    bottom_n=5,
    window_size=5
)
plt.show()


fig = visualize_extreme_embryos_dual(
    df_analyze,
    metric_column='hypotenuse',
    time_column='predicted_stage_hpf',
    time_threshold=30,
    top_n=5,
    bottom_n=5,
    window_size=5
)
plt.show()


In [30]:
def visualize_extreme_embryos_dual_range(
    df,
    metric_column='hypotenuse',
    time_column='predicted_stage_hpf',
    time_range=(55, 70),  # Calculate top/bottom within this range
    display_range=None,   # Optional different range for display
    top_n=5,
    bottom_n=5,
    window_size=5,
    figsize=(18, 8),
    extreme_colors={'top': '#1f77b4', 'bottom': '#d62728'},  # Blue, Red
    save_path=None
):
    """
    Creates two side-by-side plots of embryos with extreme metric values within a specific time range:
    1. Left plot: Colored by top vs. bottom status
    2. Right plot: Colored by phenotype
    
    Parameters:
    -----------
    df : pd.DataFrame
        DataFrame containing the embryo data
    metric_column : str, default='hypotenuse'
        Column name for the metric to analyze
    time_column : str, default='predicted_stage_hpf'
        Column name for the time/stage
    time_range : tuple, default=(55, 70)
        (min, max) time range used to calculate average metrics for ranking
    display_range : tuple, optional
        (min, max) time range for display. If None, shows all data
    top_n : int, default=5
        Number of top-scoring embryos to visualize
    bottom_n : int, default=5
        Number of bottom-scoring embryos to visualize
    window_size : int, default=5
        Window size for rolling average smoothing
    figsize : tuple, default=(18, 8)
        Figure size (width, height)
    extreme_colors : dict, default={'top': '#1f77b4', 'bottom': '#d62728'}
        Colors for top and bottom embryos in the first plot
    save_path : str, optional
        Path to save the figure, if provided
        
    Returns:
    --------
    tuple
        (fig, top_embryos_list, bottom_embryos_list)
        
    Example:
    --------
    # Analyze embryos in early development (33-55 hpf)
    fig, top_early, bottom_early = visualize_extreme_embryos_dual_range(
        df_augmented_projec_wt,
        time_range=(33, 55)
    )
    
    # Analyze embryos in late development (55-70 hpf)
    fig, top_late, bottom_late = visualize_extreme_embryos_dual_range(
        df_augmented_projec_wt,
        time_range=(55, 70)
    )
    
    # Get overlapping embryos to see which are consistently extreme
    early_extreme = set(top_early + bottom_early)
    late_extreme = set(top_late + bottom_late)
    consistent_embryos = early_extreme.intersection(late_extreme)
    """
    import pandas as pd
    import matplotlib.pyplot as plt
    import numpy as np
    import seaborn as sns
    
    # Make a copy to avoid modifying the original
    df = df.copy()
    
    # Ensure required columns exist
    required_cols = [metric_column, 'embryo_id', time_column, 'phenotype']
    missing_cols = [col for col in required_cols if col not in df.columns]
    if missing_cols:
        raise ValueError(f"Missing required columns: {missing_cols}")
    
    # Extract time range values
    min_time, max_time = time_range
    
    # Calculate average metric values for each embryo within the specified time range
    filtered_df = df[(df[time_column] >= min_time) & (df[time_column] <= max_time)]
    
    if filtered_df.empty:
        raise ValueError(f"No data found within {time_column} range {min_time}-{max_time}")
    
    # Calculate average metric per embryo
    embryo_metrics = filtered_df.groupby('embryo_id')[metric_column].mean().reset_index()
    embryo_metrics = embryo_metrics.sort_values(metric_column)
    
    # Get bottom and top embryo IDs
    bottom_embryos = embryo_metrics.head(bottom_n)['embryo_id'].tolist()
    top_embryos = embryo_metrics.tail(top_n)['embryo_id'].tolist()
    extreme_embryos = bottom_embryos + top_embryos
    
    # Get phenotype for each embryo
    embryo_phenotypes = {}
    for embryo_id in extreme_embryos:
        phenotype = df[df['embryo_id'] == embryo_id]['phenotype'].iloc[0]
        embryo_phenotypes[embryo_id] = phenotype
    
    # Get unique phenotypes for coloring
    unique_phenotypes = list(set(embryo_phenotypes.values()))
    phenotype_colors = dict(zip(unique_phenotypes, 
                              sns.color_palette("husl", len(unique_phenotypes))))
    
    # Determine display range
    if display_range is None:
        # If no display range is provided, show all data for these embryos
        min_display = df[df['embryo_id'].isin(extreme_embryos)][time_column].min()
        max_display = df[df['embryo_id'].isin(extreme_embryos)][time_column].max()
    else:
        min_display, max_display = display_range
    
    # Create figure with two subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize, sharey=True)
    
    # Style for plotting
    line_styles = ['-', '--', '-.', ':', (0, (3, 1, 1, 1))]
    
    # ---- PLOT 1: Colored by top/bottom status ----
    
    # Plot trajectory for each extreme embryo
    for i, embryo_id in enumerate(extreme_embryos):
        # Get embryo data within display range
        embryo_data = df[(df['embryo_id'] == embryo_id) & 
                         (df[time_column] >= min_display) & 
                         (df[time_column] <= max_display)].copy()
        
        if embryo_data.empty:
            continue
            
        # Sort by time
        embryo_data = embryo_data.sort_values(time_column)
        
        # Apply smoothing with rolling window
        embryo_data[f'smooth_{metric_column}'] = embryo_data[metric_column].rolling(
            window=window_size, min_periods=1, center=True
        ).mean()
        
        # Determine if this is a top or bottom embryo
        is_top = embryo_id in top_embryos
        status = "top" if is_top else "bottom"
        color = extreme_colors[status]
        
        # Choose line style
        idx = i if is_top else (i - bottom_n)
        line_style = line_styles[idx % len(line_styles)]
        
        # Get phenotype
        phenotype = embryo_phenotypes[embryo_id]
        
        # Calculate the average metric within the ranking range for this embryo
        avg_metric = filtered_df[filtered_df['embryo_id'] == embryo_id][metric_column].mean()
        
        # Plot on first axis - colored by top/bottom
        ax1.plot(
            embryo_data[time_column],
            embryo_data[f'smooth_{metric_column}'],
            linestyle=line_style,
            color=color,
            linewidth=2,
            label=f"{'Top' if is_top else 'Bottom'} - {embryo_id} (avg: {avg_metric:.2f})"
        )
        
        # Plot on second axis - colored by phenotype
        ax2.plot(
            embryo_data[time_column],
            embryo_data[f'smooth_{metric_column}'],
            linestyle=line_style,
            color=phenotype_colors[phenotype],
            linewidth=2,
            label=f"{embryo_id} ({phenotype})"
        )
    
    # Add shaded area showing the ranking range
    for ax in [ax1, ax2]:
        ax.axvspan(min_time, max_time, alpha=0.2, color='gray')
        
        # Add text to indicate the ranking range
        y_range = ax.get_ylim()
        y_pos = y_range[0] + 0.05 * (y_range[1] - y_range[0])
        ax.text(
            (min_time + max_time) / 2, 
            y_pos,
            f"Ranking range: {min_time}-{max_time}",
            horizontalalignment='center',
            bbox=dict(facecolor='white', alpha=0.8, edgecolor='none'),
            fontsize=10
        )
    
    # Set titles and labels
    ax1.set_title(f"Top {top_n} and Bottom {bottom_n} Embryos\n(Colored by Rank)", fontsize=14)
    ax2.set_title(f"Same Embryos\n(Colored by Phenotype)", fontsize=14)
    
    ax1.set_xlabel(time_column, fontsize=12)
    ax2.set_xlabel(time_column, fontsize=12)
    
    ax1.set_ylabel(metric_column, fontsize=12)
    
    # Add grids
    ax1.grid(True, linestyle='--', alpha=0.7)
    ax2.grid(True, linestyle='--', alpha=0.7)
    
    # Create legends
    # For the first plot, create a simpler legend with just top/bottom
    top_line = plt.Line2D([0], [0], color=extreme_colors['top'], lw=2)
    bottom_line = plt.Line2D([0], [0], color=extreme_colors['bottom'], lw=2)
    ax1.legend(
        [top_line, bottom_line],
        [f'Top {top_n} embryos', f'Bottom {bottom_n} embryos'],
        loc='best', fontsize=10
    )
    
    # For the second plot, create a legend with phenotypes
    phenotype_lines = [plt.Line2D([0], [0], color=color, lw=2) 
                      for phenotype, color in phenotype_colors.items()]
    ax2.legend(
        phenotype_lines,
        list(phenotype_colors.keys()),
        title="Phenotypes",
        loc='best', fontsize=10
    )
    
    # Add a main title for the entire figure
    fig.suptitle(
        f"{metric_column.capitalize()} Trajectories - Ranked within {time_column} range {min_time}-{max_time}", 
        fontsize=16, y=1.02
    )
    
    # Adjust layout
    plt.tight_layout()
    
    # Save if requested
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    return fig, top_embryos, bottom_embryos

# Example usage for different developmental windows:
"""
# Early development window
fig_early, top_early, bottom_early = visualize_extreme_embryos_dual_range(
    df_augmented_projec_wt,
    metric_column='hypotenuse',
    time_range=(33, 55),  # Calculate top/bottom within this range
    display_range=(30, 70),  # Show the entire timeline
    top_n=5,
    bottom_n=5
)

# Late development window
fig_late, top_late, bottom_late = visualize_extreme_embryos_dual_range(
    df_augmented_projec_wt,
    metric_column='hypotenuse',
    time_range=(55, 70),  # Calculate top/bottom within this range
    display_range=(30, 70),  # Show the entire timeline
    top_n=5,
    bottom_n=5
)

# Check which embryos are consistently in the extremes
early_extreme = set(top_early + bottom_early)
late_extreme = set(top_late + bottom_late)
consistent = early_extreme.intersection(late_extreme)
print(f"Embryos that are extreme in both early and late windows: {consistent}")
"""

In [33]:
# First filter your data to include only complete embryos
df_analyze = df_augmented_projec_wt[df_augmented_projec_wt["phenotype"].isin([
    'cep290_homo_cep290_temp30'
])]

# Find embryos that have at least one data point above 60 hpf
embryos_with_late_stages = df_analyze.groupby('embryo_id')['predicted_stage_hpf'].max() >= 60
valid_embryo_ids = embryos_with_late_stages[embryos_with_late_stages].index.tolist()
df_analyze_filtered = df_analyze[df_analyze['embryo_id'].isin(valid_embryo_ids)]

# Analyze early development window (33-55 hpf)
fig_early, top_early, bottom_early = visualize_extreme_embryos_dual_range(
    df_analyze_filtered,
    metric_column='hypotenuse',
    time_range=(33, 55),
    display_range=(30, 70),  # Show full timeline for context
    top_n=5,
    bottom_n=5
)

# Analyze late development window (55-70 hpf)
fig_late, top_late, bottom_late = visualize_extreme_embryos_dual_range(
    df_analyze_filtered,
    metric_column='hypotenuse',
    time_range=(55, 70),
    display_range=(30, 70),  # Show full timeline for context
    top_n=5,
    bottom_n=5
)

# First filter your data to include only complete embryos
df_analyze = df_augmented_projec_wt[df_augmented_projec_wt["phenotype"].isin([
    'cep290_homo_cep290_temp30'
])]

# Find embryos that have at least one data point above 60 hpf
embryos_with_late_stages = df_analyze.groupby('embryo_id')['predicted_stage_hpf'].max() >= 30
valid_embryo_ids = embryos_with_late_stages[embryos_with_late_stages].index.tolist()
df_analyze_filtered = df_analyze[df_analyze['embryo_id'].isin(valid_embryo_ids)]

# Analyze early development window (33-55 hpf)
fig_early, top_early, bottom_early = visualize_extreme_embryos_dual_range(
    df_analyze_filtered,
    metric_column='hypotenuse',
    time_range=(33, 55),
    display_range=(30, 70),  # Show full timeline for context
    top_n=5,
    bottom_n=5
)

# Analyze late development window (55-70 hpf)
fig_late, top_late, bottom_late = visualize_extreme_embryos_dual_range(
    df_analyze_filtered,
    metric_column='hypotenuse',
    time_range=(55, 70),
    display_range=(30, 70),  # Show full timeline for context
    top_n=5,
    bottom_n=5
)

In [36]:
# Filter to specific phenotypes
df_analyze = df_augmented_projec_wt[df_augmented_projec_wt["phenotype"].isin([
    'cep290_het_cep290_temp30',
    'wt_cep290_temp30',
    'cep290_homo_cep290_temp30'
])]


# Find embryos that have at least one data point above 60 hpf
embryos_with_late_stages = df_analyze.groupby('embryo_id')['predicted_stage_hpf'].max() >= 55
valid_embryo_ids = embryos_with_late_stages[embryos_with_late_stages].index.tolist()
df_analyze_filtered = df_analyze[df_analyze['embryo_id'].isin(valid_embryo_ids)]

# Analyze early development window (33-55 hpf)
fig_early, top_early, bottom_early = visualize_extreme_embryos_dual_range(
    df_analyze_filtered,
    metric_column='hypotenuse',
    time_range=(33, 55),
    display_range=(20, 70),  # Show full timeline for context
    top_n=5,
    bottom_n=5
)

# Analyze late development window (55-70 hpf)
fig_late, top_late, bottom_late = visualize_extreme_embryos_dual_range(
    df_analyze_filtered,
    metric_column='hypotenuse',
    time_range=(55, 70),
    display_range=(20, 70),  # Show full timeline for context
    top_n=5,
    bottom_n=5
)

# Analyze late development window (55-70 hpf)
fig_late, top_late, bottom_late = visualize_extreme_embryos_dual_range(
    df_analyze_filtered,
    metric_column='hypotenuse',
    time_range=(20, 26),
    display_range=(20, 70),  # Show full timeline for context
    top_n=5,
    bottom_n=5
)

In [37]:
# Filter to specific phenotypes
df_analyze = df_augmented_projec_wt[df_augmented_projec_wt["phenotype"].isin([
    'cep290_het_cep290_temp30',
    'wt_cep290_temp30',
])]


# Find embryos that have at least one data point above 60 hpf
embryos_with_late_stages = df_analyze.groupby('embryo_id')['predicted_stage_hpf'].max() >= 55
valid_embryo_ids = embryos_with_late_stages[embryos_with_late_stages].index.tolist()
df_analyze_filtered = df_analyze[df_analyze['embryo_id'].isin(valid_embryo_ids)]

# Analyze early development window (33-55 hpf)
fig_early, top_early, bottom_early = visualize_extreme_embryos_dual_range(
    df_analyze_filtered,
    metric_column='hypotenuse',
    time_range=(33, 55),
    display_range=(20, 70),  # Show full timeline for context
    top_n=5,
    bottom_n=5
)

# Analyze late development window (55-70 hpf)
fig_late, top_late, bottom_late = visualize_extreme_embryos_dual_range(
    df_analyze_filtered,
    metric_column='hypotenuse',
    time_range=(55, 70),
    display_range=(20, 70),  # Show full timeline for context
    top_n=5,
    bottom_n=5
)

# Analyze late development window (55-70 hpf)
fig_late, top_late, bottom_late = visualize_extreme_embryos_dual_range(
    df_analyze_filtered,
    metric_column='hypotenuse',
    time_range=(20, 26),
    display_range=(20, 70),  # Show full timeline for context
    top_n=5,
    bottom_n=5
)

In [38]:
def visualize_all_embryos_ranked(
    df,
    metric_column='hypotenuse',
    time_column='predicted_stage_hpf',
    time_range=(55, 70),  # Calculate ranking within this range
    display_range=None,   # Optional different range for display
    window_size=5,
    figsize=(18, 8),
    cmap='viridis',       # Colormap for ranking
    reverse_cmap=False,   # If True, higher values are at the bottom of the colormap
    save_path=None
):
    """
    Creates two side-by-side plots of all embryos with gradient coloring based on ranking:
    1. Left plot: Colored by metric rank using viridis colormap
    2. Right plot: Colored by phenotype
    
    Parameters:
    -----------
    df : pd.DataFrame
        DataFrame containing the embryo data
    metric_column : str, default='hypotenuse'
        Column name for the metric to analyze
    time_column : str, default='predicted_stage_hpf'
        Column name for the time/stage
    time_range : tuple, default=(55, 70)
        (min, max) time range used to calculate average metrics for ranking
    display_range : tuple, optional
        (min, max) time range for display. If None, shows all data
    window_size : int, default=5
        Window size for rolling average smoothing
    figsize : tuple, default=(18, 8)
        Figure size (width, height)
    cmap : str, default='viridis'
        Matplotlib colormap name to use for ranking
    reverse_cmap : bool, default=False
        If True, reverses the colormap so higher values are at the bottom
    save_path : str, optional
        Path to save the figure, if provided
        
    Returns:
    --------
    tuple
        (fig, embryo_rankings_df)
        
    Example:
    --------
    # Visualize all embryos with gradient coloring by rank
    fig, rankings = visualize_all_embryos_ranked(
        df_analyze_filtered,
        time_range=(55, 70)
    )
    
    # Use a different colormap and reverse it (so highest values are blue)
    fig, rankings = visualize_all_embryos_ranked(
        df_analyze_filtered,
        cmap='plasma',
        reverse_cmap=True
    )
    """
    import pandas as pd
    import matplotlib.pyplot as plt
    import numpy as np
    import seaborn as sns
    from matplotlib.colors import Normalize
    from matplotlib.cm import ScalarMappable
    
    # Make a copy to avoid modifying the original
    df = df.copy()
    
    # Ensure required columns exist
    required_cols = [metric_column, 'embryo_id', time_column, 'phenotype']
    missing_cols = [col for col in required_cols if col not in df.columns]
    if missing_cols:
        raise ValueError(f"Missing required columns: {missing_cols}")
    
    # Extract time range values
    min_time, max_time = time_range
    
    # Calculate average metric values for each embryo within the specified time range
    filtered_df = df[(df[time_column] >= min_time) & (df[time_column] <= max_time)]
    
    if filtered_df.empty:
        raise ValueError(f"No data found within {time_column} range {min_time}-{max_time}")
    
    # Calculate average metric per embryo and create ranking
    embryo_metrics = filtered_df.groupby('embryo_id')[metric_column].mean().reset_index()
    embryo_metrics = embryo_metrics.sort_values(metric_column)
    embryo_metrics['rank'] = range(1, len(embryo_metrics) + 1)
    embryo_metrics['normalized_rank'] = (embryo_metrics['rank'] - 1) / max(1, len(embryo_metrics) - 1)
    
    if reverse_cmap:
        embryo_metrics['normalized_rank'] = 1 - embryo_metrics['normalized_rank']
    
    # Get list of all embryo IDs with their rankings
    embryo_rankings = dict(zip(embryo_metrics['embryo_id'], embryo_metrics['normalized_rank']))
    
    # Get phenotype for each embryo
    embryo_phenotypes = {}
    for embryo_id in embryo_metrics['embryo_id']:
        phenotype = df[df['embryo_id'] == embryo_id]['phenotype'].iloc[0]
        embryo_phenotypes[embryo_id] = phenotype
    
    # Get unique phenotypes for coloring right plot
    unique_phenotypes = list(set(embryo_phenotypes.values()))
    phenotype_colors = dict(zip(unique_phenotypes, 
                                sns.color_palette("husl", len(unique_phenotypes))))
    
    # Determine display range
    if display_range is None:
        # If no display range is provided, show all data for these embryos
        embryo_ids = embryo_metrics['embryo_id'].tolist()
        display_df = df[df['embryo_id'].isin(embryo_ids)]
        min_display = display_df[time_column].min()
        max_display = display_df[time_column].max()
    else:
        min_display, max_display = display_range
    
    # Create figure with two subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize, sharey=True)
    
    # Create colormap for ranking
    cmap_obj = plt.get_cmap(cmap)
    norm = Normalize(vmin=embryo_metrics[metric_column].min(), vmax=embryo_metrics[metric_column].max())
    
    # Plot trajectory for each embryo
    for idx, row in embryo_metrics.iterrows():
        embryo_id = row['embryo_id']
        avg_metric = row[metric_column]
        normalized_rank = row['normalized_rank']
        
        # Get embryo data within display range
        embryo_data = df[(df['embryo_id'] == embryo_id) & 
                         (df[time_column] >= min_display) & 
                         (df[time_column] <= max_display)].copy()
        
        if embryo_data.empty:
            continue
            
        # Sort by time
        embryo_data = embryo_data.sort_values(time_column)
        
        # Apply smoothing with rolling window
        embryo_data[f'smooth_{metric_column}'] = embryo_data[metric_column].rolling(
            window=window_size, min_periods=1, center=True
        ).mean()
        
        # Get color for this embryo based on its rank
        rank_color = cmap_obj(normalized_rank)
        
        # Get phenotype
        phenotype = embryo_phenotypes[embryo_id]
        
        # Plot on first axis - colored by rank
        ax1.plot(
            embryo_data[time_column],
            embryo_data[f'smooth_{metric_column}'],
            color=rank_color,
            linewidth=1.5,
            alpha=0.7
        )
        
        # Plot on second axis - colored by phenotype
        ax2.plot(
            embryo_data[time_column],
            embryo_data[f'smooth_{metric_column}'],
            color=phenotype_colors[phenotype],
            linewidth=1.5,
            alpha=0.7
        )
    
    # Add shaded area showing the ranking range
    for ax in [ax1, ax2]:
        ax.axvspan(min_time, max_time, alpha=0.2, color='gray')
        
        # Add text to indicate the ranking range
        y_range = ax.get_ylim()
        y_pos = y_range[0] + 0.05 * (y_range[1] - y_range[0])
        ax.text(
            (min_time + max_time) / 2, 
            y_pos,
            f"Ranking range: {min_time}-{max_time}",
            horizontalalignment='center',
            bbox=dict(facecolor='white', alpha=0.8, edgecolor='none'),
            fontsize=10
        )
    
    # Set titles and labels
    ax1.set_title(f"All Embryos (n={len(embryo_metrics)})\nColored by {metric_column} Value", fontsize=14)
    ax2.set_title(f"All Embryos\nColored by Phenotype", fontsize=14)
    
    ax1.set_xlabel(time_column, fontsize=12)
    ax2.set_xlabel(time_column, fontsize=12)
    
    ax1.set_ylabel(metric_column, fontsize=12)
    
    # Add grids
    ax1.grid(True, linestyle='--', alpha=0.7)
    ax2.grid(True, linestyle='--', alpha=0.7)
    
    # Add a colorbar for the first plot
    sm = ScalarMappable(cmap=cmap_obj, norm=norm)
    sm.set_array([])
    cbar = plt.colorbar(sm, ax=ax1)
    cbar.set_label(f'Average {metric_column} ({min_time}-{max_time})', fontsize=10)
    
    # Add legend for phenotypes on the second plot
    phenotype_lines = [plt.Line2D([0], [0], color=color, lw=2) 
                      for phenotype, color in phenotype_colors.items()]
    ax2.legend(
        phenotype_lines,
        list(phenotype_colors.keys()),
        title="Phenotypes",
        loc='best', fontsize=10
    )
    
    # Add a main title for the entire figure
    fig.suptitle(
        f"{metric_column.capitalize()} Trajectories - All Embryos Ranked by {min_time}-{max_time} {time_column}", 
        fontsize=16, y=1.02
    )
    
    # Adjust layout
    plt.tight_layout()
    
    # Save if requested
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    return fig, embryo_metrics


In [39]:
# Filter to embryos with complete development coverage
df_analyze = df_augmented_projec_wt[df_augmented_projec_wt["phenotype"].isin([
    'cep290_homo_cep290_temp30'
])]
embryos_with_late_stages = df_analyze.groupby('embryo_id')['predicted_stage_hpf'].max() >= 60
valid_embryo_ids = embryos_with_late_stages[embryos_with_late_stages].index.tolist()
df_analyze_filtered = df_analyze[df_analyze['embryo_id'].isin(valid_embryo_ids)]

# Visualize all embryos with gradient coloring by rank
fig, rankings = visualize_all_embryos_ranked(
    df_analyze_filtered,
    metric_column='hypotenuse',
    time_range=(55, 70),     # Calculate ranking in this window
    display_range=(20, 75),  # Show this full time range
    cmap='plasma',          # Use viridis colormap (default)
    window_size=5            # For smoothing
)
plt.show()


# Visualize all embryos with gradient coloring by rank
fig, rankings = visualize_all_embryos_ranked(
    df_analyze_filtered,
    metric_column='hypotenuse',
    time_range=(33, 55),     # Calculate ranking in this window
    display_range=(20, 75),  # Show this full time range
    cmap='plasma',          # Use viridis colormap (default)
    window_size=5            # For smoothing
)
plt.show()

# Visualize all embryos with gradient coloring by rank
fig, rankings = visualize_all_embryos_ranked(
    df_analyze_filtered,
    metric_column='hypotenuse',
    time_range=(20, 26),     # Calculate ranking in this window
    display_range=(20, 75),  # Show this full time range
    cmap='plasma',          # Use viridis colormap (default)
    window_size=5            # For smoothing
)
plt.show()

In [40]:
# Filter to embryos with complete development coverage
# Filter to specific phenotypes
df_analyze = df_augmented_projec_wt[df_augmented_projec_wt["phenotype"].isin([
    'cep290_het_cep290_temp30',
    'wt_cep290_temp30',
])]



embryos_with_late_stages = df_analyze.groupby('embryo_id')['predicted_stage_hpf'].max() >= 60
valid_embryo_ids = embryos_with_late_stages[embryos_with_late_stages].index.tolist()
df_analyze_filtered = df_analyze[df_analyze['embryo_id'].isin(valid_embryo_ids)]

# Visualize all embryos with gradient coloring by rank
fig, rankings = visualize_all_embryos_ranked(
    df_analyze_filtered,
    metric_column='hypotenuse',
    time_range=(55, 70),     # Calculate ranking in this window
    display_range=(20, 75),  # Show this full time range
    cmap='plasma',          # Use viridis colormap (default)
    window_size=5            # For smoothing
)
plt.show()


# Visualize all embryos with gradient coloring by rank
fig, rankings = visualize_all_embryos_ranked(
    df_analyze_filtered,
    metric_column='hypotenuse',
    time_range=(33, 55),     # Calculate ranking in this window
    display_range=(20, 75),  # Show this full time range
    cmap='plasma',          # Use viridis colormap (default)
    window_size=5            # For smoothing
)
plt.show()

# Visualize all embryos with gradient coloring by rank
fig, rankings = visualize_all_embryos_ranked(
    df_analyze_filtered,
    metric_column='hypotenuse',
    time_range=(20, 26),     # Calculate ranking in this window
    display_range=(20, 75),  # Show this full time range
    cmap='plasma',          # Use viridis colormap (default)
    window_size=5            # For smoothing
)
plt.show()

In [41]:
# Filter to specific phenotypes
df_analyze = df_augmented_projec_wt[df_augmented_projec_wt["phenotype"].isin([
    'cep290_het_cep290_temp30',
    'wt_cep290_temp30',
    'cep290_homo_cep290_temp30'
])]



embryos_with_late_stages = df_analyze.groupby('embryo_id')['predicted_stage_hpf'].max() >= 60
valid_embryo_ids = embryos_with_late_stages[embryos_with_late_stages].index.tolist()
df_analyze_filtered = df_analyze[df_analyze['embryo_id'].isin(valid_embryo_ids)]

# Visualize all embryos with gradient coloring by rank
fig, rankings = visualize_all_embryos_ranked(
    df_analyze_filtered,
    metric_column='hypotenuse',
    time_range=(55, 70),     # Calculate ranking in this window
    display_range=(20, 75),  # Show this full time range
    cmap='plasma',          # Use viridis colormap (default)
    window_size=5            # For smoothing
)
plt.show()


# Visualize all embryos with gradient coloring by rank
fig, rankings = visualize_all_embryos_ranked(
    df_analyze_filtered,
    metric_column='hypotenuse',
    time_range=(33, 55),     # Calculate ranking in this window
    display_range=(20, 75),  # Show this full time range
    cmap='plasma',          # Use viridis colormap (default)
    window_size=5            # For smoothing
)
plt.show()

# Visualize all embryos with gradient coloring by rank
fig, rankings = visualize_all_embryos_ranked(
    df_analyze_filtered,
    metric_column='hypotenuse',
    time_range=(20, 26),     # Calculate ranking in this window
    display_range=(20, 75),  # Show this full time range
    cmap='plasma',          # Use viridis colormap (default)
    window_size=5            # For smoothing
)
plt.show()

In [43]:

# Visualize all embryos with gradient coloring by rank
fig, rankings = visualize_all_embryos_ranked(
    df_analyze_filtered,
    metric_column='hypotenuse',
    time_range=(55, 60),     # Calculate ranking in this window
    display_range=(55, 75),  # Show this full time range
    cmap='plasma',          # Use viridis colormap (default)
    window_size=5            # For smoothing
)
plt.show()


In [44]:


embryos_with_late_stages = df_analyze.groupby('embryo_id')['predicted_stage_hpf'].max() >= 30
valid_embryo_ids = embryos_with_late_stages[embryos_with_late_stages].index.tolist()
df_analyze_filtered = df_analyze[df_analyze['embryo_id'].isin(valid_embryo_ids)]


# Visualize all embryos with gradient coloring by rank
fig, rankings = visualize_all_embryos_ranked(
    df_analyze_filtered,
    metric_column='hypotenuse',
    time_range=(35,38),     # Calculate ranking in this window
    display_range=(28, 55),  # Show this full time range
    cmap='plasma',          # Use viridis colormap (default)
    window_size=5            # For smoothing
)
plt.show()



# Visualize all embryos with gradient coloring by rank
fig, rankings = visualize_all_embryos_ranked(
    df_analyze_filtered,
    metric_column='hypotenuse',
    time_range=(45,50),     # Calculate ranking in this window
    display_range=(28, 55),  # Show this full time range
    cmap='plasma',          # Use viridis colormap (default)
    window_size=5            # For smoothing
)
plt.show()
