In [None]:
import pandas as pd

# Load the CSV file
df = pd.read_csv(
    "/home/jupyter/igor_repos/exploration/noise_scaling_laws/Scaling-up-measurement-noise-scaling-laws/collect_mi_results.csv"
)

In [None]:
# Rename columns for clarity
df = df.rename(columns={"algorithm": "method", "signal": "metric"})
df

### fitting noise scaling laws

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Create the plot for cell scaling laws
if 'final_results' in locals() and len(final_results) > 0:
    # Get unique combinations of metric and method
    metrics = final_results["metric"].unique()
    methods = final_results["method"].unique()

    # Create subplots
    fig, axes = plt.subplots(len(metrics), len(methods), figsize=(4 * len(methods), 4 * len(metrics)))

    # Handle case where there's only one metric or method
    if len(metrics) == 1 and len(methods) == 1:
        axes = np.array([[axes]])
    elif len(metrics) == 1:
        axes = axes.reshape(1, -1)
    elif len(methods) == 1:
        axes = axes.reshape(-1, 1)

    # Color scheme for different quality levels (UMI counts)
    qualities = sorted(final_results["quality"].unique())
    colors = plt.cm.tab10(np.linspace(0, 1, len(qualities)))
    quality_color_map = dict(zip(qualities, colors))

    for i, metric in enumerate(metrics):
        for j, method in enumerate(methods):
            ax = axes[i, j]

            # Filter data for this metric and method combination
            subset = final_results[(final_results["metric"] == metric) & (final_results["method"] == method)]

            if len(subset) > 0:
                # Plot points for each quality level
                for quality in qualities:
                    quality_data = subset[subset["quality"] == quality]
                    if len(quality_data) > 0:
                        # Group by size to handle multiple seeds
                        grouped = quality_data.groupby("size")

                        x_vals = []
                        y_means = []
                        y_stds = []

                        for size, group in grouped:
                            x_vals.append(size)
                            mi_values = group["mi_value"]
                            y_means.append(mi_values.mean())
                            y_stds.append(mi_values.std() if len(mi_values) > 1 else 0)

                        x_vals = np.array(x_vals)
                        y_means = np.array(y_means)
                        y_stds = np.array(y_stds)

                        # Sort by x_vals for proper line connection
                        sort_idx = np.argsort(x_vals)
                        x_vals = x_vals[sort_idx]
                        y_means = y_means[sort_idx]
                        y_stds = y_stds[sort_idx]

                        # Plot error bars if there are multiple seeds, otherwise scatter points
                        if np.any(y_stds > 0):
                            ax.errorbar(
                                x_vals,
                                y_means,
                                yerr=y_stds,
                                color=quality_color_map[quality],
                                fmt="o",
                                capsize=5,
                                capthick=2,
                                alpha=0.7,
                                markersize=6,
                                label=f"Quality {quality:.3f}",
                            )
                        else:
                            ax.scatter(
                                x_vals,
                                y_means,
                                color=quality_color_map[quality],
                                alpha=0.7,
                                s=50,
                                label=f"Quality {quality:.3f}",
                            )

                        # Connect points with lines for each quality level
                        ax.plot(x_vals, y_means, color=quality_color_map[quality], linestyle="-", linewidth=1.5, alpha=0.8)

                        # Plot fitted curve for this quality with uncertainty bands (if available)
                        if 'A_cell' in quality_data.columns and not quality_data['A_cell'].isna().all():
                            size_range = np.linspace(quality_data["size"].min(), quality_data["size"].max(), 100)
                            # Use the fitted parameters (they should be the same for all points with same quality/metric/method)
                            A_cell = quality_data["A_cell"].iloc[0]
                            A_cell_err = quality_data["A_cell_err"].iloc[0]
                            B_cell = quality_data["B_cell"].iloc[0]
                            B_cell_err = quality_data["B_cell_err"].iloc[0]
                            C_cell = quality_data["C_cell"].iloc[0]
                            C_cell_err = quality_data["C_cell_err"].iloc[0]

                            # Calculate fitted curve using cell_scaling
                            fitted_curve = cell_scaling(size_range, A_cell, B_cell, C_cell)

                            # Calculate uncertainty bands using error propagation
                            # For I(n) = C - (n/A)^(-B)
                            # Partial derivatives for error propagation
                            term = (size_range / A_cell) ** (-B_cell)
                            dI_dA = -B_cell * term / A_cell
                            dI_dB = term * np.log(size_range / A_cell)
                            dI_dC = np.ones_like(size_range)

                            # Error propagation: σ_I² = (∂I/∂A)²σ_A² + (∂I/∂B)²σ_B² + (∂I/∂C)²σ_C²
                            uncertainty = np.sqrt(
                                (dI_dA * A_cell_err) ** 2 + (dI_dB * B_cell_err) ** 2 + (dI_dC * C_cell_err) ** 2
                            )

                            # Plot uncertainty band (2 sigma)
                            ax.fill_between(
                                size_range,
                                fitted_curve - 2 * uncertainty,
                                fitted_curve + 2 * uncertainty,
                                color=quality_color_map[quality],
                                alpha=0.2,
                            )

                            # Plot fitted line
                            ax.plot(size_range, fitted_curve, color=quality_color_map[quality], linestyle="--", linewidth=2, alpha=0.9)

                ax.set_xlabel("Cell number")
                ax.set_ylabel("MI value")
                ax.set_title(f"{metric}\n{method}")
                ax.set_xscale("log")
                ax.grid(True, alpha=0.3)

                # Add legend only to the first subplot
                if i == 0 and j == 0:
                    ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
            else:
                ax.set_title(f"{metric}\n{method}\n(No data)")

    plt.tight_layout()
    plt.show()
else:
    print("No fitted data available for plotting")


In [None]:
import matplotlib.pyplot as plt
import numpy as np
from lmfit import Model, Parameters


def info_scaling(x, A, B):
    """
    Information scaling function: I(x) = 0.5 * log2((x*B + 1)/(1 + A*x))
    """
    return 0.5 * np.log2((x * B + 1) / (1 + A * x))


def fit_noise_scaling_model(u_values, mi_values):
    """
    Fit the noise scaling model to data and return I_max and u_bar with uncertainties.

    Parameters:
    u_values: array of UMI per cell values
    mi_values: array of mutual information values

    Returns:
    dict with I_max, u_bar and their uncertainties, plus fit success status
    """

    # Define the noise scaling function for fitting
    def info_scaling_local(x, A, B):
        """
        Information scaling function: I(x) = 0.5 * log2((x*B + 1)/(1 + A*x))
        """
        return 0.5 * np.log2((x * B + 1) / (1 + A * x))

    def info_max(A, B):
        """Calculate maximum information"""
        return 0.5 * np.log2(B / A)

    # Create lmfit model
    model = Model(info_scaling_local)

    # Set up parameters with initial values and bounds
    params = model.make_params(A=1e-2, B=1e-2)  # initial guesses
    params["A"].min = 0  # set bounds
    params["B"].min = 0

    # Fit the curve
    try:
        result = model.fit(mi_values, params, x=u_values)
        A_info = result.params["A"].value
        B_info = result.params["B"].value
        A_err = result.params["A"].stderr
        B_err = result.params["B"].stderr

        # Calculate derived quantities
        I_max = info_max(A_info, B_info)
        u_bar = 1 / A_info

        # Calculate uncertainties using error propagation
        # For I_max = 0.5 * log2(B/A), error propagation gives:
        I_max_err = 0.5 * np.sqrt((A_err / A_info) ** 2 + (B_err / B_info) ** 2) / np.log(2)

        # For u_bar = 1/A, error is |du_bar/dA| * A_err = A_err/A^2
        u_bar_err = A_err / (A_info**2)

        return {
            "I_max": I_max,
            "I_max_err": I_max_err,
            "u_bar": u_bar,
            "u_bar_err": u_bar_err,
            "fit_success": True,
            "result": result,
            "A_info": A_info,
            "B_info": B_info,
        }

    except Exception as e:
        print(f"    Fitting failed: {e}")
        return {
            "I_max": np.nan,
            "I_max_err": np.nan,
            "u_bar": np.nan,
            "u_bar_err": np.nan,
            "fit_success": False,
            "result": None,
        }

In [None]:
# Create a loop to iterate over different datasets, methods, metrics and cell sizes
import os

# Get unique combinations of dataset, size, method, and metric
unique_combinations = df.groupby(["dataset", "size", "method", "metric"]).size().reset_index(name="count")

print(f"Found {len(unique_combinations)} unique combinations to fit")

# Initialize list to store all results
all_fit_results = []

# Loop through each combination
for idx, row in unique_combinations.iterrows():
    dataset = row["dataset"]
    size = row["size"]
    method = row["method"]
    metric = row["metric"]

    print(f"Processing {idx+1}/{len(unique_combinations)}: {dataset}, {size}, {method}, {metric}")

    # Filter data for current combination
    filtered_data = df[
        (df["dataset"] == dataset) & (df["size"] == size) & (df["method"] == method) & (df["metric"] == metric)
    ]

    # Skip if not enough data points
    if len(filtered_data) < 3:
        print(f"  Skipping - insufficient data points ({len(filtered_data)})")
        continue

    # Extract data for fitting
    u_values = filtered_data["umis_per_cell"].values
    mi_values = filtered_data["mi_value"].values

    # Fit the model
    fit_results = fit_noise_scaling_model(u_values, mi_values)

    if fit_results["fit_success"]:
        print(f"  Success: I_max={fit_results['I_max']:.4f}, u_bar={fit_results['u_bar']:.1f}")

        # Calculate fitted MI values for the original data points
        fitted_mi_values = info_scaling(u_values, fit_results["A_info"], fit_results["B_info"])

        # Add new columns to the filtered data
        filtered_data_with_fits = filtered_data.copy()
        filtered_data_with_fits["fitted_mi_value"] = fitted_mi_values
        filtered_data_with_fits["I_max"] = fit_results["I_max"]
        filtered_data_with_fits["I_max_err"] = fit_results["I_max_err"]
        filtered_data_with_fits["u_bar"] = fit_results["u_bar"]
        filtered_data_with_fits["u_bar_err"] = fit_results["u_bar_err"]
        filtered_data_with_fits["A_fit"] = fit_results["A_info"]
        filtered_data_with_fits["B_fit"] = fit_results["B_info"]

        # Append to results list
        all_fit_results.append(filtered_data_with_fits)

    else:
        print(f"  Failed: Fit unsuccessful")

# Combine all results into a single DataFrame
if all_fit_results:
    combined_results = pd.concat(all_fit_results, ignore_index=True)

    # Save to CSV
    output_filename = "/home/jupyter/igor_repos/exploration/noise_scaling_laws/Scaling-up-measurement-noise-scaling-laws/fitted_data_with_results.csv"
    combined_results.to_csv(output_filename, index=False)
    print(f"\nAll fitted data saved to: {output_filename}")

    # Display summary
    print(f"\nSuccessfully fitted {len(all_fit_results)} combinations")
    print(f"Total data points with fits: {len(combined_results)}")

    # Display first few rows
    print("\nFirst few rows of combined results:")
    display(combined_results.head())

else:
    print("\nNo successful fits to save")

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Create the plot
if all_fit_results:
    # Get unique combinations of metric and method
    metrics = combined_results["metric"].unique()
    methods = combined_results["method"].unique()

    # Create subplots
    fig, axes = plt.subplots(len(metrics), len(methods), figsize=(4 * len(methods), 4 * len(metrics)))

    # Handle case where there's only one metric or method
    if len(metrics) == 1 and len(methods) == 1:
        axes = np.array([[axes]])
    elif len(metrics) == 1:
        axes = axes.reshape(1, -1)
    elif len(methods) == 1:
        axes = axes.reshape(-1, 1)

    # Color scheme for different sizes
    sizes = sorted(combined_results["size"].unique())
    colors = plt.cm.tab10(np.linspace(0, 1, len(sizes)))
    size_color_map = dict(zip(sizes, colors))

    for i, metric in enumerate(metrics):
        for j, method in enumerate(methods):
            ax = axes[i, j]

            # Filter data for this metric and method combination
            subset = combined_results[(combined_results["metric"] == metric) & (combined_results["method"] == method)]

            if len(subset) > 0:
                # Plot points for each size
                for size in sizes:
                    size_data = subset[subset["size"] == size]
                    if len(size_data) > 0:
                        # Group by umis_per_cell to handle multiple seeds
                        grouped = size_data.groupby("umis_per_cell")

                        x_vals = []
                        y_means = []
                        y_stds = []

                        for umis, group in grouped:
                            x_vals.append(umis)
                            mi_values = group["mi_value"]
                            y_means.append(mi_values.mean())
                            y_stds.append(mi_values.std() if len(mi_values) > 1 else 0)

                        x_vals = np.array(x_vals)
                        y_means = np.array(y_means)
                        y_stds = np.array(y_stds)

                        # Plot error bars if there are multiple seeds, otherwise scatter points
                        if np.any(y_stds > 0):
                            ax.errorbar(
                                x_vals,
                                y_means,
                                yerr=y_stds,
                                color=size_color_map[size],
                                fmt="o",
                                capsize=5,
                                capthick=2,
                                alpha=0.7,
                                markersize=6,
                                label=f"Size {size}",
                            )
                        else:
                            ax.scatter(
                                x_vals,
                                y_means,
                                color=size_color_map[size],
                                alpha=0.7,
                                s=50,
                                label=f"Size {size}",
                            )

                        # Plot fitted line for this size with uncertainty bands
                        u_range = np.linspace(size_data["umis_per_cell"].min(), size_data["umis_per_cell"].max(), 100)
                        # Use the fitted parameters (they should be the same for all points with same size/metric/method)
                        A_fit = size_data["A_fit"].iloc[0]
                        B_fit = size_data["B_fit"].iloc[0]
                        I_max = size_data["I_max"].iloc[0]
                        I_max_err = size_data["I_max_err"].iloc[0]
                        u_bar = size_data["u_bar"].iloc[0]
                        u_bar_err = size_data["u_bar_err"].iloc[0]

                        # Calculate fitted curve
                        fitted_curve = info_scaling(u_range, A_fit, B_fit)

                        # Calculate uncertainty bands using error propagation
                        # For I(u) = I_max * (1 - exp(-u/u_bar))
                        # We need to propagate errors from I_max and u_bar

                        # Partial derivatives for error propagation
                        exp_term = np.exp(-u_range / u_bar)
                        dI_dImax = 1 - exp_term
                        dI_dubar = I_max * exp_term * u_range / (u_bar**2)

                        # Error propagation: σ_I² = (∂I/∂I_max)²σ_I_max² + (∂I/∂u_bar)²σ_u_bar²
                        uncertainty = np.sqrt((dI_dImax * I_max_err) ** 2 + (dI_dubar * u_bar_err) ** 2)

                        # Plot uncertainty band (2 sigma)
                        ax.fill_between(
                            u_range,
                            fitted_curve - 2 * uncertainty,
                            fitted_curve + 2 * uncertainty,
                            color=size_color_map[size],
                            alpha=0.2,
                        )

                        # Plot fitted line
                        ax.plot(u_range, fitted_curve, color=size_color_map[size], linestyle="-", linewidth=2)

                ax.set_xlabel("UMIs per cell")
                ax.set_ylabel("MI value")
                ax.set_title(f"{metric}\n{method}")
                ax.set_xscale("log")
                ax.grid(True, alpha=0.3)

                # Add legend only to the first subplot
                if i == 0 and j == 0:
                    ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
            else:
                ax.set_title(f"{metric}\n{method}\n(No data)")

    plt.tight_layout()
    plt.show()
else:
    print("No fitted data available for plotting")

### fitting cell scaling laws

In [None]:
# Create a loop to iterate over different datasets, methods, metrics and cell sizes
import os

import matplotlib.pyplot as plt
import numpy as np
from lmfit import Model, Parameters


def cell_number_scaling(x, A, B, C):
    """
    Cell number scaling function: f(x) = C - (x/A)^(-B)
    """
    return C - (x / A) ** (-B)


def fit_cell_scaling_model(x_values, mi_values, mi_key=None):
    """
    Fit the cell number scaling model to data and return parameters with uncertainties.

    Parameters:
    x_values: array of cell number values
    mi_values: array of mutual information values
    mi_key: string indicating which MI metric is being fitted (for parameter initialization)

    Returns:
    dict with A, B, C parameters and their uncertainties, plus fit success status
    """

    # Create lmfit model
    model = Model(cell_number_scaling)

    # Set up parameters with initial values and bounds
    if mi_key == "Spatial neighborhood MI":
        params = model.make_params(
            A=dict(value=10**3, min=1e-6),
            B=dict(value=1, min=1e-3),
            C=dict(value=1, min=mi_values.max(), max=mi_values.max() * 2),
        )
    else:
        params = model.make_params(
            A=dict(value=10**4, min=1e-6),
            B=dict(value=1, min=1e-6),
            C=dict(value=1, min=mi_values.max(), max=mi_values.max() * 1.5),
        )

    # Fit the curve
    try:
        result = model.fit(mi_values, params, x=x_values)
        A_fit = result.params["A"].value
        B_fit = result.params["B"].value
        C_fit = result.params["C"].value
        A_err = result.params["A"].stderr
        B_err = result.params["B"].stderr
        C_err = result.params["C"].stderr

        return {
            "A": A_fit,
            "A_err": A_err,
            "B": B_fit,
            "B_err": B_err,
            "C": C_fit,
            "C_err": C_err,
            "fit_success": True,
            "result": result,
        }

    except Exception as e:
        return {
            "A": np.nan,
            "A_err": np.nan,
            "B": np.nan,
            "B_err": np.nan,
            "C": np.nan,
            "C_err": np.nan,
            "fit_success": False,
            "result": None,
        }


# Get unique combinations of dataset, method, and metric (cell size scaling)
unique_combinations = df.groupby(["dataset", "method", "metric"]).size().reset_index(name="count")

# Initialize list to store all results
all_cell_fit_results = []
# Initialize list to store fit parameters summary
cell_fit_parameters_summary = []

# Loop through each combination
for idx, row in unique_combinations.iterrows():
    dataset = row["dataset"]
    method = row["method"]
    metric = row["metric"]

    # Filter data for current combination
    filtered_data = df[(df["dataset"] == dataset) & (df["method"] == method) & (df["metric"] == metric)]

    # Group by size and average MI values across seeds for each size
    size_averaged_data = filtered_data.groupby("size")["mi_value"].mean().reset_index()

    # Skip if not enough data points
    if len(size_averaged_data) < 3:
        continue

    # Extract data for fitting
    size_values = size_averaged_data["size"].values
    mi_values = size_averaged_data["mi_value"].values

    # Fit the model
    fit_results = fit_cell_scaling_model(size_values, mi_values)

    if fit_results["fit_success"]:
        # Calculate fitted MI values for all original data points
        fitted_mi_values = cell_number_scaling(
            filtered_data["size"].values, fit_results["A"], fit_results["B"], fit_results["C"]
        )

        # Add new columns to the filtered data
        filtered_data_with_fits = filtered_data.copy()
        filtered_data_with_fits["fitted_mi_value_cell"] = fitted_mi_values
        filtered_data_with_fits["A_cell"] = fit_results["A"]
        filtered_data_with_fits["A_cell_err"] = fit_results["A_err"]
        filtered_data_with_fits["B_cell"] = fit_results["B"]
        filtered_data_with_fits["B_cell_err"] = fit_results["B_err"]
        filtered_data_with_fits["C_cell"] = fit_results["C"]
        filtered_data_with_fits["C_cell_err"] = fit_results["C_err"]

        # Append to results list
        all_cell_fit_results.append(filtered_data_with_fits)

        # Store fit parameters summary
        cell_fit_parameters_summary.append(
            {
                "dataset": dataset,
                "method": method,
                "metric": metric,
                "A_cell": fit_results["A"],
                "A_cell_err": fit_results["A_err"],
                "B_cell": fit_results["B"],
                "B_cell_err": fit_results["B_err"],
                "C_cell": fit_results["C"],
                "C_cell_err": fit_results["C_err"],
                "n_size_points": len(size_averaged_data),
            }
        )

# Combine cell scaling results with original data
if all_cell_fit_results:
    combined_cell_results = pd.concat(all_cell_fit_results, ignore_index=True)

    # Merge with existing fitted data if it exists
    if "combined_results" in locals():
        # Merge on all common columns
        merge_cols = ["dataset", "size", "quality", "method", "metric", "seed", "mi_value", "umis_per_cell"]
        final_results = pd.merge(
            combined_results,
            combined_cell_results[
                merge_cols
                + ["fitted_mi_value_cell", "A_cell", "A_cell_err", "B_cell", "B_cell_err", "C_cell", "C_cell_err"]
            ],
            on=merge_cols,
            how="outer",
        )
    else:
        # If no UMI scaling fits exist, use cell scaling results as base
        final_results = combined_cell_results
else:
    # If no cell scaling fits, keep the UMI scaling results as final results
    if "combined_results" in locals():
        final_results = combined_results

# Save final results to CSV
final_results.to_csv(
    "/home/jupyter/igor_repos/exploration/noise_scaling_laws/Scaling-up-measurement-noise-scaling-laws/fitted_data_with_results.csv",
    index=False,
)

final_results

In [None]:
import matplotlib.pyplot as plt
import numpy as np


# Create the plot
if all_fit_results:
    # Get unique combinations of metric and method
    metrics = combined_results["metric"].unique()
    methods = combined_results["method"].unique()

    # Create subplots
    fig, axes = plt.subplots(len(metrics), len(methods), figsize=(4 * len(methods), 4 * len(metrics)))

    # Handle case where there's only one metric or method
    if len(metrics) == 1 and len(methods) == 1:
        axes = np.array([[axes]])
    elif len(metrics) == 1:
        axes = axes.reshape(1, -1)
    elif len(methods) == 1:
        axes = axes.reshape(-1, 1)

    # Color scheme for different sizes
    sizes = sorted(combined_results["size"].unique())
    colors = plt.cm.tab10(np.linspace(0, 1, len(sizes)))
    size_color_map = dict(zip(sizes, colors))

    for i, metric in enumerate(metrics):
        for j, method in enumerate(methods):
            ax = axes[i, j]

            # Filter data for this metric and method combination
            subset = combined_results[(combined_results["metric"] == metric) & (combined_results["method"] == method)]

            if len(subset) > 0:
                # Plot points for each size
                for size in sizes:
                    size_data = subset[subset["size"] == size]
                    if len(size_data) > 0:
                        # Group by umis_per_cell to handle multiple seeds
                        grouped = size_data.groupby("umis_per_cell")

                        x_vals = []
                        y_means = []
                        y_stds = []

                        for umis, group in grouped:
                            x_vals.append(umis)
                            mi_values = group["mi_value"]
                            y_means.append(mi_values.mean())
                            y_stds.append(mi_values.std() if len(mi_values) > 1 else 0)

                        x_vals = np.array(x_vals)
                        y_means = np.array(y_means)
                        y_stds = np.array(y_stds)

                        # Plot error bars if there are multiple seeds, otherwise scatter points
                        if np.any(y_stds > 0):
                            ax.errorbar(
                                x_vals,
                                y_means,
                                yerr=y_stds,
                                color=size_color_map[size],
                                fmt="o",
                                capsize=5,
                                capthick=2,
                                alpha=0.7,
                                markersize=6,
                                label=f"Size {size}",
                            )
                        else:
                            ax.scatter(
                                x_vals,
                                y_means,
                                color=size_color_map[size],
                                alpha=0.7,
                                s=50,
                                label=f"Size {size}",
                            )

                        # Plot fitted line for this size with uncertainty bands
                        u_range = np.linspace(size_data["umis_per_cell"].min(), size_data["umis_per_cell"].max(), 100)
                        # Use the fitted parameters (they should be the same for all points with same size/metric/method)
                        A_fit = size_data["A_fit"].iloc[0]
                        B_fit = size_data["B_fit"].iloc[0]
                        I_max = size_data["I_max"].iloc[0]
                        I_max_err = size_data["I_max_err"].iloc[0]
                        u_bar = size_data["u_bar"].iloc[0]
                        u_bar_err = size_data["u_bar_err"].iloc[0]

                        # Calculate fitted curve
                        fitted_curve = info_scaling(u_range, A_fit, B_fit)

                        # Calculate uncertainty bands using error propagation
                        # For I(u) = I_max * (1 - exp(-u/u_bar))
                        # We need to propagate errors from I_max and u_bar

                        # Partial derivatives for error propagation
                        exp_term = np.exp(-u_range / u_bar)
                        dI_dImax = 1 - exp_term
                        dI_dubar = I_max * exp_term * u_range / (u_bar**2)

                        # Error propagation: σ_I² = (∂I/∂I_max)²σ_I_max² + (∂I/∂u_bar)²σ_u_bar²
                        uncertainty = np.sqrt((dI_dImax * I_max_err) ** 2 + (dI_dubar * u_bar_err) ** 2)

                        # Plot uncertainty band (2 sigma)
                        ax.fill_between(
                            u_range,
                            fitted_curve - 2 * uncertainty,
                            fitted_curve + 2 * uncertainty,
                            color=size_color_map[size],
                            alpha=0.2,
                        )

                        # Plot fitted line
                        ax.plot(u_range, fitted_curve, color=size_color_map[size], linestyle="-", linewidth=2)

                ax.set_xlabel("UMIs per cell")
                ax.set_ylabel("MI value")
                ax.set_title(f"{metric}\n{method}")
                ax.set_xscale("log")
                ax.grid(True, alpha=0.3)

                # Add legend only to the first subplot
                if i == 0 and j == 0:
                    ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
            else:
                ax.set_title(f"{metric}\n{method}\n(No data)")

    plt.tight_layout()
    plt.show()
else:
    print("No fitted data available for plotting")