In [None]:
import pandas as pd
import os
import numpy as np
import seaborn as sns 

import plotly.graph_objects as go
from matplotlib import pyplot as plt
import os
import plotly.express as px

from scipy.interpolate import CubicSpline
from sklearn.metrics.pairwise import cosine_similarity
from scipy.interpolate import CubicSpline, interp1d
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression




os.chdir("/net/trapnell/vol1/home/mdcolon/proj/morphseq")

# pert_comparisons = ["wnt-i", "wt"]
pert_comparisons = ["wnt-i", "tgfb-i", "wt", "lmx1b", "gdf3"]

color_map = {
    "wnt-i": "red",
    "tgfb-i": "green",
    "wt": "blue",
    "lmx1b": "orange",
    "gdf3": "purple"
}



In [21]:
def plot_trajectories_3d(splines_final):
    """
    Plots PCA trajectories for different perturbations and datasets in a 3D Plotly plot.

    Parameters:
    splines_final (pd.DataFrame): DataFrame containing the trajectory data with columns
                                  ['dataset', 'Perturbation', 'point_index', 'PCA_1', 'PCA_2', 'PCA_3']

    Returns:
    None
    """
    # Define perturbations and their corresponding colors
    pert_comparisons = ["wnt-i", "tgfb-i", "wt", "lmx1b", "gdf3"]
    
    color_map = {
        "wnt-i": "red",
        "tgfb-i": "green",
        "wt": "blue",
        "lmx1b": "orange",
        "gdf3": "purple"
    }
    
    # Define dataset styles with dash styles
    dataset_styles = {
        "all": {"dash": "solid", "name": "all"},
        "hld": {"dash": "dash", "name": "hld"},
        "hld_aligned": {"dash": "dot", "name": "hld aligned"}
    }
    
    # Initialize the figure
    fig = go.Figure()

    # Iterate over each perturbation
    for pert in pert_comparisons:
        pert_data = splines_final[splines_final['Perturbation'] == pert]
        color = color_map.get(pert, "black")  # Default to black if perturbation not found
        
        # Iterate over each dataset
        for dataset, style in dataset_styles.items():
            dataset_data = pert_data[pert_data['dataset'] == dataset]
            
            if dataset_data.empty:
                continue  # Skip if there's no data for this dataset
            
            # Sort by point_index to ensure proper trajectory
            dataset_data = dataset_data.sort_values(by='point_index')
            
            # Add trace
            fig.add_trace(
                go.Scatter3d(
                    x=dataset_data['PCA_1'],
                    y=dataset_data['PCA_2'],
                    z=dataset_data['PCA_3'],
                    mode='lines',
                    name=f"{pert} - {style['name']}",
                    line=dict(color=color, dash=style['dash'], width=4),
                    )
                )
            
    
    # Show the plot
    return fig


In [4]:
class LocalPrincipalCurve:
    def __init__(self, bandwidth=0.5, max_iter=100, tol=1e-4, angle_penalty_exp=2, h=None):
        """
        Initialize the Local Principal Curve solver.
        """
        self.bandwidth = bandwidth
        self.h = h if h is not None else self.bandwidth
        self.max_iter = max_iter
        self.tol = tol
        self.angle_penalty_exp = angle_penalty_exp

        self.initializations = []
        self.paths = []
        self.cubic_splines_eq = []
        self.cubic_splines = []

    def _kernel_weights(self, dataset, x):
        """Compute Gaussian kernel weights w_i = K_h(X_i - x)."""
        dists = np.linalg.norm(dataset - x, axis=1)
        weights = np.exp(- (dists**2) / (2 * self.bandwidth**2))
        w = weights / np.sum(weights)
        return w

    def _local_center_of_mass(self, dataset, x):
        """Compute µ_x, the local center of mass around x."""
        w = self._kernel_weights(dataset, x)
        mu = np.sum(dataset.T * w, axis=1)
        return mu

    def _local_covariance(self, dataset, x, mu):
        """Compute the local covariance matrix Σ_x."""
        w = self._kernel_weights(dataset, x)
        centered = dataset - mu
        cov = np.zeros((dataset.shape[1], dataset.shape[1]))
        for i in range(len(dataset)):
            cov += w[i] * np.outer(centered[i], centered[i])
        return cov

    def _principal_component(self, cov, prev_vec=None):
        """Compute the first local principal component γ_x, with angle penalization."""
        vals, vecs = np.linalg.eig(cov)
        idx = np.argsort(vals)[::-1]
        vals = vals[idx]
        vecs = vecs[:, idx]

        gamma = vecs[:, 0]  # first principal component

        # Sign flipping to maintain direction if prev_vec is given
        if prev_vec is not None and np.linalg.norm(prev_vec) != 0:
            cos_alpha = np.dot(gamma, prev_vec) / (np.linalg.norm(gamma)*np.linalg.norm(prev_vec))
            if cos_alpha < 0:
                gamma = -gamma

            # Angle penalization
            cos_alpha = np.dot(gamma, prev_vec) / (np.linalg.norm(gamma)*np.linalg.norm(prev_vec))
            a_x = (abs(cos_alpha))**self.angle_penalty_exp
            gamma = a_x * gamma + (1 - a_x) * prev_vec
            gamma /= np.linalg.norm(gamma)

        return gamma

    def _forward_run(self, dataset, x_start):
        """Run the algorithm forward from a starting point using the full dataset."""
        x = x_start
        path_x = [x]
        prev_gamma = None

        for _ in range(self.max_iter):
            mu = self._local_center_of_mass(dataset, x)
            cov = self._local_covariance(dataset, x, mu)
            gamma = self._principal_component(cov, prev_vec=prev_gamma)

            x_new = mu + self.h * gamma

            # Check convergence
            if np.linalg.norm(mu - x) < self.tol:
                path_x.append(x_new)
                break

            path_x.append(x_new)
            x = x_new
            prev_gamma = gamma

        return np.array(path_x)

    def _backward_run(self, dataset, x0, gamma0):
        """Run the algorithm backwards from x(0) along -γ_x(0)."""
        x = x0
        path_x = [x]
        prev_gamma = -gamma0

        for _ in range(self.max_iter):
            mu = self._local_center_of_mass(dataset, x)
            cov = self._local_covariance(dataset, x, mu)
            gamma = self._principal_component(cov, prev_vec=prev_gamma)

            x_new = mu + self.h * gamma

            if np.linalg.norm(mu - x) < self.tol:
                path_x.append(x_new)
                break

            path_x.append(x_new)
            x = x_new
            prev_gamma = gamma

        return np.array(path_x)

    def _find_starting_point(self, dataset, start_point):
        """Ensure starting point is in dataset or choose closest."""
        if start_point is None:
            idx = np.random.choice(len(dataset))
            return dataset[idx], idx
        else:
            diffs = dataset - start_point
            dists = np.linalg.norm(diffs, axis=1)
            min_idx = np.argmin(dists)
            closest_pt = dataset[min_idx]
            if not np.allclose(closest_pt, start_point):
                print(f"Starting point not in dataset. Using closest point: {closest_pt}")
            return closest_pt, min_idx

    def fit(self, dataset, start_points=None, remove_similar_end_start_points=True):
        """
        Fit LPC on the dataset using possibly multiple starting points.
        """
        dataset = np.array(dataset)
        self.paths = []
        self.initializations = []

        if start_points is None:
            start_points = [None]

        for sp in start_points:
            x0, _ = self._find_starting_point(dataset, sp)

            forward_path = self._forward_run(dataset, x0)
            if len(forward_path) > 1:
                initial_gamma_direction = (forward_path[1] - forward_path[0]) / self.h
            else:
                initial_gamma_direction = np.zeros(dataset.shape[1])

            if np.linalg.norm(initial_gamma_direction) > 0:
                backward_path = self._backward_run(dataset, x0, initial_gamma_direction)
                full_path = np.vstack([backward_path[::-1], forward_path[1:]])
            else:
                full_path = forward_path

            # Check orientation: which end of the path is closer to the starting point x0?
            dist_start_to_first = np.linalg.norm(x0 - full_path[0])
            dist_start_to_last = np.linalg.norm(x0 - full_path[-1])

            if dist_start_to_last < dist_start_to_first:
                # Reverse the path if the last point is closer to the starting point
                full_path = full_path[::-1]

            if remove_similar_end_start_points:
                start_pt = full_path[0]
                end_pt = full_path[-1]

                dist_to_start = np.linalg.norm(full_path - start_pt, axis=1)
                dist_to_end = np.linalg.norm(full_path - end_pt, axis=1)

                mask = np.ones(len(full_path), dtype=bool)
                mask[(dist_to_start < self.tol) | (dist_to_end < self.tol)] = False
                mask[0] = True  # Keep the first point
                mask[-1] = True # Keep the last point

                full_path = full_path[mask]

            self.paths.append(full_path)
            self.initializations.append(x0)

        self._fit_cubic_splines_eq()
        self._compute_equal_arc_length_spline_points()  # Compute equal arc-length points
        return self.paths

    def _fit_cubic_splines_eq(self):
        """Fit cubic splines (equations) for all paths."""
        self.cubic_splines_eq = []
        for path in self.paths:
            if len(path) < 4:
                self.cubic_splines_eq.append(None)
                continue
            t = np.arange(len(path))
            splines_dict = {}
            for dim in range(path.shape[1]):
                splines_dict[dim] = CubicSpline(t, path[:, dim])
            self.cubic_splines_eq.append(splines_dict)

    def _compute_cubic_spline_points(self, num_points=500):
        """
        Compute parameterized points from each cubic spline equation.
        This fills self.cubic_splines with arrays of evaluated points.
        """
        self.cubic_splines = []
        for i, eq in enumerate(self.cubic_splines_eq):
            if eq is None:
                self.cubic_splines.append(None)
                continue
            path = self.paths[i]
            t_values = np.linspace(0, len(path) - 1, num_points)
            spline_points = self.evaluate_cubic_spline(i, t_values)
            self.cubic_splines.append(spline_points)

    def evaluate_cubic_spline(self, path_idx, t_values):
        """Evaluate the cubic spline equation for a specific path at given parameter values."""
        if path_idx >= len(self.cubic_splines_eq) or self.cubic_splines_eq[path_idx] is None:
            raise ValueError(f"No cubic spline found for path index {path_idx}.")
        spline = self.cubic_splines_eq[path_idx]
        points = np.array([spline[dim](t_values) for dim in range(len(spline))]).T
        return points

    def compute_arc_length(self, spline, t_min, t_max, num_samples=10000):
        """
        Approximate the arc length of the spline by sampling it finely.
        spline: dict of {dim: CubicSpline} for the given path.
        t_min, t_max: the parameter range of the spline.
        num_samples: number of samples to approximate arc length.
        Returns:
            t_values: the parameter values used for sampling
            cumulative_length: cumulative arc length corresponding to t_values
        """
        t_values = np.linspace(t_min, t_max, num_samples)
        points = np.array([spline[dim](t_values) for dim in range(len(spline))]).T

        distances = np.sqrt(np.sum(np.diff(points, axis=0)**2, axis=1))
        cumulative_length = np.insert(np.cumsum(distances), 0, 0.0)
        return t_values, cumulative_length

    def get_uniformly_spaced_points(self, spline, num_points):
        """
        Given a fitted spline (as a dict {dim: CubicSpline}), 
        return points spaced equally by arc length.
        """
        # Parameter range: assuming t goes from 0 to (path_length - 1)
        path_length = len(spline[0].x)
        t_min = 0
        t_max = path_length - 1

        t_vals_dense, cum_length = self.compute_arc_length(spline, t_min, t_max, num_samples=5000)
        total_length = cum_length[-1]

        # Desired distances
        desired_distances = np.linspace(0, total_length, num_points)

        # Interpolate t-values for these distances
        t_for_dist = interp1d(cum_length, t_vals_dense, kind='linear')(desired_distances)

        # Evaluate spline at these t-values
        uniform_points = np.array([spline[dim](t_for_dist) for dim in range(len(spline))]).T

        return uniform_points

    def _compute_equal_arc_length_spline_points(self, num_points=500):
        """
        Compute points along each cubic spline that are equally spaced by arc length.
        """
        self.cubic_splines = []
        for i, eq in enumerate(self.cubic_splines_eq):
            if eq is None:
                self.equal_spaced_splines.append(None)
                continue
            spline_points = self.get_uniformly_spaced_points(eq, num_points)
            self.cubic_splines.append(spline_points)

    def plot_path_3d(self, path_idx=0,dataset=None):
        """Plot dataset and one LPC path in 3D for visualization."""
        import matplotlib.pyplot as plt
        from mpl_toolkits.mplot3d import Axes3D

        dataset = np.array(dataset)
        path = self.paths[path_idx]
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        if dataset: #plot dataset if one is give, usually what you fit onto
            ax.scatter(dataset[:,0], dataset[:,1], dataset[:,2], alpha=0.5, label='Data')
        ax.plot(path[:,0], path[:,1], path[:,2], 'r-', label='Local Principal Curve')
        ax.legend()
        plt.show()

    def plot_cubic_spline_3d(self, path_idx, show_path=True):
        """
        Plot the cubic spline for a specific path in 3D.
        """
        import matplotlib.pyplot as plt
        from mpl_toolkits.mplot3d import Axes3D

        if path_idx >= len(self.paths):
            raise IndexError(f"Path index {path_idx} is out of range. Total paths: {len(self.paths)}.")

        path = self.cubic_splines[path_idx]

        spline_points = self.cubic_splines[path_idx]

        fig = plt.figure()
        ax = fig.add_subplot(111, projection="3d")

        if show_path:
            ax.scatter(path[:, 0], path[:, 1], path[:, 2], label="LPC Path", alpha=0.5)

        ax.plot(spline_points[:, 0], spline_points[:, 1], spline_points[:, 2], color="red", label="Cubic Spline")
        ax.legend()
        plt.show()

# Functions to Extract the Splines and perform Alignment

In [5]:
def extract_spline(splines_df, dataset_label, perturbation):
    sdf = splines_df[(splines_df["dataset"] == dataset_label) & (splines_df["Perturbation"] == perturbation)]
    sdf = sdf.sort_values("point_index")
    points = sdf[["PCA_1", "PCA_2", "PCA_3"]].values
    return points

def rmse(a, b):
    return np.sqrt(np.mean((a - b)**2))

def mean_l1_error(a, b):
# a and b are Nx3 arrays of points.
# Compute L1 distance for each point pair: sum of absolute differences across coordinates
# Then take the mean over all points.
    return np.mean(np.sum(np.abs(a - b), axis=1))
    
def centroid(X):
    return np.mean(X, axis=0)

def rmse(X, Y):
    return np.sqrt(np.mean(np.sum((X - Y)**2, axis=1)))

def quaternion_alignment(P, Q):
    """
    Compute the optimal rotation using quaternions that aligns Q onto P.
    Returns rotation matrix R and translation vector t.
    """
    # Ensure P and Q have the same shape
    assert P.shape == Q.shape, "P and Q must have the same shape"
    
    # 1. Compute centroids and center the points
    P_cent = centroid(P)
    Q_cent = centroid(Q)
    P_prime = P - P_cent
    Q_prime = Q - Q_cent
    
    # 2. Construct correlation matrix M
    M = Q_prime.T @ P_prime
    
    # 3. Construct the Kearsley (Davenport) 4x4 matrix K
    # Refer to the equations above
    A = np.array([
        [ M[0,0]+M[1,1]+M[2,2],   M[1,2]-M[2,1],         M[2,0]-M[0,2],         M[0,1]-M[1,0]       ],
        [ M[1,2]-M[2,1],         M[0,0]-M[1,1]-M[2,2],  M[0,1]+M[1,0],         M[0,2]+M[2,0]       ],
        [ M[2,0]-M[0,2],         M[0,1]+M[1,0],         M[1,1]-M[0,0]-M[2,2],  M[1,2]+M[2,1]       ],
        [ M[0,1]-M[1,0],         M[0,2]+M[2,0],         M[1,2]+M[2,1],         M[2,2]-M[0,0]-M[1,1]]
    ], dtype=np.float64)
    A = A / 3.0
    
    # 4. Find the eigenvector of A with the highest eigenvalue
    eigenvalues, eigenvectors = np.linalg.eigh(A)
    max_idx = np.argmax(eigenvalues)
    q = eigenvectors[:, max_idx]
    q = q / np.linalg.norm(q)
    
    # 5. Convert quaternion q into rotation matrix R
    # Quaternion format: q = [q0, q1, q2, q3]
    q0, q1, q2, q3 = q
    R = np.array([
        [q0**2 + q1**2 - q2**2 - q3**2, 2*(q1*q2 - q0*q3),         2*(q1*q3 + q0*q2)],
        [2*(q2*q1 + q0*q3),             q0**2 - q1**2 + q2**2 - q3**2, 2*(q2*q3 - q0*q1)],
        [2*(q3*q1 - q0*q2),             2*(q3*q2 + q0*q1),             q0**2 - q1**2 - q2**2 + q3**2]
    ])
    
    # 6. Compute translation
    t = P_cent - R @ Q_cent
    
    return R, t


In [6]:
merged_df = pd.read_csv("/net/trapnell/vol1/home/mdcolon/proj/morphseq/results/20241130/sweep_analysis/paired_models_and_metrics_df.csv")
merged_df_avg = merged_df[merged_df["Perturbation"]=="avg_pert"]

In [None]:
import numpy as np
import plotly.graph_objects as go
from matplotlib import pyplot as plt
import plotly.express as px

# pert_comparisons = ["wnt-i", "wt"]
pert_comparisons = ["wnt-i", "tgfb-i", "wt", "lmx1b", "gdf3"]

color_map = {
    "wnt-i": "red",
    "tgfb-i": "green",
    "wt": "blue",
    "lmx1b": "orange",
    "gdf3": "purple"
}

####                        ####
#### HERE WE LOAD  DATA     ####
####                        ####    
splines_final_dict = {}
scaffold_align_metrics = []

for model_index in [71,77,78]:
    print(model_index)
    path_all = merged_df_avg[merged_df_avg["model_index"]==model_index]["embryo_df_path_nohld"].iloc[0]
    path_hld = merged_df_avg[merged_df_avg["model_index"]==model_index]["embryo_df_path_hld"].iloc[0]

    print(merged_df_avg[merged_df_avg["model_index"]==model_index]["F1_score_all"].iloc[0])

    score    = merged_df_avg[merged_df_avg["model_index"]==model_index]["F1_score_all"].iloc[0]
    mweight  = merged_df_avg[merged_df_avg["model_index"]==model_index]["metric_weight"].iloc[0]
    timeonly = merged_df_avg[merged_df_avg["model_index"]==model_index]["time_only_flag"].iloc[0]
 
    df_all = pd.read_csv(path_all)
    df_hld = pd.read_csv(path_hld)

    title = f"PCA plot model_idx {model_index}: F1 score {score:.2f},mweight {mweight}, timeonly {timeonly}"

    print(title)

    # Define the comparisons (Multiclass) and obtain coloumns for data

    z_mu_columns = [col for col in df_all.columns if 'z_mu' in col]    
    z_mu_biological_columns = [col for col in z_mu_columns if "b" in col]


    # Dictionary to store spline points for each dataset and perturbation
    # Key: (dataset_label, perturbation), Value: array of spline points shape (num_spline_points, 3)
    splines_dict = {}


    # Replace this with your actual class or method for fitting local principal curves

    ####                               ####
    ####   HERE WE FIT THE SPLINE      ####
    ####   (start from ealy timepoint) ####    

    for df_label, df in [("all", df_all), ("hld", df_hld)]:
    
        X = df[z_mu_biological_columns].values
        pca = PCA(n_components=3)
        pcs = pca.fit_transform(X)

        perturbations = pert_comparisons

        # Map perturbations to colors
        if perturbations is None:
            perturbations = df['phenotype'].unique()
        color_discrete_map = {pert: px.colors.qualitative.Plotly[i % 10] for i, pert in enumerate(perturbations)}

        # Prepare the color array
        df['color'] = df['phenotype'].map(color_discrete_map)


        df["PCA_1"] = pcs[:,0]
        df["PCA_2"] = pcs[:,1]
        df["PCA_3"] = pcs[:,2]

        for pert in pert_comparisons:
            print(f"Processing {pert} in {df_label} dataset...")
            
            
            ##################
            #### First extact  pert and subset points for computational reason
            pert_df = df[df["phenotype"] == pert].reset_index(drop=True)

            # Calcualte ealy time point
            avg_early_timepoint = pert_df[
                (pert_df["predicted_stage_hpf"] >= pert_df["predicted_stage_hpf"].min()) &
                (pert_df["predicted_stage_hpf"] < pert_df["predicted_stage_hpf"].min() + 1)
            ][["PCA_1", "PCA_2", "PCA_3"]].mean().values

            # Downsampling logic
            if pert == "wt":
                pert_df_subset = pert_df.sample(frac=0.05, random_state=42)
            else:
                pert_df_subset = pert_df.sample(frac=.1, random_state=42)
            
            print(f"Subset size: {len(pert_df_subset)}")

            pert_3d_subset = pert_df_subset[["PCA_1", "PCA_2", "PCA_3"]].values
            ##################

            # Fit the Local Principal Curve on the subset
            lpc = LocalPrincipalCurve(bandwidth=.5, max_iter=500, tol=1e-4, angle_penalty_exp=2)
            paths = lpc.fit(pert_3d_subset, start_points=[avg_early_timepoint],remove_similar_end_start_points=True)

            # Extract the first path (assuming one main path)
            spline_points = lpc.cubic_splines[0]  # shape: (num_points, 3)
            # lpc.plot_cubic_spline_3d(0)

            # Store the spline points in the dictionary
            splines_dict[(df_label, pert)] = spline_points

    rows = []

    for (df_label, pert), spline_points in splines_dict.items():
        # spline_points is an array of shape (num_points, 3)
        for i, point in enumerate(spline_points[::-1], start=1):
            # point is [PCA_1, PCA_2, PCA_3]
            num_points = len(spline_points)
            rows.append({
                "dataset": df_label,
                "Perturbation": pert,
                "point_index":  num_points - i,
                "PCA_1": point[0],
                "PCA_2": point[1],
                "PCA_3": point[2]
            })

    # Convert to DataFrame
    splines_df = pd.DataFrame(rows)


    ####                               ####
    #### HERE WE ALIGN AND MEASURE FIT ####
    ####                               ####  

    splines_dict_aligned = []

    all_combined = []
    hld_combined = []
    hld_aligned_combined = []


    #### calculate values for all perts and then the scaffold ####
    for pert in pert_comparisons:
        all_points = extract_spline(splines_df, "all", pert)
        hld_points = extract_spline(splines_df, "hld", pert)

        # Perform Kabsch alignment with quaternion calcs
        R, t = quaternion_alignment(all_points, hld_points)
        hld_aligned = (hld_points @ R.T) + t  # Alignment transformation

        # Compute initial errors before alignment
        initial_rmse = rmse(all_points, hld_points)
        # Compute errors after alignment
        aligned_rmse = rmse(all_points, hld_aligned)
        
        # Accumulate for scaffold comparison
        all_combined.append(all_points)
        hld_combined.append(hld_points)
        hld_aligned_combined.append(hld_aligned)
        
        splines_dict_aligned.append({"Perturbation":pert, "spline":hld_aligned})
        # Record metrics for this perturbation
        scaffold_align_metrics.append({
            "model_index": model_index,
            'Perturbation': pert,
            'Initial_RMSE': initial_rmse,
            'Aligned_RMSE': aligned_rmse
        })


    # Concatenate all perturbation data for scaffold metrics
    all_combined = np.concatenate(all_combined, axis=0)
    hld_combined = np.concatenate(hld_combined, axis=0)
    hld_aligned_combined = np.concatenate(hld_aligned_combined, axis=0)

    # Compute scaffold-level errors
    scaffold_initial_rmse = rmse(all_combined, hld_combined)
    scaffold_aligned_rmse = rmse(all_combined, hld_aligned_combined)

    #apend this 
    scaffold_align_metrics.append({
        "model_index": model_index,
        'Perturbation': 'avg_pert',
        'Initial_RMSE': scaffold_initial_rmse,
        'Aligned_RMSE': scaffold_aligned_rmse
        })

    # @chatgpt make a warning  message if the error doesnt decrease but i dont need to print this out live:
    # # Print in the organized format (optional)
    # print(f"Perturbation: {pert}")
    # print(f"  Initial RMSE: {initial_rmse:.4f}, Aligned RMSE: {aligned_rmse:.4f}")
    # print(f"  Initial L1:   {initial_l1:.4f}, Aligned L1:   {aligned_l1:.4f}")
    # print("\nOverall Scaffold (Individual Kabsch per Perturbation):")
    # print(f"  Initial RMSE: {scaffold_initial_rmse:.4f}, Aligned RMSE: {scaffold_aligned_rmse:.4f}")
    # print(f"  Initial L1:   {scaffold_initial_l1:.4f}, Aligned L1:   {scaffold_aligned_l1:.4f}")

##########

    # Convert per-perturbation metrics to a DataFrame
    scaffold_align_metrics_df = pd.DataFrame(scaffold_align_metrics)

    # row object already created, heres logic to add it to the 
    for spline in splines_dict_aligned:
        # spline_points is an array of shape (num_points, 3)
        for i, point in enumerate(spline["spline"]):
            # point is [PCA_1, PCA_2, PCA_3]
            rows.append({
                "dataset": "hld_aligned",
                "Perturbation": spline["Perturbation"],
                "point_index": i,
                "PCA_1": point[0],
                "PCA_2": point[1],
                "PCA_3": point[2],
            })

    # Convert to DataFrame
    splines_final_df = pd.DataFrame(rows)

    splines_final_dict[model_index] = splines_final_df


splines_final_df_model_index = pd.concat(
[df.assign(model_index=model_index) for model_index, df in splines_final_dict.items()],
ignore_index=True
)
scaffold_align_metrics_df = pd.DataFrame(scaffold_align_metrics)




71
0.713932407655193


  df_all = pd.read_csv(path_all)
  df_hld = pd.read_csv(path_hld)


PCA plot model_idx 71: F1 score 0.71,mweight 25, timeonly 1
Processing wnt-i in all dataset...
Subset size: 187
Starting point not in dataset. Using closest point: [ 1.82168643  1.29012443 -0.26144278]
Processing tgfb-i in all dataset...
Subset size: 245
Starting point not in dataset. Using closest point: [ 1.83927206  1.28596941 -0.29201842]
Processing wt in all dataset...
Subset size: 2437
Starting point not in dataset. Using closest point: [ 1.76878228  1.7100162  -0.174527  ]
Processing lmx1b in all dataset...
Subset size: 776
Starting point not in dataset. Using closest point: [ 1.84002696  1.34803321 -0.23661037]
Processing gdf3 in all dataset...
Subset size: 747
Starting point not in dataset. Using closest point: [ 1.82765235  1.39959051 -0.05014554]
Processing wnt-i in hld dataset...
Subset size: 187
Starting point not in dataset. Using closest point: [ 1.99630838  1.3154439  -0.20895961]
Processing tgfb-i in hld dataset...
Subset size: 245
Starting point not in dataset. Using 

  df_all = pd.read_csv(path_all)
  df_hld = pd.read_csv(path_hld)


PCA plot model_idx 77: F1 score 0.78,mweight 50, timeonly 0
Processing wnt-i in all dataset...
Subset size: 187
Starting point not in dataset. Using closest point: [ 1.95702147  1.03924324 -0.65932128]
Processing tgfb-i in all dataset...
Subset size: 245
Starting point not in dataset. Using closest point: [ 2.00553128  1.20161037 -0.74285054]
Processing wt in all dataset...
Subset size: 2437
Starting point not in dataset. Using closest point: [ 1.64973178  1.44043901 -0.50209158]
Processing lmx1b in all dataset...
Subset size: 776
Starting point not in dataset. Using closest point: [ 1.92653635  1.1202651  -0.60469241]
Processing gdf3 in all dataset...
Subset size: 747
Starting point not in dataset. Using closest point: [ 1.67751863  0.91676987 -0.87060139]
Processing wnt-i in hld dataset...
Subset size: 187
Starting point not in dataset. Using closest point: [-1.76951613  1.13909611  1.3465403 ]
Processing tgfb-i in hld dataset...
Subset size: 245
Starting point not in dataset. Using 

  df_all = pd.read_csv(path_all)
  df_hld = pd.read_csv(path_hld)


PCA plot model_idx 78: F1 score 0.76,mweight 50, timeonly 1
Processing wnt-i in all dataset...
Subset size: 187
Starting point not in dataset. Using closest point: [-1.54920605  6.35138069  1.18860077]
Processing tgfb-i in all dataset...
Subset size: 245
Starting point not in dataset. Using closest point: [-1.69979066  5.84492894  0.92313777]
Processing wt in all dataset...
Subset size: 2437
Starting point not in dataset. Using closest point: [-3.93511559  7.42517696  7.88852508]
Processing lmx1b in all dataset...
Subset size: 776
Starting point not in dataset. Using closest point: [-0.78182236  6.34060658  4.69437493]
Processing gdf3 in all dataset...
Subset size: 747
Starting point not in dataset. Using closest point: [-2.6893847   6.23954375  3.61233928]
Processing wnt-i in hld dataset...
Subset size: 187
Starting point not in dataset. Using closest point: [-2.08180227  6.37145221  0.02439552]
Processing tgfb-i in hld dataset...
Subset size: 245
Starting point not in dataset. Using 

# Compute metrics other than MSE alignment 

In [28]:
def _segment_direction_metrics(data_a, data_b, k=10):
    """
    Compute SegmentColinearity and SegmentCovariance for two given sets of points `data_a` and `data_b`.
    Both data_a and data_b are np.ndarray of shape (n, 3).

    If there aren't enough points for k segments, returns (np.nan, np.nan).
    """
    min_len = min(len(data_a), len(data_b))
    data_a = data_a[:min_len]
    data_b = data_b[:min_len]

    if min_len < k + 1 or min_len == 0:
        return (np.nan, np.nan)

    # Define segments using data_b
    segment_indices = np.linspace(0, min_len - 1, k + 1, dtype=int)

    aligned_segment_vecs = []
    all_segment_vecs = []

    for i in range(k):
        start_idx = segment_indices[i]
        end_idx = segment_indices[i + 1]

        start_b = data_b[start_idx]
        end_b = data_b[end_idx]

        # Find closest points in data_a to start_b and end_b
        start_dists = np.linalg.norm(data_a - start_b, axis=1)
        closest_start_idx = np.argmin(start_dists)
        closest_start_a = data_a[closest_start_idx]

        end_dists = np.linalg.norm(data_a - end_b, axis=1)
        closest_end_idx = np.argmin(end_dists)
        closest_end_a = data_a[closest_end_idx]

        # Construct vectors
        vec_a = closest_end_a - closest_start_a
        vec_b = end_b - start_b

        # Normalize
        norm_a = np.linalg.norm(vec_a)
        norm_b = np.linalg.norm(vec_b)
        if norm_a > 0:
            vec_a = vec_a / norm_a
        else:
            vec_a = np.zeros(3)
        if norm_b > 0:
            vec_b = vec_b / norm_b
        else:
            vec_b = np.zeros(3)

        aligned_segment_vecs.append(vec_a)
        all_segment_vecs.append(vec_b)

    aligned_segment_vecs = np.array(aligned_segment_vecs)
    all_segment_vecs = np.array(all_segment_vecs)

    # Cosine similarities
    cos_sims = []
    for i in range(len(aligned_segment_vecs)):
        va = aligned_segment_vecs[i].reshape(1, -1)
        vb = all_segment_vecs[i].reshape(1, -1)
        sim = cosine_similarity(va, vb)[0][0]
        cos_sims.append(sim)

    avg_cosine_sim = np.mean(cos_sims) if len(cos_sims) > 0 else np.nan

    # Covariances
    covariances = []
    for dim_idx in range(3):
        dim_a = aligned_segment_vecs[:, dim_idx]
        dim_b = all_segment_vecs[:, dim_idx]
        if len(dim_a) > 1:
            cov = np.cov(dim_a, dim_b, bias=True)[0, 1]
        else:
            cov = np.nan
        covariances.append(cov)
    avg_cov = np.nanmean(covariances) if len(covariances) > 0 else np.nan

    return (avg_cosine_sim, avg_cov)


    # Split the dataset into 'all' and 'hld_aligned'
    splines_all = splines_final_df[splines_final_df["dataset"] == "all"]
    splines_hld_aligned = splines_final_df[splines_final_df["dataset"] == "hld_aligned"]

def segment_direction_consistency(splines_final_df, k=10):
    """
    Step 1 (Across): For each perturbation present in both datasets, compute SegmentColinearity and SegmentCovariance
    by comparing splines_hld_aligned and splines_all.

    Step 2 (Within): Compute these metrics for all unique pairs of perturbations within each dataset
    (both splines_hld_aligned and splines_all separately).
    Then compute the mean and std of these pairwise metrics for each dataset.

    Returns:
    - across_df: DataFrame with ['Perturbation', 'SegmentColinearity', 'SegmentCovariance']
    - within_hld_aligned_df: DataFrame with ['Metric', 'Mean', 'Std'] for pairwise metrics within splines_hld_aligned
    - within_all_df: DataFrame with ['Metric', 'Mean', 'Std'] for pairwise metrics within splines_all
    """

    splines_all = splines_final_df[splines_final_df["dataset"] == "all"]
    splines_hld = splines_final_df[splines_final_df["dataset"] == "hld"]
    splines_hld_aligned = splines_final_df[splines_final_df["dataset"] == "hld_aligned"]
    
    pca_columns = ["PCA_1", "PCA_2", "PCA_3"]
    for col in pca_columns:
        if col not in splines_hld_aligned.columns or col not in splines_all.columns:
            raise ValueError(f"Missing required PCA column: {col}")

            

    # Across computations
    perts_aligned = set(splines_hld_aligned["Perturbation"].unique())
    perts_all = set(splines_all["Perturbation"].unique())
    common_perts = perts_aligned.intersection(perts_all)

    across_results = []
    for pert in common_perts:
        data_a_df = splines_hld_aligned[splines_hld_aligned["Perturbation"] == pert].sort_values("point_index")
        data_b_df = splines_all[splines_all["Perturbation"] == pert].sort_values("point_index")
        data_a = data_a_df[pca_columns].values
        data_b = data_b_df[pca_columns].values

        sim, cov = _segment_direction_metrics(data_a, data_b, k=k)
        across_results.append({"Perturbation": pert, "SegmentColinearity": sim, "SegmentCovariance": cov})

    across_df = pd.DataFrame(across_results)

    # Calculate column means (excluding the Perturbation column)
    mean_row = across_df.iloc[:, 1:].mean()
    mean_row["Perturbation"] = "avg_pert"

    # Append the mean row to the DataFrame
    across_df = pd.concat([across_df, pd.DataFrame([mean_row])], ignore_index=True)


    # Within computations for splines_hld
    perts_in_aligned = list(perts_aligned)
    within_values_colinearity_hld = []
    within_values_covariance_hld  = []

    for i in range(len(perts_in_aligned)):
        for j in range(i+1, len(perts_in_aligned)):
            pert1 = perts_in_aligned[i]
            pert2 = perts_in_aligned[j]

            data_pert1 = splines_hld[splines_hld["Perturbation"] == pert1].sort_values("point_index")[pca_columns].values
            data_pert2 = splines_hld[splines_hld["Perturbation"] == pert2].sort_values("point_index")[pca_columns].values

            sim, cov = _segment_direction_metrics(data_pert1, data_pert2, k=k)
            if not np.isnan(sim):
                within_values_colinearity_hld.append(sim)
            if not np.isnan(cov):
                within_values_covariance_hld.append(cov)

    metrics_hld = []
    for metric_name, vals in [("SegmentColinearity", within_values_colinearity_hld), 
                              ("SegmentCovariance",  within_values_covariance_hld)]:
        mean_val = np.nanmean(vals) if len(vals) > 0 else np.nan
        std_val = np.nanstd(vals) if len(vals) > 0 else np.nan
        metrics_hld.append({"Metric": metric_name, "Mean": mean_val, "Std": std_val})

    within_hld_df = pd.DataFrame(metrics_hld)

    # Within computations for splines_all
    perts_in_all = list(perts_all)
    within_values_colinearity_all = []
    within_values_covariance_all = []

    for i in range(len(perts_in_all)):
        for j in range(i+1, len(perts_in_all)):
            pert1 = perts_in_all[i]
            pert2 = perts_in_all[j]

            data_pert1 = splines_all[splines_all["Perturbation"] == pert1].sort_values("point_index")[pca_columns].values
            data_pert2 = splines_all[splines_all["Perturbation"] == pert2].sort_values("point_index")[pca_columns].values

            sim, cov = _segment_direction_metrics(data_pert1, data_pert2, k=k)
            if not np.isnan(sim):
                within_values_colinearity_all.append(sim)
            if not np.isnan(cov):
                within_values_covariance_all.append(cov)

    metrics_all_list = []
    for metric_name, vals in [("SegmentColinearity", within_values_colinearity_all), 
                              ("SegmentCovariance", within_values_covariance_all)]:
        mean_val = np.nanmean(vals) if len(vals) > 0 else np.nan
        std_val = np.nanstd(vals) if len(vals) > 0 else np.nan
        metrics_all_list.append({"Metric": metric_name, "Mean": mean_val, "Std": std_val})

    within_all_df = pd.DataFrame(metrics_all_list)

    return across_df, within_hld_df, within_all_df

def calculate_dispersion_metrics(splines_final_df, n=5):
    """
    Calculates dispersion metrics for each dataset, including:
    - Dispersion Coefficient (slope of dispersion vs. point_index, normalized to [0, 1])
    - Initial Dispersion (average dispersion of the first n points)
    - Last Dispersion (average dispersion of the last n points)

    Parameters:
    - splines_final_df (pd.DataFrame): DataFrame containing all PCA trajectories with 'dataset' column.
    - n (int): Number of initial and last points to consider for initial and last dispersion.

    Returns:
    - pd.DataFrame: DataFrame with columns ['Dataset', 'disp_coefficient', 'dispersion_first_n', 'dispersion_last_n'].
    """
    # Extract subsets
    splines_all = splines_final_df[splines_final_df["dataset"] == "all"]
    splines_hld = splines_final_df[splines_final_df["dataset"] == "hld"]
    splines_hld_aligned = splines_final_df[splines_final_df["dataset"] == "hld_aligned"]

    # Ensure PCA columns are present
    pca_columns = ["PCA_1", "PCA_2", "PCA_3"]
    for col in pca_columns:
        if col not in splines_final_df.columns:
            raise ValueError(f"Missing required PCA column: {col}")

    # Get unique datasets
    datasets = splines_final_df["dataset"].unique()

    # Initialize list to store results
    results = []

    for dataset in datasets:
        if dataset == "hld_aligned":
            continue
        # Filter data for the current dataset
        dataset_df = splines_final_df[splines_final_df["dataset"] == dataset]

        # Get unique point_indices
        point_indices = sorted(dataset_df["point_index"].unique())

        # Initialize lists to store dispersion and point_index
        dispersion_list = []
        point_index_list = []

        # Initialize lists to store initial and last dispersions
        initial_dispersions = []
        last_dispersions = []

        for pid in point_indices:
            # Filter data for the current point_index
            point_df = dataset_df[dataset_df["point_index"] == pid]

            # Calculate dispersion: average Euclidean distance from centroid
            dispersion = compute_dispersion(point_df, pca_columns)

            # Append to lists
            dispersion_list.append(dispersion)
            point_index_list.append(pid)

            # If within first n points, store for initial dispersion
            if pid < n:
                initial_dispersions.append(dispersion)

            # If within last n points, store for last dispersion
            if pid >= max(point_indices) - n + 1:
                last_dispersions.append(dispersion)

        # Check if there are enough points for regression
        if len(point_index_list) < 2:
            print(f"Warning: Dataset '{dataset}' has less than 2 unique point_indices. Setting disp_coefficient to NaN.")
            disp_coefficient = np.nan
        else:
            # Prepare data for linear regression
            X = np.array(point_index_list).reshape(-1, 1)  # Shape: (num_points, 1)
            y = np.array(dispersion_list)  # Shape: (num_points,)

            # Fit linear regression
            reg = LinearRegression().fit(X, y)
            disp_coefficient = reg.coef_[0]
            disp_coefficient *= len(point_indices)  # Normalize to [0, 1]

        # Calculate average initial dispersion
        dispersion_first_n = np.mean(initial_dispersions) if initial_dispersions else np.nan
        if np.isnan(dispersion_first_n):
            print(f"Warning: Dataset '{dataset}' has no points within the first {n} point_indices.")

        # Calculate average last dispersion
        dispersion_last_n = np.mean(last_dispersions) if last_dispersions else np.nan
        if np.isnan(dispersion_last_n):
            print(f"Warning: Dataset '{dataset}' has no points within the last {n} point_indices.")

        # Append results
        results.append({
            "Dataset": dataset,
            "disp_coefficient": disp_coefficient,
            "dispersion_first_n": dispersion_first_n,
            "dispersion_last_n": dispersion_last_n
        })

    # Convert results to DataFrame
    results_df = pd.DataFrame(results)



    return results_df

def compute_dispersion(df, pca_columns):
    """
    Computes the average Euclidean distance of points from their centroid.
    
    Parameters:
    - df (pd.DataFrame): DataFrame containing PCA coordinates.
    - pca_columns (list): List of PCA column names.
    
    Returns:
    - float: Average Euclidean distance (dispersion).
    """
    if df.empty:
        return np.nan
    
    # Calculate centroid
    centroid = df[pca_columns].mean().values
    
    # Calculate Euclidean distances from centroid
    distances = np.linalg.norm(df[pca_columns].values - centroid, axis=1)
    
    # Return average distance
    return distances.mean()


In [41]:

for key in splines_final_df_model_index["model_index"].unique():
    print(key)
    splines_final_df = splines_final_df_model_index[splines_final_df_model_index["model_index"] == key]

    splines_all = splines_final_df[splines_final_df["dataset"]=="all"]
    splines_hld_aligned = splines_final_df[splines_final_df["dataset"]=="hld_aligned"]

    # -------------------------------------
    # Analyze Datasets
    # -------------------------------------
    # Segment Direction Consistency
    # After loading or creating splines_hld_aligned and splines_all DataFrames:
    across_seg_df, within_hld_seg_df, within_all_seg_df = segment_direction_consistency(splines_final_df, k=100)
    

    # Add the model_index column with the key value
    across_seg_df.insert(0, "model_index", key)

    # Create a new DataFrame to collect renamed columns
    within_hld_renamed = within_hld_seg_df[["Metric", "Mean"]].copy()
    within_hld_renamed["Metric"] += "_mean_within_hld"  # Add suffix
    within_hld_renamed = within_hld_renamed.set_index("Metric").T  # Transpose for easy appending

    within_all_renamed = within_all_seg_df[["Metric", "Mean"]].copy()
    within_all_renamed["Metric"] += "_mean_within_all"  # Add suffix
    within_all_renamed = within_all_renamed.set_index("Metric").T  # Transpose for easy appending

    # Combine the renamed DataFrames into one row
    within_seg_measures = pd.concat([within_hld_renamed, within_all_renamed], axis=1)

    within_seg_measures.insert(0, "model_index", key)
    print(within_seg_measures)

    # Calculate Dispersion Metrics
    dispersion_metrics_df = calculate_dispersion_metrics(splines_final_df, n=5)


    # Split and rename columns based on Dataset
    disp_all = dispersion_metrics_df[dispersion_metrics_df["Dataset"] == "all"].drop("Dataset", axis=1)
    disp_all.columns = [col + "_all" for col in disp_all.columns]
    disp_all = disp_all.reset_index(drop=True)

    disp_hld = dispersion_metrics_df[dispersion_metrics_df["Dataset"] == "hld"].drop("Dataset", axis=1)
    disp_hld.columns = [col + "_hld" for col in disp_hld.columns]
    disp_hld = disp_hld.reset_index(drop=True)

    # Combine into a single row
    combined_dispersion_df = pd.concat([disp_all, disp_hld], axis=1)

    # Add the model_index column
    combined_dispersion_df.insert(0, "model_index", key)

    # Final output
    print("Dispersion Metrics for Each Dataset:")
    print(combined_dispersion_df)




    # fig = plot_trajectories_3d(splines_final_df)
    # fig.write_html(f"/net/trapnell/vol1/home/mdcolon/proj/morphseq/results/20241211/splines_aligned_{key}_quaternion_2.html")


71
Metric  model_index  SegmentColinearity_mean_within_hld  \
Mean             71                            0.746983   

Metric  SegmentCovariance_mean_within_hld  SegmentColinearity_mean_within_all  \
Mean                             0.172636                            0.777472   

Metric  SegmentCovariance_mean_within_all  
Mean                             0.165758  
Dispersion Metrics for Each Dataset:
   model_index  disp_coefficient_all  dispersion_first_n_all  \
0           71              1.358857                0.116642   

   dispersion_last_n_all  disp_coefficient_hld  dispersion_first_n_hld  \
0               1.357972              1.514239                0.101214   

   dispersion_last_n_hld  
0               1.592155  
77
Metric  model_index  SegmentColinearity_mean_within_hld  \
Mean             77                            0.500365   

Metric  SegmentCovariance_mean_within_hld  SegmentColinearity_mean_within_all  \
Mean                             0.106559              

In [None]:
import pandas as pd

# -------------------------------
# Helper Functions
# -------------------------------

def rename_within_metrics(df, suffix, key):
    """Renames columns in within metrics DataFrame with a given suffix."""
    renamed_df = df[["Metric", "Mean"]].copy()
    renamed_df["Metric"] += suffix  # Add suffix
    renamed_df = renamed_df.set_index("Metric").T  # Transpose for easy appending
    renamed_df.insert(0, "model_index", key)  # Add model_index
    return renamed_df

def process_dispersion_metrics(df, key):
    """Processes and renames dispersion metrics DataFrame."""
    disp_all = df[df["Dataset"] == "all"].drop("Dataset", axis=1)
    disp_all.columns = [col + "_all" for col in disp_all.columns]
    disp_hld = df[df["Dataset"] == "hld"].drop("Dataset", axis=1)
    disp_hld.columns = [col + "_hld" for col in disp_hld.columns]
    
    combined_df = pd.concat([disp_all.reset_index(drop=True), disp_hld.reset_index(drop=True)], axis=1)
    combined_df.insert(0, "model_index", key)  # Add model_index
    return combined_df

def process_segment_direction(splines_final_df, key):
    """Calculates and processes segment direction consistency metrics."""
    across_seg_df, within_hld_seg_df, within_all_seg_df = segment_direction_consistency(splines_final_df, k=100)
    across_seg_df.insert(0, "model_index", key)  # Add model_index
    
    within_hld_renamed = rename_within_metrics(within_hld_seg_df, "_mean_within_hld", key)
    within_all_renamed = rename_within_metrics(within_all_seg_df, "_mean_within_all", key)
    
    within_seg_measures = pd.concat([within_hld_renamed, within_all_renamed], axis=1)
    return across_seg_df, within_seg_measures

def combine_results_dict(results_dict):
    """
    Combines the results dictionary into a single DataFrame.
    Handles duplicate 'model_index' columns by ensuring uniqueness during merge.
    """
    final_list_of_dfs = []
    
    for model_index, metrics in results_dict.items():
        # Start with across_seg_df as the base since it has multiple perturbations
        if "across_seg_df" not in metrics:
            continue  # If for some reason this key doesn't have across_seg_df, skip
        
        base_df = metrics["across_seg_df"].copy()

        # Drop duplicate 'model_index' columns from other metrics before merging
        if "within_seg_measures" in metrics:
            temp_within = metrics["within_seg_measures"].copy()
            temp_within = temp_within.loc[:, ~temp_within.columns.duplicated()]  # Remove duplicate columns
            base_df = base_df.merge(temp_within, on="model_index", how="left")
        
        if "dispersion_metrics" in metrics:
            temp_disp = metrics["dispersion_metrics"].copy()
            temp_disp = temp_disp.loc[:, ~temp_disp.columns.duplicated()]  # Remove duplicate columns
            base_df = base_df.merge(temp_disp, on="model_index", how="left")

        # Append to list
        final_list_of_dfs.append(base_df)

    # Concatenate all model results
    if final_list_of_dfs:
        final_results_df = pd.concat(final_list_of_dfs, ignore_index=True)
    else:
        final_results_df = pd.DataFrame()

    return final_results_df

# Example usage:
# results_df = combine_results_dict(results_dict)
# This will produce a DataFrame with each row corresponding to a (model_index, Perturbation) pair,
# and columns from across_seg_df, within_seg_measures, and dispersion_metrics.

# -------------------------------
# Main Loop
# -------------------------------

results_dict = {}

for key in splines_final_df_model_index["model_index"].unique():
    print(f"Processing model_index: {key}")
    
    # Filter data for the current model index
    splines_final_df = splines_final_df_model_index[splines_final_df_model_index["model_index"] == key]
    
    # Process segment direction consistency
    across_seg_df, within_seg_measures = process_segment_direction(splines_final_df, key)
    print("Segment Direction Measures:")
    print(within_seg_measures)

    # Calculate dispersion metrics
    dispersion_metrics_df = calculate_dispersion_metrics(splines_final_df, n=5)
    combined_dispersion_df = process_dispersion_metrics(dispersion_metrics_df, key)
    print("Dispersion Metrics for Each Dataset:")
    print(combined_dispersion_df)

    # Store results in a dictionary
    results_dict[key] = {
        "across_seg_df": across_seg_df,
        "within_seg_measures": within_seg_measures,
        "dispersion_metrics": combined_dispersion_df
    }

# -------------------------------
# Access Results Example
# -------------------------------
results_df = combine_results_dict(results_dict)
results_df

scaffold_align_metrics_df.merge(results_df, on=["model_index", "Perturbation"], how="left")

Processing model_index: 71
Segment Direction Measures:
Metric  model_index  SegmentColinearity_mean_within_hld  \
Mean             71                            0.746983   

Metric  SegmentCovariance_mean_within_hld  model_index  \
Mean                             0.172636           71   

Metric  SegmentColinearity_mean_within_all  SegmentCovariance_mean_within_all  
Mean                              0.777472                           0.165758  
Dispersion Metrics for Each Dataset:
   model_index  disp_coefficient_all  dispersion_first_n_all  \
0           71              1.358857                0.116642   

   dispersion_last_n_all  disp_coefficient_hld  dispersion_first_n_hld  \
0               1.357972              1.514239                0.101214   

   dispersion_last_n_hld  
0               1.592155  
Processing model_index: 77
Segment Direction Measures:
Metric  model_index  SegmentColinearity_mean_within_hld  \
Mean             77                            0.500365   

Metric

Unnamed: 0,model_index,Perturbation,SegmentColinearity,SegmentCovariance,SegmentColinearity_mean_within_hld,SegmentCovariance_mean_within_hld,SegmentColinearity_mean_within_all,SegmentCovariance_mean_within_all,disp_coefficient_all,dispersion_first_n_all,dispersion_last_n_all,disp_coefficient_hld,dispersion_first_n_hld,dispersion_last_n_hld
0,71,tgfb-i,0.979616,0.183081,0.746983,0.172636,0.777472,0.165758,1.358857,0.116642,1.357972,1.514239,0.101214,1.592155
1,71,wnt-i,0.961767,0.145345,0.746983,0.172636,0.777472,0.165758,1.358857,0.116642,1.357972,1.514239,0.101214,1.592155
2,71,gdf3,0.969415,0.195387,0.746983,0.172636,0.777472,0.165758,1.358857,0.116642,1.357972,1.514239,0.101214,1.592155
3,71,lmx1b,0.919742,0.245564,0.746983,0.172636,0.777472,0.165758,1.358857,0.116642,1.357972,1.514239,0.101214,1.592155
4,71,wt,0.993592,0.285032,0.746983,0.172636,0.777472,0.165758,1.358857,0.116642,1.357972,1.514239,0.101214,1.592155
5,71,avg_pert,0.964826,0.210882,0.746983,0.172636,0.777472,0.165758,1.358857,0.116642,1.357972,1.514239,0.101214,1.592155
6,77,tgfb-i,0.58198,0.132881,0.500365,0.106559,0.600711,0.135649,1.751113,0.34721,1.716981,1.388928,0.574597,2.06553
7,77,wnt-i,0.878107,0.066824,0.500365,0.106559,0.600711,0.135649,1.751113,0.34721,1.716981,1.388928,0.574597,2.06553
8,77,gdf3,0.716081,0.155529,0.500365,0.106559,0.600711,0.135649,1.751113,0.34721,1.716981,1.388928,0.574597,2.06553
9,77,lmx1b,0.521959,0.159994,0.500365,0.106559,0.600711,0.135649,1.751113,0.34721,1.716981,1.388928,0.574597,2.06553


In [51]:
# # Perform the merge on the "model_index" column
# merged_df = scaffold_align_metrics_df.merge(results_df, on="model_index", how="left")

# # Display the merged DataFrame
# print("Merged DataFrame:")
# print(merged_df)

In [53]:
scaffold_align_metrics_df.merge(results_df, on=["model_index", "Perturbation"], how="left")

Unnamed: 0,model_index,Perturbation,Initial_RMSE,Aligned_RMSE,SegmentColinearity,SegmentCovariance,SegmentColinearity_mean_within_hld,SegmentCovariance_mean_within_hld,SegmentColinearity_mean_within_all,SegmentCovariance_mean_within_all,disp_coefficient_all,dispersion_first_n_all,dispersion_last_n_all,disp_coefficient_hld,dispersion_first_n_hld,dispersion_last_n_hld
0,71,wnt-i,0.234647,0.186641,0.961767,0.145345,0.746983,0.172636,0.777472,0.165758,1.358857,0.116642,1.357972,1.514239,0.101214,1.592155
1,71,tgfb-i,0.292043,0.190456,0.979616,0.183081,0.746983,0.172636,0.777472,0.165758,1.358857,0.116642,1.357972,1.514239,0.101214,1.592155
2,71,wt,0.194967,0.147977,0.993592,0.285032,0.746983,0.172636,0.777472,0.165758,1.358857,0.116642,1.357972,1.514239,0.101214,1.592155
3,71,lmx1b,0.452274,0.289614,0.919742,0.245564,0.746983,0.172636,0.777472,0.165758,1.358857,0.116642,1.357972,1.514239,0.101214,1.592155
4,71,gdf3,0.28164,0.124806,0.969415,0.195387,0.746983,0.172636,0.777472,0.165758,1.358857,0.116642,1.357972,1.514239,0.101214,1.592155
5,71,avg_pert,0.30405,0.196193,0.964826,0.210882,0.746983,0.172636,0.777472,0.165758,1.358857,0.116642,1.357972,1.514239,0.101214,1.592155
6,77,wnt-i,4.639583,0.312645,0.878107,0.066824,0.500365,0.106559,0.600711,0.135649,1.751113,0.34721,1.716981,1.388928,0.574597,2.06553
7,77,tgfb-i,3.798049,0.830231,0.58198,0.132881,0.500365,0.106559,0.600711,0.135649,1.751113,0.34721,1.716981,1.388928,0.574597,2.06553
8,77,wt,3.238206,0.821606,0.789793,0.23671,0.500365,0.106559,0.600711,0.135649,1.751113,0.34721,1.716981,1.388928,0.574597,2.06553
9,77,lmx1b,2.721069,1.372218,0.521959,0.159994,0.500365,0.106559,0.600711,0.135649,1.751113,0.34721,1.716981,1.388928,0.574597,2.06553


In [74]:
import numpy as np
import plotly.graph_objects as go
from matplotlib import pyplot as plt
import plotly.express as px
import pandas as pd
from sklearn.decomposition import PCA

# pert_comparisons = ["wnt-i", "wt"]
pert_comparisons = ["wnt-i", "tgfb-i", "wt", "lmx1b", "gdf3"]

color_map = {
    "wnt-i": "red",
    "tgfb-i": "green",
    "wt": "blue",
    "lmx1b": "orange",
    "gdf3": "purple"
}

splines_final_dict = {}
scaffold_align_metrics = []
results_dict = {}

# Example model_indices, adjust as needed
model_indices = [71]

for model_index in model_indices:
    print(model_index)
    path_all = merged_df_avg[merged_df_avg["model_index"] == model_index]["embryo_df_path_nohld"].iloc[0]
    path_hld = merged_df_avg[merged_df_avg["model_index"] == model_index]["embryo_df_path_hld"].iloc[0]

    score = merged_df_avg[merged_df_avg["model_index"] == model_index]["F1_score_all"].iloc[0]
    mweight = merged_df_avg[merged_df_avg["model_index"] == model_index]["metric_weight"].iloc[0]
    timeonly = merged_df_avg[merged_df_avg["model_index"] == model_index]["time_only_flag"].iloc[0]

    df_all = pd.read_csv(path_all)
    df_hld = pd.read_csv(path_hld)

    # Identify biological Z columns
    z_mu_columns = [col for col in df_all.columns if 'z_mu' in col]    
    z_mu_biological_columns = [col for col in z_mu_columns if "b" in col]

    # Dictionary to store dataframes with PCA columns added
    data_dict = {}

    # Compute PCA and augment dataframes for both "all" and "hld"
    for df_label, df_raw in [("all", df_all), ("hld", df_hld)]:
        X = df_raw[z_mu_biological_columns].values
        pca = PCA(n_components=3)
        pcs = pca.fit_transform(X)

        df_raw["PCA_1"] = pcs[:,0]
        df_raw["PCA_2"] = pcs[:,1]
        df_raw["PCA_3"] = pcs[:,2]

        # Color mapping for perturbations
        perturbations = pert_comparisons
        color_discrete_map = {pert: px.colors.qualitative.Plotly[i % 10] for i, pert in enumerate(perturbations)}
        df_raw['color'] = df_raw['phenotype'].map(color_discrete_map)

        # Store the augmented dataframe
        data_dict[df_label] = df_raw

    # Dictionary to store spline points for each dataset and perturbation
    # Key: (df_label, pert), Value: array of spline points shape (num_spline_points, 3)
    splines_dict = {}

    # Fit splines for each perturbation and dataset in a single combined loop
    for pert in pert_comparisons:
        for df_label, df in data_dict.items():
            print(f"Processing {pert} in {df_label} dataset...")

            pert_df = df[df["phenotype"] == pert].reset_index(drop=True)

            # Calculate early time point
            avg_early_timepoint = pert_df[
                (pert_df["predicted_stage_hpf"] >= pert_df["predicted_stage_hpf"].min()) &
                (pert_df["predicted_stage_hpf"] < pert_df["predicted_stage_hpf"].min() + 1)
            ][["PCA_1", "PCA_2", "PCA_3"]].mean().values

            # Downsampling logic
            if pert == "wt":
                pert_df_subset = pert_df.sample(frac=0.05, random_state=42)
            else:
                pert_df_subset = pert_df.sample(frac=0.1, random_state=42)

            print(f"Subset size: {len(pert_df_subset)}")

            pert_3d_subset = pert_df_subset[["PCA_1", "PCA_2", "PCA_3"]].values

            # Fit the Local Principal Curve on the subset
            lpc = LocalPrincipalCurve(bandwidth=.5, max_iter=500, tol=1e-4, angle_penalty_exp=2)
            paths = lpc.fit(pert_3d_subset, start_points=[avg_early_timepoint], remove_similar_end_start_points=True)

            # Extract the first path (assuming one main path)
            spline_points = lpc.cubic_splines[0]  # shape: (num_points, 3)
            splines_dict[(df_label, pert)] = spline_points

    # Convert spline data to a DataFrame
    rows = []
    for (df_label, pert), spline_points in splines_dict.items():
        num_points = len(spline_points)
        for i, point in enumerate(spline_points[::-1], start=1):
            rows.append({
                "dataset": df_label,
                "Perturbation": pert,
                "point_index": num_points - i,
                "PCA_1": point[0],
                "PCA_2": point[1],
                "PCA_3": point[2]
            })

    splines_df = pd.DataFrame(rows)

    # Alignment and scaffold metrics
    splines_dict_aligned = []
    all_combined = []
    hld_combined = []
    hld_aligned_combined = []

    for pert in pert_comparisons:
        all_points = extract_spline(splines_df, "all", pert)
        hld_points = extract_spline(splines_df, "hld", pert)

        # Perform Kabsch alignment
        R, t = quaternion_alignment(all_points, hld_points)
        hld_aligned = (hld_points @ R.T) + t

        # Compute errors
        initial_rmse = rmse(all_points, hld_points)
        aligned_rmse = rmse(all_points, hld_aligned)

        # Accumulate for scaffold comparison
        all_combined.append(all_points)
        hld_combined.append(hld_points)
        hld_aligned_combined.append(hld_aligned)

        splines_dict_aligned.append({"Perturbation": pert, "spline": hld_aligned})
        scaffold_align_metrics.append({
            "model_index": model_index,
            'Perturbation': pert,
            'Initial_RMSE': initial_rmse,
            'Aligned_RMSE': aligned_rmse
        })

    # Compute scaffold-level metrics
    all_combined = np.concatenate(all_combined, axis=0)
    hld_combined = np.concatenate(hld_combined, axis=0)
    hld_aligned_combined = np.concatenate(hld_aligned_combined, axis=0)

    scaffold_initial_rmse = rmse(all_combined, hld_combined)
    scaffold_aligned_rmse = rmse(all_combined, hld_aligned_combined)

    scaffold_align_metrics.append({
        "model_index": model_index,
        'Perturbation': 'avg_pert',
        'Initial_RMSE': scaffold_initial_rmse,
        'Aligned_RMSE': scaffold_aligned_rmse
    })

    scaffold_align_metrics_df = pd.DataFrame(scaffold_align_metrics)

    # Add aligned spline points to DataFrame
    for spline in splines_dict_aligned:
        for i, point in enumerate(spline["spline"]):
            rows.append({
                "dataset": "hld_aligned",
                "Perturbation": spline["Perturbation"],
                "point_index": i,
                "PCA_1": point[0],
                "PCA_2": point[1],
                "PCA_3": point[2],
            })

    #store the spline in dict
    splines_final_df = pd.DataFrame(rows)
    splines_final_dict[model_index] = splines_final_df
    
    # Process segment direction consistency
    across_seg_df, within_seg_measures = process_segment_direction(splines_final_df, model_index)
    print("Segment Direction Measures:")
    print(within_seg_measures)

    # Calculate dispersion metrics
    dispersion_metrics_df = calculate_dispersion_metrics(splines_final_df, n=5)
    combined_dispersion_df = process_dispersion_metrics(dispersion_metrics_df, model_index)
    print("Dispersion Metrics for Each Dataset:")
    print(combined_dispersion_df)

    # Store results in a dictionary
    results_dict[model_index] = {
        "across_seg_df": across_seg_df,
        "within_seg_measures": within_seg_measures,
        "dispersion_metrics": combined_dispersion_df
    }
    
    results_df = combine_results_dict(results_dict)

    scaffold_align_metrics_df_final = scaffold_align_metrics_df.merge(results_df, on=["model_index", "Perturbation"], how="left")

scaffold_align_metrics_df_final
    


71



Columns (2) have mixed types. Specify dtype option on import or set low_memory=False.


Columns (2) have mixed types. Specify dtype option on import or set low_memory=False.



PCA plot model_idx 71: F1 score 0.71, mweight 25, timeonly 1
Processing wnt-i in all dataset...
Subset size: 187
Starting point not in dataset. Using closest point: [ 1.82168643  1.29012443 -0.26144278]
Processing wnt-i in hld dataset...
Subset size: 187
Starting point not in dataset. Using closest point: [ 1.99630838  1.3154439  -0.20895961]
Processing tgfb-i in all dataset...
Subset size: 245
Starting point not in dataset. Using closest point: [ 1.83927206  1.28596941 -0.29201842]
Processing tgfb-i in hld dataset...
Subset size: 245
Starting point not in dataset. Using closest point: [ 2.04486115  1.44040819 -0.25376278]
Processing wt in all dataset...
Subset size: 2437
Starting point not in dataset. Using closest point: [ 1.76878228  1.7100162  -0.174527  ]
Processing wt in hld dataset...
Subset size: 2437
Starting point not in dataset. Using closest point: [ 1.96912577  1.7797467  -0.14410615]
Processing lmx1b in all dataset...
Subset size: 776
Starting point not in dataset. Using 

Unnamed: 0,model_index,Perturbation,Initial_RMSE,Aligned_RMSE,SegmentColinearity,SegmentCovariance,SegmentColinearity_mean_within_hld,SegmentCovariance_mean_within_hld,SegmentColinearity_mean_within_all,SegmentCovariance_mean_within_all,disp_coefficient_all,dispersion_first_n_all,dispersion_last_n_all,disp_coefficient_hld,dispersion_first_n_hld,dispersion_last_n_hld
0,71,wnt-i,0.234647,0.186641,0.961767,0.145345,0.746983,0.172636,0.777472,0.165758,1.358857,0.116642,1.357972,1.514239,0.101214,1.592155
1,71,tgfb-i,0.292043,0.190456,0.979616,0.183081,0.746983,0.172636,0.777472,0.165758,1.358857,0.116642,1.357972,1.514239,0.101214,1.592155
2,71,wt,0.194967,0.147977,0.993592,0.285032,0.746983,0.172636,0.777472,0.165758,1.358857,0.116642,1.357972,1.514239,0.101214,1.592155
3,71,lmx1b,0.452274,0.289614,0.919742,0.245564,0.746983,0.172636,0.777472,0.165758,1.358857,0.116642,1.357972,1.514239,0.101214,1.592155
4,71,gdf3,0.28164,0.124806,0.969415,0.195387,0.746983,0.172636,0.777472,0.165758,1.358857,0.116642,1.357972,1.514239,0.101214,1.592155
5,71,avg_pert,0.30405,0.196193,0.964826,0.210882,0.746983,0.172636,0.777472,0.165758,1.358857,0.116642,1.357972,1.514239,0.101214,1.592155


In [71]:
scaffold_align_metrics_df.merge(results_df, on=["model_index", "Perturbation"], how="left")

Unnamed: 0,model_index,Perturbation,Initial_RMSE,Aligned_RMSE,SegmentColinearity,SegmentCovariance,SegmentColinearity_mean_within_hld,SegmentCovariance_mean_within_hld,SegmentColinearity_mean_within_all,SegmentCovariance_mean_within_all,disp_coefficient_all,dispersion_first_n_all,dispersion_last_n_all,disp_coefficient_hld,dispersion_first_n_hld,dispersion_last_n_hld
0,71,wnt-i,0.234647,0.186641,,,,,,,,,,,,
1,71,tgfb-i,0.292043,0.190456,,,,,,,,,,,,
2,71,wt,0.194967,0.147977,,,,,,,,,,,,
3,71,lmx1b,0.452274,0.289614,,,,,,,,,,,,
4,71,gdf3,0.28164,0.124806,,,,,,,,,,,,
5,71,avg_pert,0.30405,0.196193,,,,,,,,,,,,


In [None]:
scaffold_align_metrics_df_final["mo"]

Unnamed: 0,model_index,Perturbation,Initial_RMSE,Aligned_RMSE,SegmentColinearity,SegmentCovariance,SegmentColinearity_mean_within_hld,SegmentCovariance_mean_within_hld,SegmentColinearity_mean_within_all,SegmentCovariance_mean_within_all,disp_coefficient_all,dispersion_first_n_all,dispersion_last_n_all,disp_coefficient_hld,dispersion_first_n_hld,dispersion_last_n_hld
0,71,wnt-i,0.234647,0.186641,0.961767,0.145345,0.746983,0.172636,0.777472,0.165758,1.358857,0.116642,1.357972,1.514239,0.101214,1.592155
1,71,tgfb-i,0.292043,0.190456,0.979616,0.183081,0.746983,0.172636,0.777472,0.165758,1.358857,0.116642,1.357972,1.514239,0.101214,1.592155
2,71,wt,0.194967,0.147977,0.993592,0.285032,0.746983,0.172636,0.777472,0.165758,1.358857,0.116642,1.357972,1.514239,0.101214,1.592155
3,71,lmx1b,0.452274,0.289614,0.919742,0.245564,0.746983,0.172636,0.777472,0.165758,1.358857,0.116642,1.357972,1.514239,0.101214,1.592155
4,71,gdf3,0.28164,0.124806,0.969415,0.195387,0.746983,0.172636,0.777472,0.165758,1.358857,0.116642,1.357972,1.514239,0.101214,1.592155
5,71,avg_pert,0.30405,0.196193,0.964826,0.210882,0.746983,0.172636,0.777472,0.165758,1.358857,0.116642,1.357972,1.514239,0.101214,1.592155


In [62]:
splines_final_df_model_index

Unnamed: 0,dataset,Perturbation,point_index,PCA_1,PCA_2,PCA_3,model_index
0,all,wnt-i,499,0.512206,-0.449847,1.558961,71
1,all,wnt-i,498,0.517657,-0.454282,1.554024,71
2,all,wnt-i,497,0.522607,-0.458513,1.548412,71
3,all,wnt-i,496,0.527217,-0.462680,1.542470,71
4,all,wnt-i,495,0.531631,-0.466741,1.536310,71
...,...,...,...,...,...,...,...
22495,hld_aligned,gdf3,495,-3.847806,0.444283,-2.961415,78
22496,hld_aligned,gdf3,496,-3.818904,0.437477,-3.071533,78
22497,hld_aligned,gdf3,497,-3.793124,0.414569,-3.180309,78
22498,hld_aligned,gdf3,498,-3.766719,0.386207,-3.287675,78
