In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import to_hex
import pandas as pd
import numpy as np
from scipy import stats
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
from colorsys import rgb_to_hsv, hsv_to_rgb

# Combine all output csvs into one

In [None]:

def combine_csvs(input_folder, output_file, keep_source=False):
    """
    Robust CSV combiner that properly unifies indices across all files.
    
    Args:
        input_folder: Path to folder containing CSV files
        output_file: Path for combined output CSV
        keep_source: Add source filename column if True
    """
    # Get all CSV files
    csv_files = [f for f in os.listdir(input_folder) if f.lower().endswith('.csv')]
    
    if not csv_files:
        print(f"No CSV files found in {input_folder}")
        return
    
    dfs = []
    
    for csv_file in csv_files:
        file_path = os.path.join(input_folder, csv_file)
        try:
            # Read CSV while being careful about index columns
            df = pd.read_csv(file_path)
            
            # Handle case where index was saved as a column
            if 'Unnamed: 0' in df.columns:
                df = df.drop(columns=['Unnamed: 0'])
            
            if keep_source:
                df['source_file'] = csv_file
            
            dfs.append(df)
            print(f"Processed {csv_file} (shape: {df.shape})")
            
        except Exception as e:
            print(f"Error processing {csv_file}: {str(e)}")
            continue
    
    if not dfs:
        print("No valid data to combine")
        return
    
    # Combine with proper index handling
    combined = pd.concat(dfs, ignore_index=True)
    
    # Final cleanup of any remaining index columns
    combined = combined.loc[:, ~combined.columns.str.contains('^Unnamed')]
    
    # Save with clean sequential index
    combined.to_csv(output_file, index=False)
    
    print(f"\nSuccessfully combined {len(dfs)} files")
    print(f"Final output shape: {combined.shape}")
    print(f"Saved to: {output_file}")

In [None]:
combine_csvs('/Users/ischneider/Switzerland/ETH/Thesis/ThesisSimuCell3D/lumenoids_image_analysis_output_differentiation_batch2', '/Users/ischneider/Switzerland/ETH/Thesis/ThesisSimuCell3D/lumenoids_image_analysis_output_differentiation_batch2/lumenoids_combined_data_differentiation_batch2.csv')

# Calculate additional data 

In [None]:
lumenoids_path = "/Users/ischneider/Switzerland/ETH/Thesis/ThesisSimuCell3D/lumenoids_image_analysis_output_differentiation_batch2/lumenoids_combined_data_differentiation_batch2.csv"
lumenoids_output_diff_batch_2 = pd.read_csv(lumenoids_path)


In [None]:
lumenoids_output_diff_batch_2.shape

In [None]:
def get_category_color_mapping(
    df, 
    control_column="control", 
    image_day_column="imaging_day", 
    color_palette="viridis",  # Now accepts either palette name or manual color mapping
    manual_colors={
            'Day 2': '#440154',
            'Day 3': '#2a788e',
            'Day 4': '#7ad151',
            'Control': '#414487', 
            'Treatment': '#22a884'
        }   # Optional manual color specification
):
    """
    Create color mapping for categories, allowing manual color specification.
    
    Parameters:
    - df: DataFrame containing the data
    - control_column: Column name indicating control/treatment status
    - image_day_column: Column name indicating imaging day
    - color_palette: Either:
        - String name of matplotlib colormap (e.g., "viridis")
        - Dictionary of {base_category: color} mappings
        - None (uses default colors if manual_colors not provided)
    - manual_colors: Dictionary of specific color assignments 
      (e.g., {'Day 2': '#ff0000', 'Control': '#00ff00'})
      Overrides color_palette for specified categories
    """
    categories = []

    # Image days (control = -1)
    image_day_vals = sorted(df[df[control_column] == -1][image_day_column].unique())
    for day in image_day_vals:
        mask = (df[control_column] == -1) & (df[image_day_column] == day)
        if mask.any():
            categories.append(f'Day {day} (N={mask.sum()})')

    # Control and Treatment
    control_mask_0 = df[control_column] == 0
    control_mask_1 = df[control_column] == 1
    if control_mask_1.any():
        categories.append(f'Control (N={control_mask_1.sum()})')
    if control_mask_0.any():
        categories.append(f'Treatment (N={control_mask_0.sum()})')

    # Create color mapping
    color_mapping = {}
    
    # If color_palette is a dict, use it as base mapping
    if isinstance(color_palette, dict):
        base_colors = color_palette
    # If color_palette is a string, generate colors from palette
    elif isinstance(color_palette, str):
        cmap = plt.get_cmap(color_palette)
        all_base_categories = [f'Day {day}' for day in range(8)] + ['Control', 'Treatment']
        base_colors = {cat: to_hex(cmap(i/len(all_base_categories))) 
                      for i, cat in enumerate(all_base_categories)}
    # Otherwise use default colors
    else:
        base_colors = {
            'Day 0': '#440154', 'Day 1': '#3b528b', 'Day 2': '#21918c',
            'Day 3': '#5ec962', 'Day 4': '#fde725', 'Day 5': '#e66101',
            'Day 6': '#f781bf', 'Day 7': '#999999',
            'Control': '#000000', 'Treatment': '#ff0000'
        }
    
    # Apply manual color overrides if provided
    if manual_colors:
        for cat, color in manual_colors.items():
            base_colors[cat] = color
    
    # Assign colors to present categories
    for category in categories:
        # Extract base category name (without count)
        base_category = category.split(' (N=')[0]
        if 'Day' in base_category:
            day_num = int(base_category.split()[1])
            base_category = f'Day {day_num}'
        
        color_mapping[category] = base_colors.get(base_category, '#333333')
    
    return color_mapping

In [None]:
def plot_stacked_histogram(
    df, 
    value_column, 
    control_column="control", 
    image_day_column="imaging_day", 
    color_palette="viridis",
    mean_line_color="red",
    mean_line_style="--",
    mean_line_width=1.5,
    annotate_mean=True,
    export_path=None,        
    export_dpi=300,
    units_x=None,
    units_y=None           
):
    data_to_plot = []
    labels = []
    colors = []
    total_n = len(df)

    category_color_map = get_category_color_mapping(df, control_column, image_day_column, color_palette="viridis")

    # Control 0 & 1
    for control_val in [0, 1]:
        mask = df[control_column] == control_val
        if mask.sum() > 0:
            label = f'Treatment (N={mask.sum()})' if control_val == 0 else f'Control (N={mask.sum()})'
            data_to_plot.append(df.loc[mask, value_column])
            labels.append(label)
            colors.append(category_color_map[label])

    # Image days (control = -1)
    image_day_cats = sorted(df[df[control_column] == -1][image_day_column].unique())
    for day_val in image_day_cats:
        mask = (df[control_column] == -1) & (df[image_day_column] == day_val)
        if mask.sum() > 0:
            label = f'Day {day_val} (N={mask.sum()})'
            data_to_plot.append(df.loc[mask, value_column])
            labels.append(label)
            colors.append(category_color_map[label])

    # Create figure with white background
    fig, ax = plt.subplots(figsize=(10, 6), facecolor='white')
    ax.set_facecolor('white')

    # Plot histogram without grid
    ax.hist(
        data_to_plot,
        bins='auto',
        stacked=True,
        alpha=0.7,
        label=labels,
        color=colors,
        edgecolor='white',
        linewidth=0.5
    )

    mean_val = df[value_column].mean()
    ax.axvline(
        mean_val,
        color=mean_line_color,
        linestyle=mean_line_style,
        linewidth=mean_line_width,
        label=f'Mean = {mean_val:.2f}'
    )

    x_label = f"{value_column} [{units_x}]" if units_x else value_column
    y_label = f"Frequency [{units_y}]" if units_y else "Frequency"

    ax.set_xlabel(x_label, fontsize=12)
    ax.set_ylabel(y_label, fontsize=12)
    ax.set_title(f'Stacked Histogram of {value_column}', fontsize=14)

    # Remove grid
    ax.grid(False)
    
    # Keep all four borders (complete rectangle)
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_color('black')
        spine.set_linewidth(1.0)

    ax.legend(
        title=f'Labels (Total N={total_n})',
        bbox_to_anchor=(1.02, 1),
        loc='upper left',
        framealpha=0.9
    )

    plt.tight_layout()

    if export_path:
        fig.savefig(export_path, dpi=export_dpi, bbox_inches='tight', transparent=False)
        print(f"Plot saved to {export_path}")

    return fig, ax

In [None]:
def plot_categorized_scatter(
    df, 
    x_col, 
    y_col, 
    control_column="control",
    image_day_column="imaging_day",
    x_equals_y=False,
    plot_title=None,
    color_palette="viridis",
    show_means=False,
    export_path=None,
    export_dpi=300,
    units_x=None,
    units_y=None,
    fit_line=False,
    line_color='red',
    line_style='--',
    line_alpha=0.5,
    show_r_squared=False,
    log_scale=False,  # New parameter for log scaling
    log_threshold=1e-20  # Small value to replace zeros in log scale
):
    # Import required libraries
    import numpy as np
    from scipy import stats
    import matplotlib.pyplot as plt
    
    # Check for required columns
    for col in [x_col, y_col, control_column, image_day_column]:
        if col not in df.columns:
            raise ValueError(f"Column '{col}' not found in DataFrame")

    # Handle log scale requirements
    if log_scale:
        # Check for non-positive values
        if (df[x_col] <= 0).any() or (df[y_col] <= 0).any():
            print(f"Warning: Log scale requested but data contains non-positive values. "
                  f"Values <= {log_threshold} will be set to {log_threshold}.")
            df = df.copy()  # Avoid modifying original dataframe
            df[x_col] = df[x_col].clip(lower=log_threshold)
            df[y_col] = df[y_col].clip(lower=log_threshold)

    total_n = len(df)
    category_color_map = get_category_color_mapping(df, control_column, image_day_column, color_palette="viridis")

    color_masks = []
    for label, color in category_color_map.items():
        if "Treatment" in label:
            mask = df[control_column] == 0
        elif "Control" in label:
            mask = df[control_column] == 1
        elif "Day" in label:
            day_val = int(label.split()[1])
            mask = (df[control_column] == -1) & (df[image_day_column] == day_val)
        else:
            continue
        color_masks.append((mask, color, label))

    fig, ax = plt.subplots(figsize=(10, 6))

    # Apply log scale if requested
    if log_scale:
        ax.set_xscale('log')
        ax.set_yscale('log')

    # Plot the scatter points
    for mask, color, label in color_masks:
        ax.scatter(
            df.loc[mask, x_col],
            df.loc[mask, y_col],
            c=[color],
            label=label,
            alpha=0.6,
            edgecolors='w',
            linewidth=0.5
        )

    # Handle x=y line (works correctly with log scale)
    if x_equals_y:
        if log_scale:
            lim_min = max(df[x_col].min(), df[y_col].min())
            lim_max = min(df[x_col].max(), df[y_col].max())
        else:
            lim_min = min(df[x_col].min(), df[y_col].min())
            lim_max = max(df[x_col].max(), df[y_col].max())
        
        ax.plot([lim_min, lim_max], [lim_min, lim_max], 
                '--', color='gray', alpha=0.7, label='x = y')

    # Handle means (works correctly with log scale)
    if show_means:
        x_mean = df[x_col].mean()
        y_mean = df[y_col].mean()
        ax.axvline(x_mean, color='red', linestyle='--', alpha=0.7, label=f'Mean {x_col} = {x_mean:.2f}')
        ax.axhline(y_mean, color='blue', linestyle='--', alpha=0.7, label=f'Mean {y_col} = {y_mean:.2f}')

    # Handle line fitting (works correctly with log scale)
    if fit_line:
        x = df[x_col].values
        y = df[y_col].values
        mask = ~np.isnan(x) & ~np.isnan(y)
        x = x[mask]
        y = y[mask]
        
        if len(x) > 1:
            if log_scale:
                # Fit line in log space but display in linear coordinates (power law)
                log_x = np.log(x)
                log_y = np.log(y)
                slope, intercept, r_value, p_value, std_err = stats.linregress(log_x, log_y)
                
                # Generate points for the line
                line_x = np.array([x.min(), x.max()])
                line_y = np.exp(intercept) * (line_x ** slope)
                
                # Calculate R-squared in log space
                r_squared = r_value**2
            else:
                # Regular linear fit
                slope, intercept, r_value, p_value, std_err = stats.linregress(x, y)
                line_x = np.array([x.min(), x.max()])
                line_y = intercept + slope * line_x
                r_squared = r_value**2
            
            line_label = 'Fit line'
            if show_r_squared:
                line_label += f' (R² = {r_squared:.2f})'
                if log_scale:
                    line_label += ' [power law fit]'
            
            ax.plot(line_x, line_y, 
                    linestyle=line_style, 
                    color=line_color, 
                    alpha=line_alpha, 
                    label=line_label)

    # Format axis labels
    x_label = f"{x_col} [{units_x}]" if units_x else x_col
    y_label = f"{y_col} [{units_y}]" if units_y else y_col
    # if log_scale:
    #     x_label = f"log({x_label})"
    #     y_label = f"log({y_label})"

    ax.set_xlabel(x_label, fontsize=12)
    ax.set_ylabel(y_label, fontsize=12)
    ax.set_title(plot_title or f"{y_col} vs {x_col} by Categories", fontsize=14)
    ax.grid(True, alpha=0.3)

    ax.legend(
        title=f'Categories (Total N={total_n})',
        bbox_to_anchor=(1.05, 1),
        loc='upper left',
        framealpha=0.9
    )

    plt.tight_layout()

    if export_path:
        fig.savefig(export_path, dpi=export_dpi, bbox_inches='tight', transparent=False)
        print(f"Plot saved to {export_path}")

    return fig, ax

In [None]:
import numpy as np
lumenoids_output_diff_batch_2["v_outer_ellipsoid"] = np.pi * 4 / 3 * lumenoids_output_diff_batch_2.a_outer_ellipsoid * lumenoids_output_diff_batch_2.b_outer_ellipsoid * lumenoids_output_diff_batch_2.c_outer_ellipsoid
lumenoids_output_diff_batch_2["v_inner_ellipsoid"] = np.pi * 4 / 3 * lumenoids_output_diff_batch_2.a_inner_ellipsoid * lumenoids_output_diff_batch_2.b_inner_ellipsoid * lumenoids_output_diff_batch_2.c_inner_ellipsoid

In [None]:
lumenoids_output_diff_batch_2["v_epithelium_ellipsoid"] = lumenoids_output_diff_batch_2.v_outer_ellipsoid - lumenoids_output_diff_batch_2.v_inner_ellipsoid
lumenoids_output_diff_batch_2["v_epithelium_mesh"] = lumenoids_output_diff_batch_2.v_outer_mesh - lumenoids_output_diff_batch_2.v_inner_mesh

In [None]:
lumenoids_output_diff_batch_2["mean_height_ellipsoid"] = ((lumenoids_output_diff_batch_2.a_outer_ellipsoid - lumenoids_output_diff_batch_2.a_inner_ellipsoid) + (lumenoids_output_diff_batch_2.b_outer_ellipsoid - lumenoids_output_diff_batch_2.b_inner_ellipsoid) + (lumenoids_output_diff_batch_2.c_outer_ellipsoid - lumenoids_output_diff_batch_2.c_inner_ellipsoid)) / 3

In [None]:
lumenoids_output_diff_batch_2["epithelium_lumen_v_ratio_mesh"] = lumenoids_output_diff_batch_2.v_epithelium_mesh / lumenoids_output_diff_batch_2.v_inner_mesh

In [None]:
lumenoids_output_diff_batch_2["outer_ellipsoid_shortest_longest_axis_ratio"] = lumenoids_output_diff_batch_2[["a_outer_ellipsoid", "b_outer_ellipsoid","c_outer_ellipsoid"]].min(axis=1) / lumenoids_output_diff_batch_2[["a_outer_ellipsoid", "b_outer_ellipsoid","c_outer_ellipsoid"]].max(axis=1)
lumenoids_output_diff_batch_2["inner_ellipsoid_shortest_longest_axis_ratio"] = lumenoids_output_diff_batch_2[["a_inner_ellipsoid", "b_inner_ellipsoid","c_inner_ellipsoid"]].min(axis=1) / lumenoids_output_diff_batch_2[["a_inner_ellipsoid", "b_inner_ellipsoid","c_inner_ellipsoid"]].max(axis=1)

In [None]:
lumenoids_output_diff_batch_2["lumen_sphericity_mesh"] = ((np.pi ** (1/3)) * ((6 * lumenoids_output_diff_batch_2.v_inner_mesh) ** (2/3))) / lumenoids_output_diff_batch_2.area_inner_mesh
# lumenoids_output_diff_batch_2["lumenoid_sphericity_mesh"] = ((np.pi ** (1/3)) * ((6 * lumenoids_output_diff_batch_2.v_outer_mesh) ** (2/3))) / lumenoids_output_diff_batch_2.area_outer_mesh

In [None]:
def geometric_diff(x, y):
        return abs(x - y) / np.sqrt(x * y)

lumenoids_output_diff_batch_2["ab_inner_diff"] = geometric_diff(lumenoids_output_diff_batch_2.a_inner_ellipsoid, lumenoids_output_diff_batch_2.b_inner_ellipsoid)
lumenoids_output_diff_batch_2["ac_inner_diff"] = geometric_diff(lumenoids_output_diff_batch_2.a_inner_ellipsoid, lumenoids_output_diff_batch_2.c_inner_ellipsoid)
lumenoids_output_diff_batch_2["bc_inner_diff"] = geometric_diff(lumenoids_output_diff_batch_2.b_inner_ellipsoid, lumenoids_output_diff_batch_2.c_inner_ellipsoid)

lumenoids_output_diff_batch_2["ab_outer_diff"] = geometric_diff(lumenoids_output_diff_batch_2.a_outer_ellipsoid, lumenoids_output_diff_batch_2.b_outer_ellipsoid)
lumenoids_output_diff_batch_2["ac_outer_diff"] = geometric_diff(lumenoids_output_diff_batch_2.a_outer_ellipsoid, lumenoids_output_diff_batch_2.c_outer_ellipsoid)
lumenoids_output_diff_batch_2["bc_outer_diff"] = geometric_diff(lumenoids_output_diff_batch_2.b_outer_ellipsoid, lumenoids_output_diff_batch_2.c_outer_ellipsoid)

In [None]:
lumenoids_output_diff_batch_2[["ab_inner_diff", "ac_inner_diff","bc_inner_diff","a_inner_ellipsoid", "b_inner_ellipsoid","c_inner_ellipsoid"]]

In [None]:
lumenoids_output_diff_batch_2["most_similar_lumen_axis"] = lumenoids_output_diff_batch_2[["ab_inner_diff", "ac_inner_diff","bc_inner_diff"]].idxmin(axis=1)
lumenoids_output_diff_batch_2["most_diff_lumen_axis"] = lumenoids_output_diff_batch_2[["ab_inner_diff", "ac_inner_diff","bc_inner_diff"]].idxmax(axis=1)

lumenoids_output_diff_batch_2["most_similar_lumen_axis_value"] = lumenoids_output_diff_batch_2[["ab_inner_diff", "ac_inner_diff","bc_inner_diff"]].min(axis=1)
lumenoids_output_diff_batch_2["most_diff_lumen_axis_value"] = lumenoids_output_diff_batch_2[["ab_inner_diff", "ac_inner_diff","bc_inner_diff"]].max(axis=1)

In [None]:
lumenoids_output_diff_batch_2["most_similar_axis"] = lumenoids_output_diff_batch_2[["ab_outer_diff", "ac_outer_diff","bc_outer_diff"]].idxmin(axis=1)
lumenoids_output_diff_batch_2["most_diff_axis"] = lumenoids_output_diff_batch_2[["ab_outer_diff", "ac_outer_diff","bc_outer_diff"]].idxmax(axis=1)

lumenoids_output_diff_batch_2["most_similar_axis_value"] = lumenoids_output_diff_batch_2[["ab_outer_diff", "ac_outer_diff","bc_outer_diff"]].min(axis=1)
lumenoids_output_diff_batch_2["most_diff_axis_value"] = lumenoids_output_diff_batch_2[["ab_outer_diff", "ac_outer_diff","bc_outer_diff"]].max(axis=1)

# Fitting the theory

## Roman's version

In [None]:
def add_data_for_theory(df):
    axis = df.most_similar_lumen_axis.str.split('_').str[0]

    aspect_ratio = []
    lumen_radius = []
    cross_section_radius = []
    diff_dim = []
    for i,ax in enumerate(axis):
        if ax == "ab" or  ax == "ba":
            x = df.a_inner_ellipsoid[i]
            y = df.b_inner_ellipsoid[i]
            z = df.c_inner_ellipsoid[i]
        elif ax == "ac" or  ax == "ca":
            x = df.a_inner_ellipsoid[i]
            y = df.c_inner_ellipsoid[i]
            z = df.b_inner_ellipsoid[i]
        elif ax == "bc" or  ax == "cb":
            x = df.b_inner_ellipsoid[i]
            y = df.c_inner_ellipsoid[i]
            z = df.a_inner_ellipsoid[i]

        cr = np.sqrt(x*y)
        ar = z/cr

        diff_dim.append(z)
        cross_section_radius.append(cr)
        aspect_ratio.append(ar)
        lumen_radius.append((x*y*z)**(1/3))

    df["aspect_ratio"] = aspect_ratio
    df["lumen_radius_cubic"] = lumen_radius
    df["cross_section_radius_squared"] = cross_section_radius
    df["diff_lumen_dim"] = diff_dim


In [None]:
add_data_for_theory(lumenoids_output_diff_batch_2)

In [None]:
lumenoids_output_diff_batch_2["relative_thickness_mesh"] = lumenoids_output_diff_batch_2.mean_height_mesh / lumenoids_output_diff_batch_2.lumen_radius_cubic
lumenoids_output_diff_batch_2["relative_thickness_ellipsoid"] = lumenoids_output_diff_batch_2.mean_height_ellipsoid / lumenoids_output_diff_batch_2.lumen_radius_cubic

In [None]:
lumenoids_output_diff_batch_2.to_csv('/Users/ischneider/Switzerland/ETH/Thesis/ThesisSimuCell3D/lumenoids_analyzed_data_differentiation_experiment_batch2.csv')

In [None]:
circular_cross_section_subset = lumenoids_output_diff_batch_2[lumenoids_output_diff_batch_2.most_similar_lumen_axis_value <= 0.13]
circular_cross_section_subset.shape

In [None]:
def calculate_epi_lumen_circular_ratio(df):
    delta = df.relative_thickness_ellipsoid
    epsilon = df.aspect_ratio
    return delta * (epsilon**(1/3)) / epsilon * ( (delta * (epsilon**(1/3)))**2 + (2 + epsilon) * delta * (epsilon**(1/3)) + 2*epsilon + 1 )

In [None]:
circular_cross_section_subset["calculated_epi_lumen_circular_ratio"] = calculate_epi_lumen_circular_ratio(circular_cross_section_subset)  

In [None]:
spherical_lumenoids = lumenoids_output_diff_batch_2[lumenoids_output_diff_batch_2.most_diff_lumen_axis_value <= 0.13]
spherical_lumenoids.shape

In [None]:
def calculate_epi_lumen_spherical_ratio(df):
    delta = df.relative_thickness_mesh
    return delta*(delta**2 + 3*delta + 3)

In [None]:
spherical_lumenoids["calculated_epi_lumen_spherical_ratio"] = calculate_epi_lumen_spherical_ratio(spherical_lumenoids)

In [None]:
def plot_categorized_scatter(
    df, 
    x_col, 
    y_col, 
    x_lab,
    y_lab,
    control_column="control",
    image_day_column="imaging_day",
    x_equals_y=False,
    color_palette="viridis",
    show_means=False,
    export_path=None,
    export_dpi=300,
    units_x=None,
    units_y=None,
    fit_line=False,
    line_color='black',
    line_style='-',
    line_alpha=0.5,
    show_r_squared=False,
    log_scale=False,
    log_threshold=1e-20,
    ax=None,
):
    import numpy as np
    from scipy import stats
    import matplotlib.pyplot as plt

    for col in [x_col, y_col, control_column, image_day_column]:
        if col not in df.columns:
            raise ValueError(f"Column '{col}' not found in DataFrame")

    if log_scale:
        if (df[x_col] <= 0).any() or (df[y_col] <= 0).any():
            print(f"Warning: Log scale requested but data contains non-positive values. "
                  f"Values <= {log_threshold} will be set to {log_threshold}.")
            df = df.copy()
            df[x_col] = df[x_col].clip(lower=log_threshold)
            df[y_col] = df[y_col].clip(lower=log_threshold)

    total_n = len(df)
    category_color_map = get_category_color_mapping(df, control_column, image_day_column, color_palette="viridis")

    color_masks = []
    for label, color in category_color_map.items():
        if "Treatment" in label:
            mask = df[control_column] == 0
        elif "Control" in label:
            mask = df[control_column] == 1
        elif "Day" in label:
            day_val = int(label.split()[1])
            mask = (df[control_column] == -1) & (df[image_day_column] == day_val)
        else:
            continue
        color_masks.append((mask, color, label))

    fig = None
    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 6))
    else:
        fig = ax.get_figure()

    if log_scale:
        ax.set_xscale('log')
        ax.set_yscale('log')

    for mask, color, label in color_masks:
        ax.scatter(
            df.loc[mask, x_col],
            df.loc[mask, y_col],
            c=[color],
            label=label,
            alpha=0.6,
            edgecolors='w',
            linewidth=0.5
        )

    if x_equals_y:
        if log_scale:
            lim_min = max(df[x_col].min(), df[y_col].min())
            lim_max = min(df[x_col].max(), df[y_col].max())
        else:
            lim_min = min(df[x_col].min(), df[y_col].min())
            lim_max = max(df[x_col].max(), df[y_col].max())

        ax.plot([lim_min, lim_max], [lim_min, lim_max], 
                '--', color='gray', alpha=0.7, label='x = y')

    if show_means:
        x_mean = df[x_col].mean()
        y_mean = df[y_col].mean()
        ax.axvline(x_mean, color='red', linestyle='--', alpha=0.7, label=f'Mean {x_col} = {x_mean:.2f}')
        ax.axhline(y_mean, color='blue', linestyle='--', alpha=0.7, label=f'Mean {y_col} = {y_mean:.2f}')

    if fit_line:
        x = df[x_col].values
        y = df[y_col].values
        mask = ~np.isnan(x) & ~np.isnan(y)
        x = x[mask]
        y = y[mask]

        if len(x) > 1:
            if log_scale:
                log_x = np.log(x)
                log_y = np.log(y)
                slope, intercept, r_value, p_value, std_err = stats.linregress(log_x, log_y)
                line_x = np.array([x.min(), x.max()])
                line_y = np.exp(intercept) * (line_x ** slope)
                r_squared = r_value**2
                ax.text(0.05, 0.95,  f'Power-law fit: $S \\sim V^{{\\alpha}}$\n$\\alpha = {slope:.2f} \\pm {std_err:.2f}$\n$(R^2={r_squared:.2f})$', transform=ax.transAxes, fontsize=11, verticalalignment='top')
               

            else:
                slope, intercept, r_value, p_value, std_err = stats.linregress(x, y)
                line_x = np.array([x.min(), x.max()])
                line_y = intercept + slope * line_x
                r_squared = r_value**2

            line_label = 'Fit line'
            if show_r_squared:
                line_label += f' (R² = {r_squared:.2f})'
                if log_scale:
                    line_label += ' [power law fit]'

            ax.plot(line_x, line_y, 
                    linestyle=line_style, 
                    color=line_color, 
                    alpha=line_alpha, 
                    label=line_label)

    x_label = f"{x_lab} [{units_x}]" if units_x else x_lab
    y_label = f"{y_lab} [{units_y}]" if units_y else y_lab
    if log_scale:
        x_label = f"log({x_lab}) [{units_x}]" if units_x else f"log({x_lab})"
        y_label = f"log({y_lab}) [{units_y}]" if units_y else f"log({y_lab})"

    ax.set_xlabel(x_label, fontsize=12,labelpad=4)
    ax.set_ylabel(y_label, fontsize=12, labelpad=4)
    # ax.set_title(title, fontsize=14, pad=10)
    ax.tick_params(axis='x', labelsize=9)
    ax.tick_params(axis='y', labelsize=9)

    ax.grid(True, alpha=0.3)

    # ax.legend(
    #     title=f'Categories (Total N={total_n})',
    #     bbox_to_anchor=(1.05, 1),
    #     loc='upper left',
    #     framealpha=0.9
    # )

    if ax is None:
        plt.tight_layout()

    if export_path:
        fig.savefig(export_path, dpi=export_dpi, bbox_inches='tight', transparent=False)
        print(f"Plot saved to {export_path}")

    return fig, ax

In [None]:
# --- Main Plotting ---
fig = plt.figure(figsize=(7, 60))

# Positions: [left, bottom, width, height] in 0-1 range

ax1 = fig.add_axes([0.0, 0.96, 0.425, 0.03]) 
ax2 = fig.add_axes([0.575, 0.96, 0.425, 0.03])

ax3 = fig.add_axes([0.0, 0.91, 0.425, 0.03]) 
ax3.text(0.65, 0.65, r'$\alpha = \frac{|a-b|}{\sqrt{ab}}$', transform=ax3.transAxes, fontsize=16, verticalalignment='top')
ax4 = fig.add_axes([0.575, 0.91, 0.425, 0.03])

ax5 = fig.add_axes([0.0, 0.86, 0.425, 0.03]) 
ax6 = fig.add_axes([0.575, 0.86, 0.425, 0.03])

# ax7 = fig.add_axes([0.2875, 0.81, 0.425, 0.03]) 

axes = [ax1, ax2, ax3, ax4, ax5, ax6]
letters = ['A', 'B', 'C', 'D', 'E', 'F', 'G']

# Add reference letters to each subplot
for ax, letter in zip(axes, letters):
    ax.text(-0.07, 1.3, letter, transform=ax.transAxes, 
            fontsize=16, fontweight='bold', va='top', ha='right')

# --- Generate plots ---
plot_categorized_scatter(lumenoids_output_diff_batch_2,
    x_col="v_outer_mesh",
    y_col="v_outer_ellipsoid",
    # x_lab="Mesh Lumenoid Volume ($S+V$)",
    # y_lab="Ellipsoid Lumenoid Volume ($S+V$)",
    x_lab="Mesh $S+V$",
    y_lab="Ellipsoid $S+V$",
    x_equals_y=True,
    units_x = "$\mu m^3$",
    units_y = "$\mu m^3$",
    # fit_line=True,
    show_r_squared=True,
    log_scale=True,
    ax=ax1)


# plot_categorized_scatter(lumenoids_processed_data,
#     x_col="v_inner_mesh",
#     y_col="v_inner_ellipsoid",
#     # x_lab="Mesh Lumen Volume $V$",
#     # y_lab="Ellipsoid Lumen Volume $V$",
#     x_lab="Mesh $V$",
#     y_lab="Ellipsoid $V$",
#     x_equals_y=True,
#     units_x = "$\mu m^3$",
#     units_y = "$\mu m^3$",
#     # fit_line=True,
#     show_r_squared=True,
#     log_scale=True,
#     ax=ax2)

plot_categorized_scatter(lumenoids_output_diff_batch_2,
    x_col="mean_height_mesh",
    y_col="mean_height_ellipsoid",
    # x_lab="Mesh Cell Height $h$",
    # y_lab="Ellipsoid Cell Height $h$",
    x_lab="Mesh $h$",
    y_lab="Ellipsoid $h$",
    x_equals_y=True,
    units_x = "$\mu m$",
    units_y = "$\mu m$",
    # fit_line=True,
    show_r_squared=True,
    ax=ax2
)

plot_side_by_side_histograms(ax3, lumenoids_output_diff_batch_2,'most_similar_lumen_axis_value', -0.07, 0.6, 0.0, 16, bins=10)

plot_categorized_scatter(
    circular_cross_section_subset,
    x_col="epithelium_lumen_v_ratio_mesh",
    y_col="calculated_epi_lumen_circular_ratio",
    # x_lab="Mesh Epithelium-lumen volume ratio $S/V$",
    # y_lab="Calculated Epithelium-lumen volume ratio $S/V$",
    x_lab="Mesh $S/V$",
    y_lab="Calculated $S/V$",
    x_equals_y=True,
    # fit_line=True,
    log_scale=True,
    ax=ax4
)

plot_categorized_scatter(
    spherical_lumenoids,
    x_col="epithelium_lumen_v_ratio_mesh",
    y_col="calculated_epi_lumen_spherical_ratio",
    # x_lab="Mesh Epithelium-lumen volume ratio $S/V$",
    # y_lab="Calculated Epithelium-lumen volume ratio $S/V$",
    x_lab="Mesh $S/V$",
    y_lab="Calculated $S/V$",
    x_equals_y=True,
    # fit_line=True,
    log_scale=True,
    ax=ax5
)

plot_categorized_scatter(
    lumenoids_analyzed_data[lumenoids_output_diff_batch_2.control == -1],
    x_col="v_inner_mesh",
    y_col="v_epithelium_mesh",
    # x_lab="Lumen Volume $V$",
    x_lab="$V$",
    y_lab="$S$",
    # x_equals_y=True,
    fit_line=True,
    log_scale=True,
    ax=ax6,
    units_x = "$\mu m^3$",
    units_y = "$\mu m^3$",
)



# Get the actual color mapping from your data
color_mapping = get_category_color_mapping(lumenoids_processed_data)

# Create legend elements
legend_elements = []
for category, color in color_mapping.items():
    legend_elements.append(
        Patch(facecolor=color, edgecolor='white', label=category)
    )

legend_elements.extend([
    # Statistical annotations
    Line2D([0], [0], color='red', linestyle='--', lw=1.5, 
           label='Mean'),
    Line2D([0], [0], color='gray', linestyle='--', lw=1.5, 
           label='x=y'),
    Line2D([0], [0], color='black', linestyle='-', lw=1.5, 
           label='Fitted Line'),
])

legend_style = {
    'loc': 'center',
    'fontsize': 8,
    'framealpha': 1,
    'frameon': False,
    'title_fontsize': 14,
    'borderpad': 0.5,
    'handletextpad': 0.5,    # Reduced from 2 to 0.5 (symbol ↔ text)
    'columnspacing': 1.0,    # Added: space between columns
    'labelspacing': 0.3,     # Added: vertical space between rows
}

legend = fig.legend(
    handles=legend_elements,
    bbox_to_anchor=(0.5, 1.005),
    title='Legend:',
    ncol=np.round(len(legend_elements)/2),
    **legend_style
)

legend.get_title().set_fontweight('bold')
legend.get_title().set_fontsize(10)
legend._legend_box.sep = 8

# --- Save ---
# fig.savefig(f"plots/image_analysis/luemnoids_main_plot{t_alpha}_{t_iou}_n{n}.pdf", bbox_inches='tight')
# plt.tight_layout()
# plt.savefig(f"plots/image_analysis/lumenoids_main_plot_initial_analysis.pdf", bbox_inches='tight', 
#             dpi=1200,
#             format='pdf')
# plt.show()

# Generate and save the table
# stats_table = generate_statistics_table(differentiation_circular_cross_section_subset_iou, variables_to_plot)
# save_statistics_table_latex(
#     stats_table, 
#     f"tables/image_analysis/luemnoids_main_plot{t_alpha}_{t_iou}_n{n}.tex",
#     caption="Statistical summary of thresholded lumenoid data",
#     label="tab:stats"
# )

# Save to CSV for LaTeX import
# stats_table.to_csv(f"stats_table_t{t}_{t_iou}_n{n}.csv", index=False)

# Display the table (optional)
# display(stats_table)

# # Example usage:
# plot_statistics_table(stats_table, title=f"Statistics for thresholds t_alpha={t_alpha}, t_iou={t_iou} (N={n})")
# plt.show()


In [None]:
lumenoids_analyzed_data.control

In [None]:
lumenoids_analyzed_data.columns

## Composite Figure

In [None]:
lumenoids_output_diff_batch_2.columns

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import mannwhitneyu, wilcoxon
from statsmodels.stats.multitest import multipletests
import itertools
import numpy as np
from matplotlib.ticker import ScalarFormatter
import itertools
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import stats
from matplotlib.ticker import StrMethodFormatter
from matplotlib.ticker import FuncFormatter

# --- Helper Functions ---
def lighten_color(color, amount=0.5):
    import matplotlib.colors as mc
    import colorsys
    try:
        c = mc.cnames[color]
    except:
        c = color
    c = colorsys.rgb_to_hsv(*mc.to_rgb(c))
    return colorsys.hsv_to_rgb(c[0], c[1], 1 - amount * (1 - c[2]))

def add_mwu_annotations(ax, df, var, group_column='imaging_day', 
                        correction_method='fdr_bh', 
                        show_test_info=False, 
                        test_info_off_set=0.9,
                        line_height_factor=0.1,  # Controls vertical spacing between annotations
                        bracket_height_factor=0.02,  # Controls bracket height
                        initial_offset_factor=0.1,  # Controls initial vertical offset from data
                        significance_fontsize=8):
    """
    Add Mann-Whitney U test annotations to a plot.
    
    Parameters:
    -----------
    ax : matplotlib.axes.Axes
        The axes to annotate
    line_height_factor : float, optional
        Multiplier for vertical spacing between annotations (default: 0.1)
    bracket_height_factor : float, optional
        Multiplier for bracket height (default: 0.02)
    initial_offset_factor : float, optional
        Multiplier for initial vertical offset from max data point (default: 0.1)
    significance_fontsize : int, optional
        Font size for significance markers (default: 8)
    ... (other existing parameters) ...
    """
    groups = sorted(df[group_column].unique())
    group_data = {g: df[df[group_column] == g][var].dropna() for g in groups}

    pairs = list(itertools.combinations(groups, 2))
    pairwise_results = {
        (g1, g2): stats.mannwhitneyu(group_data[g1], group_data[g2], alternative='two-sided')
        for g1, g2 in pairs
    }
    p_values = [res.pvalue for res in pairwise_results.values()]
    reject, corrected_pvals, _, _ = multipletests(p_values, method=correction_method)

    # --- Print Report ---
    print(f"\n=== Mann-Whitney U Test Report for {var} ===")
    print(f"Groups: {groups} (n={[len(group_data[g]) for g in groups]})")
    print("\nPairwise Tests (corrected p-values):")
    for (g1, g2), p_corr in zip(pairs, corrected_pvals):
        stat = pairwise_results[(g1, g2)].statistic
        p_raw = pairwise_results[(g1, g2)].pvalue
        print(f"  {g1} vs {g2}: U={stat}, p_raw={p_raw:.4f}, p_corr={p_corr:.4f}")

    y_max = df[var].max()
    y_min = df[var].min()
    y_range = y_max - y_min

    for i, ((g1, g2), p_corr) in enumerate(zip(pairs, corrected_pvals)):
        y_pos = y_max + (initial_offset_factor * y_range) + (line_height_factor * y_range) * i

        # Draw bracket
        ax.plot([groups.index(g1), groups.index(g1), groups.index(g2), groups.index(g2)],
                [y_pos - bracket_height_factor * y_range, y_pos, y_pos, y_pos - bracket_height_factor * y_range],
                lw=1, color='black')

        # Add significance marker
        significance = ('***' if p_corr < 0.001 else
                       '**' if p_corr < 0.01 else
                       '*' if p_corr < 0.05 else
                       'n.s.')
        ax.text((groups.index(g1) + groups.index(g2)) / 2, 
               y_pos, 
               significance,
               ha='center', 
               va='bottom', 
               fontsize=significance_fontsize)

    if show_test_info:
        ax.text(test_info_off_set, 0.98, "MWU",
                transform=ax.transAxes,
                ha='right', va='top',
                fontsize=6,
                bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', boxstyle='round,pad=0.2'))

    return {
        'test_type': 'Mann-Whitney U',
        'pairwise': {
            (g1, g2): {
                'stat': pairwise_results[(g1, g2)].statistic,
                'p_raw': pairwise_results[(g1, g2)].pvalue,
                'p_corr': p_corr
            } for (g1, g2), p_corr in zip(pairs, corrected_pvals)
        }
    }

def add_statistical_annotations_3groups(
    ax, df, var, group_column='imaging_day',
    normality_threshold=0.05,
    equal_var_threshold=0.10,
    correction_method='fdr_bh',
    show_test_info=False,
    test_info_off_set = 0.22,
):
    # --- Data Preparation ---
    groups = sorted(df[group_column].unique())
    group_data = {g: df[df[group_column]==g][var].dropna() for g in groups}
    
    # --- Assumption Checks ---
    norm_results = {g: stats.shapiro(group_data[g]) for g in groups}
    normal_groups = all(p > normality_threshold for _, p in norm_results.values())
    
    if normal_groups:
        levene_p = stats.levene(*[group_data[g] for g in groups]).pvalue
        equal_var = levene_p > equal_var_threshold
    else:
        equal_var = False
    
    # --- Test Selection ---
    if normal_groups:
        pairwise_test = lambda x,y: stats.ttest_ind(x, y, equal_var=equal_var)
        test_type = "Student's t-test" if equal_var else "Welch's t-test"
    else:
        pairwise_test = stats.mannwhitneyu
        test_type = "Mann-Whitney U"
    
    # --- Pairwise Tests ---
    pairs = list(itertools.combinations(groups, 2))
    pairwise_results = {
        (g1, g2): pairwise_test(group_data[g1], group_data[g2])
        for g1, g2 in pairs
    }
    p_values = [res.pvalue for res in pairwise_results.values()]
    
    # Multiple testing correction
    reject, corrected_pvals, _, _ = multipletests(p_values, method=correction_method)
    
    # --- Print Report ---
    print(f"\n=== Pairwise Test Report for {var} ===")
    print(f"Groups: {groups} (n={[len(group_data[g]) for g in groups]})")
    print(f"\nNormality (Shapiro-Wilk):")
    for g in groups:
        print(f"  {g}: W={norm_results[g][0]}, p={norm_results[g][1]}")
    if normal_groups:
        print(f"\nEqual Variance (Levene): p={levene_p}")
    
    print(f"\nSelected: {test_type}")
    print("\nPairwise Tests (corrected p-values):")
    for (g1, g2), p_corr in zip(pairs, corrected_pvals):
        stat = pairwise_results[(g1, g2)].statistic
        print(f"  {g1} vs {g2}: stat={stat}, p={p_corr}")
    
    # --- Plot All Pairwise Comparisons ---
    y_max = df[var].max()
    y_min = df[var].min()
    y_range = y_max - y_min
    
    for i, ((g1, g2), p_corr) in enumerate(zip(pairs, corrected_pvals)):
        y_pos = y_max + (0.15 * y_range) * (i + 1)  # Increased spacing
        
        # Draw line
        ax.plot([groups.index(g1), groups.index(g1), groups.index(g2), groups.index(g2)], 
                [y_pos-0.02*y_range, y_pos, y_pos, y_pos-0.02*y_range], 
                lw=1, color='black')
        
        # Add stars or 'n.s.'
        significance = ('***' if p_corr < 0.001 else 
                       '**' if p_corr < 0.01 else 
                       '*' if p_corr < 0.05 else 
                       'n.s.')
        ax.text((groups.index(g1)+groups.index(g2))/2, y_pos, significance,
                ha='center', va='bottom', fontsize=10)
        
     # --- Add test information to the plot ---
    if show_test_info:
        # Position the text in the top right corner of the plot 

        "Student's t-test" if equal_var else "Welch's t-test"
        if test_type == "Mann-Whitney U":
            test_info = "MWU"

        elif test_type == "Student's t-test":
            test_info = "STT"

        elif test_type == "Welch's t-test":
            test_info = "WTT"

        ax.text(test_info_off_set, 0.98, test_info,
                transform=ax.transAxes,
                ha='right', va='top',
                fontsize=10,
                bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', boxstyle='round,pad=0.2'))
    
    return {
        'test_type': test_type,
        'normality': norm_results,
        'equal_variance': levene_p if normal_groups else None,
        'pairwise': {
            (g1, g2): {
                'stat': pairwise_results[(g1, g2)].statistic,
                'p_raw': pairwise_results[(g1, g2)].pvalue,
                'p_corr': p_corr
            } for (g1, g2), p_corr in zip(pairs, corrected_pvals)
        }
    }

def add_statistical_annotations(ax, df, var, group_column='imaging_day', 
                              test_func=mannwhitneyu, correction_method='bonferroni',
                              line_height_factor=0.05, line_spacing_factor=0.10,
                              star_fontsize=10, line_y_offset=0.04):
    """
    Add statistical annotations to a plot with grouped data.
    
    Parameters:
    -----------
    ax : matplotlib.axes.Axes
        The axes to add annotations to
    df : pandas.DataFrame
        DataFrame containing the data
    var : str
        The variable being plotted (column name in df)
    group_column : str, optional
        The column used for grouping (default: 'imaging_day')
    test_func : function, optional
        Statistical test function (default: mannwhitneyu)
    correction_method : str, optional
        Multiple testing correction method (default: 'bonferroni'). Could also consider: fdr_bh
    line_height_factor : float, optional
        Factor to determine height of annotation lines relative to y-range (default: 0.05)
    line_spacing_factor : float, optional
        Factor to determine vertical spacing between annotations (default: 0.07)
    star_fontsize : int, optional
        Font size for significance stars (default: 10)
    line_y_offset : float, optional
        Additional offset to lower all lines (default: 0.02)
    """
    # Get unique groups
    groups = sorted(df[group_column].unique())
    
    # ==== Statistical Testing and Annotation ====
    # Pairwise tests for all combinations of groups
    pairs = list(itertools.combinations(groups, 2))
    p_values = []
    test_results = []

    for group1, group2 in pairs:
        vals1 = df[df[group_column] == group1][var]
        vals2 = df[df[group_column] == group2][var]
        stat, p = test_func(vals1, vals2, alternative='two-sided')
        p_values.append(p)
        test_results.append(((group1, group2), p))

    # Multiple testing correction
    if len(p_values) > 0:
        reject, corrected_pvals, _, _ = multipletests(p_values, method=correction_method)
    
        # Add annotations
        y_max = df[var].max()
        y_min = df[var].min()
        y_range = y_max - y_min
        
        # Calculate line positions
        line_height = y_range * line_height_factor
        line_spacing = y_range * line_spacing_factor
        
        # Start position (lower than before)
        start_y = y_max + (y_range * line_y_offset)
        
        for i, ((group1, group2), p_corr) in enumerate(zip(pairs, corrected_pvals)):
            x1 = groups.index(group1)
            x2 = groups.index(group2)
            
            # Calculate y position with increasing spacing
            y = start_y + line_spacing * i
            
            # Determine significance stars
            if p_corr < 0.001:
                stars = '***'
            elif p_corr < 0.01:
                stars = '**'
            elif p_corr < 0.05:
                stars = '*'
            else:
                stars = 'n.s.'

            # Draw the line (lower and shorter)
            line_y = y - (line_height * 0.3)  # Lower the line
            ax.plot([x1, x1, x2, x2], 
                   [line_y, y, y, line_y], 
                   lw=1, color='black')
            
            # Add text (positioned slightly above the line)
            ax.text((x1 + x2) / 2, y + (line_height * 0.1), 
                    stars, ha='center', va='bottom', 
                    fontsize=star_fontsize)

def add_mean_std_annotations(ax, df, var, group_column='imaging_day', 
                            y_offset=-25, fontsize=11):
    """
    Add mean ± std annotations below each box in a boxplot.
    
    Parameters:
    -----------
    ax : matplotlib.axes.Axes
        The axes to add annotations to
    df : pandas.DataFrame
        DataFrame containing the data
    var : str
        The variable being plotted (column name in df)
    group_column : str, optional
        The column used for grouping (default: 'imaging_day')
    y_offset : int, optional
        Vertical offset for annotations (default: -25)
    fontsize : int, optional
        Font size for annotations (default: 11)
    """
    stats = df.groupby(group_column)[var].agg(['mean', 'std'])
    
    for i, group in enumerate(stats.index):
        mean_val = stats.loc[group, 'mean']
        std_val = stats.loc[group, 'std']
        ax.annotate(
            f"μ={mean_val:.2f} ± {std_val:.2f}",
            xy=(i, ax.get_ylim()[0]),
            xytext=(0, y_offset),
            textcoords='offset points',
            ha='center',
            va='top',
            fontsize=fontsize,
            bbox=dict(facecolor='white', alpha=0.9, pad=2, edgecolor='none'),
            zorder=5
        )

def plot_grouped_boxplot(ax, df, var, group_column='imaging_day', 
                        xlabel="Day", ylabel="", title="",
                        color_mapping_func=get_category_color_mapping,
                        lighten_amount=0.7, show_legend=False,
                        scale_y=None,
                        y_min=None,
                        y_max=None,
                        **boxplot_kwargs):
    """
    Create a grouped boxplot with statistical annotations.
    
    Parameters:
    -----------
    ax : matplotlib.axes.Axes
        The axes to plot on
    df : pandas.DataFrame
        DataFrame containing the data
    var : str
        The variable to plot (column name in df)
    group_column : str, optional
        The column used for grouping (default: 'imaging_day')
    xlabel : str, optional
        Label for x-axis (default: "Day")
    ylabel : str, optional
        Label for y-axis (default: "")
    title : str, optional
        Plot title (default: "")
    color_mapping_func : function, optional
        Function that returns color mapping
    lighten_amount : float, optional
        Amount to lighten colors (default: 0.7)
    show_legend : bool, optional
        Whether to show legend (default: False)
    scale_y : bool, optional
        Whether to scale y-axis labels by 10^6 (default: False)
    y_min : float, optional
        Minimum value for y-axis (scaled if scale_y=True)
    y_max : float, optional
        Maximum value for y-axis (scaled if scale_y=True)
    **boxplot_kwargs : dict
        Additional arguments passed to sns.boxplot
    """
    # Get groups and colors
    groups = sorted(df[group_column].unique())
    color_mapping = color_mapping_func(df)
    
    # Create palette with lightened colors
    palette = [
        lighten_color(color_mapping[f'Day {group} (N={(df[group_column]==group).sum()})'], amount=lighten_amount)
        for group in groups
    ]
    
    # Default boxplot properties
    default_kwargs = {
        'width': 0.6,
        'linewidth': 2,
        'fliersize': 5,
        'showmeans': False,
        'medianprops': {"color": "black", "linestyle": '--', "linewidth": 2},
        'palette': palette
    }
    default_kwargs.update(boxplot_kwargs)
    
    # Create boxplot
    sns.boxplot(
        data=df,
        x=group_column,
        y=var,
        ax=ax,
        **default_kwargs
    )
    
    # Apply scaling to y-axis labels if requested
    if scale_y is not None:
        ax.yaxis.set_major_formatter(FuncFormatter(lambda x, pos: x/scale_y))
        # Scale the y-limits if they were provided
        if y_min is not None:
            y_min = y_min * scale_y
        if y_max is not None:
            y_max = y_max * scale_y
    
    # Apply y-axis limits if specified
    if y_min is not None or y_max is not None:
        ax.set_ylim(bottom=y_min, top=y_max)
    
    # Set labels and style
    ax.set_xlabel(xlabel, fontsize=12, labelpad=4)
    ax.set_ylabel(ylabel, fontsize=12, labelpad=4)
    if title:
        ax.set_title(title, fontsize=14, pad=10)
    ax.tick_params(axis='x', labelsize=9)
    ax.tick_params(axis='y', labelsize=9)
    ax.grid(axis='y', linestyle='--', alpha=0.5)

    if show_legend:
        ax.legend(fontsize=12, loc='upper right')

def plot_side_by_side_histograms(ax, df, value_column, x_min=None, x_max=None, y_min=None, y_max=None, bins=None, leged_loc='upper right', x_lab=r"Cross-sectional shape index $\alpha$", title="Lumen cross-section shape"):

    # Mean/median lines (ensure they're above bars but below legend)
    mean_val = df[value_column].mean()
    median_val = df[value_column].median()
    ax.axvline(mean_val, color='red', linestyle='--', linewidth=1.5, label=f'Mean = {mean_val:.2f}', zorder=4)
    # ax.axvline(median_val, color='black', linestyle='--', linewidth=1.5, label=f'Median = {median_val:.2f}', zorder=2)

     # --- Set x-axis limits if provided ---
    if x_min is not None or x_max is not None:
        ax.set_xlim(left=x_min, right=x_max)

    # --- Set y-axis limits if provided ---
    if y_min is not None or y_max is not None:
        ax.set_ylim(bottom=y_min, top=y_max)  # ← New: Apply y-axis limits

    data_groups = []
    labels = []
    colors = []

    category_color_map = get_category_color_mapping(df, color_palette="viridis", 
        manual_colors={
            'Day 2': '#440154',
            'Day 3': '#2a788e',
            'Day 4': '#7ad151',
            'Control': '#414487', 
            'Treatment': '#22a884'
        })

    if 'control' in df.columns:
        for control_val in [0, 1]:
            mask = df['control'] == control_val
            if mask.sum() > 0:
                label = f'Treatment (N={mask.sum()})' if control_val == 0 else f'Control (N={mask.sum()})'
                data_groups.append(df.loc[mask, value_column])
                labels.append(label)
                colors.append(category_color_map[label])

    if 'imaging_day' in df.columns:
        days = sorted(df[df['control'] == -1]['imaging_day'].unique())
        for day in days:
            mask = (df['control'] == -1) & (df['imaging_day'] == day)
            if mask.sum() > 0:
                label = f'Day {day} (N={mask.sum()})'
                data_groups.append(df.loc[mask, value_column])
                labels.append(label)
                colors.append(category_color_map[label])

    # --- Add grid lines FIRST (before plotting bars) ---
    ax.grid(True, which='both', axis='both', 
            linestyle='--', linewidth=0.5, color='lightgrey', alpha=0.7, zorder=0)

    # Determine bins based on all data
    if bins == None:
        bins ='auto'
    else:
        bins = bins
        
    all_data = pd.concat(data_groups)
    bin_edges = np.histogram_bin_edges(all_data, bins=bins)
    bin_width = bin_edges[1] - bin_edges[0]
    
    # Plot each group with offset positions
    n_groups = len(data_groups)
    group_width = bin_width / (n_groups + 0.2)
    
    for i, (data, label, color) in enumerate(zip(data_groups, labels, colors)):
        offset = (i - (n_groups-1)/2) * group_width
        adjusted_bins = bin_edges[:-1] + offset
        
        counts, _ = np.histogram(data, bins=bin_edges)
        ax.bar(adjusted_bins, counts, width=group_width, 
               alpha=0.85, label=label, color=color,
               edgecolor='white', linewidth=0.5, zorder=3)  # zorder > grid's zorder

    ax.set_xlabel(x_lab, fontsize=12,labelpad=4)
    ax.set_ylabel("Count", fontsize=12, labelpad=4)
    # ax.set_title(title, fontsize=14, pad=10)
    ax.tick_params(axis='x', labelsize=9)
    ax.tick_params(axis='y', labelsize=9)

    # Optional: Uncomment if legend is needed
    # ax.legend(fontsize=12, loc=leged_loc)

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image

def crop_right(image_path, pixels_to_crop=300):
    """
    Crop an image from the left side and return as array for plotting.
    
    Args:
        image_path: Path to input PNG
        pixels_to_crop: Number of pixels to remove from left
        
    Returns:
        Numpy array of the cropped image ready for plt.imshow()
    """
    with Image.open(image_path) as img:
        width, height = img.size
        crop_box = (0, 0, width-pixels_to_crop, height)
        cropped_img = img.crop(crop_box)
        return np.array(cropped_img)  # Convert to numpy array for matplotlib

In [None]:
t_alpha = 1.0
t_iou = 0.0

variables_to_plot = [
    {'var': 'lumen_sphericity_mesh', 'ylabel': 'Lumen Sphericity'},
    {'var': 'epithelium_lumen_v_ratio_mesh', 'ylabel': 'Epithelium-lumen volume ratio'},
    {'var': 'v_epithelium_mesh', 'ylabel': 'Epithelial volume V μm^3'},
    {'var': 'v_inner_mesh', 'ylabel': 'Lumen volume S μm^3'},
    {'var': 'relative_thickness_mesh', 'ylabel': 'Relative Thickness δ = h/r'},
    {'var': 'mean_height_ellipsoid', 'ylabel': 'Cell Height h μm'},
    {'var': 'aspect_ratio', 'ylabel': 'Aspect Ratio ε = c/√(ab)'}
]


differentiation_circular_cross_section_subset_iou = lumenoids_output_diff_batch_2[(lumenoids_output_diff_batch_2.most_similar_lumen_axis_value <= t_alpha )& (lumenoids_output_diff_batch_2.iou_inner_mesh >= t_iou)]
n = differentiation_circular_cross_section_subset_iou.shape[0]

print(f"THRESHOLDS:{t_alpha}, {t_iou} | N:{n}")

# --- Main Plotting ---

fig = plt.figure(figsize=(7, 60))

# Positions: [left, bottom, width, height] in 0-1 range

# image1_ax = fig.add_axes([0.00, 0.94, 0.15, 0.15])
# image2_ax = fig.add_axes([0.15, 0.94, 0.15, 0.15])

# image3_ax = fig.add_axes([0.34, 0.935, 0.16, 0.16])
# image4_ax = fig.add_axes([0.50, 0.935, 0.16, 0.16])

# image5_ax = fig.add_axes([0.70, 0.9375, 0.16, 0.155])
# image6_ax = fig.add_axes([0.85, 0.9375, 0.16, 0.155])

ax1 = fig.add_axes([0.0, 0.955, 0.425, 0.03]) 
ax1.text(0.65, 0.65, r'$\alpha = \frac{|a-b|}{\sqrt{ab}}$', transform=ax1.transAxes, fontsize=16, verticalalignment='top')
ax2 = fig.add_axes([0.575, 0.955, 0.425, 0.03])

ax3 = fig.add_axes([0.0, 0.89, 0.15, 0.05])
ax4 = fig.add_axes([0.275, 0.89, 0.15, 0.05])
ax5 = fig.add_axes([0.575, 0.89, 0.15, 0.05])
ax6 = fig.add_axes([0.85, 0.89, 0.15, 0.05])

ax7 = fig.add_axes([0.1375, 0.825, 0.15, 0.05])
ax8 = fig.add_axes([0.4125, 0.825, 0.15, 0.05])
ax9 = fig.add_axes([0.6875, 0.825, 0.15, 0.05])

# ax10 = fig.add_axes([0.95, 0.825, 0.15, 0.05])

axes = [ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8, ax9] # image1_ax, image3_ax, image5_ax,
letters = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L']

# Add reference letters to each subplot
for ax, letter in zip(axes, letters):
    if letter =='A'or letter == 'B':
        ax.text(-0.18, 1.2, letter, transform=ax.transAxes, 
                fontsize=16, fontweight='bold', va='top', ha='right')
    else:
        ax.text(-0.1, 1.12, letter, transform=ax.transAxes, 
                fontsize=16, fontweight='bold', va='top', ha='right')

# # --- Add the image ---
# img1 = crop_right("/Users/ischneider/Switzerland/ETH/Thesis/CalTechData/raw_data/Differentiation_over_time/actin/d2_sample_33.png", 20)
# image1_ax.imshow(img1)
# image1_ax.axis('off')
# img2 = crop_right("/Users/ischneider/Switzerland/ETH/Thesis/CalTechData/raw_data/Differentiation_over_time/actin/d2_sample_48.png", 20)
# image2_ax.imshow(img2)
# image2_ax.axis('off')

# img3 = crop_right("/Users/ischneider/Switzerland/ETH/Thesis/CalTechData/raw_data/Differentiation_over_time/actin/d3_sample_114.png", 10)
# image3_ax.imshow(img3)
# image3_ax.axis('off')
# img4 = crop_right("/Users/ischneider/Switzerland/ETH/Thesis/CalTechData/raw_data/Differentiation_over_time/actin/d3_sample_142.png", 10)
# image4_ax.imshow(img4)
# image4_ax.axis('off')

# img5 = crop_right("/Users/ischneider/Switzerland/ETH/Thesis/CalTechData/raw_data/Differentiation_over_time/actin/d4_sample_88.png", 20)
# image5_ax.imshow(img5)
# image5_ax.axis('off')
# img6 = crop_right("/Users/ischneider/Switzerland/ETH/Thesis/CalTechData/raw_data/Differentiation_over_time/actin/d4_sample_98.png", 20)
# image6_ax.imshow(img6)
# image6_ax.axis('off')

# --- Generate plots ---

plot_side_by_side_histograms(ax1, lumenoids_output_diff_batch_2,'most_similar_lumen_axis_value',-0.07, 0.8, 0.0, 16, bins=9)

plot_side_by_side_histograms(ax2, lumenoids_output_diff_batch_2,'iou_inner_mesh', 0.07, 0.8, 0.0, 12.5, leged_loc='upper left', x_lab="Quality of ellipsoid fit [IoU]", title="Ellipsoid fit for the lumen", bins=9)

plot_grouped_boxplot(ax3, differentiation_circular_cross_section_subset_iou, "v_epithelium_mesh", 
                ylabel="Epithelial Volume $S$ $[10^6 \mu m^3]$", scale_y=1e6,)
add_mwu_annotations(ax3, differentiation_circular_cross_section_subset_iou, "v_epithelium_mesh", initial_offset_factor=0.01)

plot_grouped_boxplot(ax4, differentiation_circular_cross_section_subset_iou, "v_inner_mesh", 
                ylabel="Lumen Volume $V$ $[10^5 \mu m^3]$", scale_y=1e5)
add_mwu_annotations(ax4, differentiation_circular_cross_section_subset_iou, "v_inner_mesh", initial_offset_factor=0.03)

plot_grouped_boxplot(ax5, differentiation_circular_cross_section_subset_iou, "epithelium_lumen_v_ratio_mesh", 
                ylabel="Epithelium-lumen Volume Ratio $S/V$")
add_mwu_annotations(ax5, differentiation_circular_cross_section_subset_iou, "epithelium_lumen_v_ratio_mesh", initial_offset_factor=0.07)

plot_grouped_boxplot(ax6, differentiation_circular_cross_section_subset_iou, "mean_height_ellipsoid", 
                ylabel="Cell Height $h$ $[\mu m]$", y_min=25, y_max = 80)
add_mwu_annotations(ax6, differentiation_circular_cross_section_subset_iou, "mean_height_ellipsoid", initial_offset_factor=0.03)

plot_grouped_boxplot(ax7, differentiation_circular_cross_section_subset_iou, "relative_thickness_mesh", 
                ylabel="Relative Thickness ($\delta = h/r$)")
add_mwu_annotations(ax7, differentiation_circular_cross_section_subset_iou, "relative_thickness_mesh")

plot_grouped_boxplot(ax8, differentiation_circular_cross_section_subset_iou, "lumen_sphericity_mesh", 
                ylabel=r"Lumen Sphericity $\Psi$")
add_mwu_annotations(ax8, differentiation_circular_cross_section_subset_iou, "lumen_sphericity_mesh", initial_offset_factor=0.05)

plot_grouped_boxplot(ax9, differentiation_circular_cross_section_subset_iou, "aspect_ratio", 
                ylabel="Aspect Ratio ($\epsilon = c/ \sqrt{ab}$)",y_min=0., y_max = 7)
add_mwu_annotations(ax9, differentiation_circular_cross_section_subset_iou, "aspect_ratio", initial_offset_factor=0.05)

# plot_grouped_boxplot(ax10, differentiation_circular_cross_section_subset_iou, "inner_ellipsoid_shortest_longest_axis_ratio", 
#                 ylabel="Shortest to longest lumen axis ratio")
# add_mwu_annotations(ax10, differentiation_circular_cross_section_subset_iou, "inner_ellipsoid_shortest_longest_axis_ratio", initial_offset_factor=0.05)


# Get the actual color mapping from your data
color_mapping_hist = get_category_color_mapping(lumenoids_output_diff_batch_2)

# Create legend elements
legend_elements = []
for category, color in color_mapping_hist.items():
    legend_elements.append(
        Patch(facecolor=lighten_color(color, amount=0.85), edgecolor='white', label=category)
    )

legend_elements.extend([
    # Statistical annotations
    Line2D([0], [0], color='red', linestyle='--', lw=1.5, 
           label='Mean'),
    # Line2D([0], [0], color='black', linestyle='--', lw=1.5, 
    #        label='Median'),

    # Box elements
    Patch(facecolor='white', edgecolor='black', linewidth=2, 
          label='IQR (25-75%)', hatch='////'),
    Line2D([0], [0], color='black', linestyle='-', lw=1, 
           label='1.5×IQR'),
    Line2D([0], [0], color='black', linestyle='--', lw=2, 
           label='Median'),
    # Line2D([0], [0], marker='D', color='gray', markersize=8, 
    #        linestyle='None', markeredgecolor='black', label='Mean'),
    # Outliers
    Line2D([0], [0], marker='o', color='black', markersize=5, 
           linestyle='None', markerfacecolor='none', label='Outliers'),
])

legend_style = {
    'loc': 'center',
    'fontsize': 8,
    'framealpha': 1,
    'frameon': False,
    'title_fontsize': 14,
    'borderpad': 0.5,
    'handletextpad': 0.5,    # Reduced from 2 to 0.5 (symbol ↔ text)
    'columnspacing': 1.0,    # Added: space between columns
    'labelspacing': 0.3,     # Added: vertical space between rows
}

legend = fig.legend(
    handles=legend_elements,
    bbox_to_anchor=(0.5, 0.997),
    title='Legend:',
    ncol=len(legend_elements) / 2,
    **legend_style
)

legend.get_title().set_fontweight('bold')
legend.get_title().set_fontsize(10)
legend._legend_box.sep = 8

# --- Save ---
plt.tight_layout()
plt.savefig(f"plots/image_analysis/lumenoids_diff_batch2_all_n69.pdf", bbox_inches='tight', 
            dpi=1200,
            format='pdf')
plt.show()

# Generate and save the table

# Save to CSV for LaTeX import
# stats_table.to_csv(f"stats_table_t{t}_{t_iou}_n{n}.csv", index=False)

# Display the table (optional)
# display(stats_table)

# # Example usage:
# plot_statistics_table(stats_table, title=f"Statistics for thresholds t_alpha={t_alpha}, t_iou={t_iou} (N={n})")
# plt.show()

### Power-law fit between the Epithelium and Lumen Volumees

In [None]:
lumenoids_path = "/Users/ischneider/Switzerland/ETH/Thesis/ThesisSimuCell3D/lumenoids_image_analysis_output_differentiation/lumenoids_data_d2.csv"
d2 = pd.read_csv(lumenoids_path)
d2[["a_inner_ellipsoid", "b_inner_ellipsoid", "c_inner_ellipsoid"]]

In [None]:
lumenoids_path = "/Users/ischneider/Switzerland/ETH/Thesis/ThesisSimuCell3D/lumenoids_image_analysis_output_differentiation/lumenoids_data_d3.csv"
d3 = pd.read_csv(lumenoids_path)
d3[["a_inner_ellipsoid", "b_inner_ellipsoid", "c_inner_ellipsoid"]]

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import linregress

def plot_log_scatter(
    df,
    x_col,
    y_col,
    control_column="control",
    image_day_column="imaging_day",
    x_equals_y=False,
    plot_title=None,
    color_palette="viridis",
    show_means=False,
    export_path=None,
    export_dpi=300,
    units_x=None,
    units_y=None,
    fit_line=False
):
    """
    Plots raw data using matplotlib log-scaled axes.
    Fits a linear model (log(y) = a * log(x) + b) and shows the classic line form: ax + b.
    """
    for col in [x_col, y_col, control_column, image_day_column]:
        if col not in df.columns:
            raise ValueError(f"Column '{col}' not found in DataFrame")
    if (df[x_col] <= 0).any() or (df[y_col] <= 0).any():
        raise ValueError("Log scale requires positive values in x_col and y_col")

    x_data = df[x_col]
    y_data = df[y_col]
    total_n = len(df)

    category_color_map = get_category_color_mapping(
        df, control_column, image_day_column, color_palette
    )

    fig, ax = plt.subplots(figsize=(10, 6))
    ax.set_xscale("log")
    ax.set_yscale("log")

    for label, color in category_color_map.items():
        mask = (
            (df[control_column] == 0) if "Treatment" in label else
            (df[control_column] == 1) if "Control" in label else
            (df[control_column] == -1) & (df[image_day_column] == int(label.split()[1]))
        )
        ax.scatter(x_data[mask], y_data[mask], c=[color], label=label, alpha=0.6)

    # if fit_line:
    #     log_x = np.log10(x_data)
    #     log_y = np.log10(y_data)
    #     slope, intercept, r_value, *_ = linregress(log_x, log_y)

    #     x_vals = np.logspace(np.log10(x_data.min()), np.log10(x_data.max()), 100)
    #     y_vals = 10**intercept * x_vals**slope
    #     ax.plot(x_vals, y_vals, 'r--', 
    #             label=f'Fit in log space: $y = {slope:.2f}x + {intercept:.2f}$\n$(R^2={r_value**2:.2f})$')

    if fit_line:
        log_x = np.log10(x_data)
        log_y = np.log10(y_data)
        slope, intercept, r_value, _, stderr_slope = linregress(log_x, log_y)  # stderr_slope is SE of the slope (α)

        x_vals = np.logspace(np.log10(x_data.min()), np.log10(x_data.max()), 100)
        y_vals = 10**intercept * x_vals**slope
        ax.plot(x_vals, y_vals, 'r--', 
                label=f'Power-law fit: $S \\sim V^{{\\alpha}}$\n$\\alpha = {slope:.2f} \\pm {stderr_slope:.2f}$\n$(R^2={r_value**2:.2f})$')

    if x_equals_y:
        lim_min = min(x_data.min(), y_data.min())
        lim_max = max(x_data.max(), y_data.max())
        ax.plot([lim_min, lim_max], [lim_min, lim_max], 'k--', alpha=0.5, label='$x = y$')

    ax.set_xlabel(f'{x_col}' + (f' [{units_x}]' if units_x else ''))
    ax.set_ylabel(f'{y_col}' + (f' [{units_y}]' if units_y else ''))
    ax.set_title(plot_title or f'{y_col} vs {x_col} (log-log)')
    ax.legend(title=f'Categories (Total N={total_n})', bbox_to_anchor=(1.05, 1))
    plt.tight_layout()

    if export_path:
        fig.savefig(export_path, dpi=export_dpi, bbox_inches='tight')

    return fig, ax


In [None]:
fig, ax = plot_log_scatter(
    differentiation_experiment,
    y_col="v_epithelium_mesh",
    x_col="v_inner_mesh",
    units_x = "$\mu m^3$",
    units_y = "$\mu m^3$",
    plot_title="Epithelium vs Lumen Volume ",
    # log_x=True,
    # log_y=True,
    fit_line=True
    # export_path="/Users/ischneider/Switzerland/ETH/Thesis/Plots/imaging_data/lum_v_height.png"
)
plt.show()
