In [4]:
import numpy as np
from scipy.special import expm1

import numpy as np
from scipy.special import expm1


def calc_growth_rate():

    return (dk + k_pre + m_pre*theta*A)*t_pre + (dk + k_sel + m_sel*theta*A)*t_sel



def quantile_bin(thetas, freqs, n_bins):
    """
    Performs quantile binning on a set of theta values.

    This function sorts the genotypes based on their theta values and partitions
    them into a specified number of bins, with each bin containing an equal
    number of genotypes. It returns the bin assignment for each original
    genotype and the cumulative frequency distribution across the bins.

    Args:
        thetas: np.ndarray
            A 1D array of theta (occupancy) values for each genotype.
        freqs: np.ndarray
            A 1D array of the frequency (p_i(0)) for each genotype, in the
            same order as `thetas`.
        n_bins: int
            The number of bins to partition the genotypes into.

    Returns
    -------
    Q : np.ndarray
        A 1D array of length `n_bins` representing the cumulative frequency
        distribution. Q[i] is the sum of initial frequencies in bins 0 through i.
    bin_indices: np.ndarray
        An integer array of the same length as `thetas`, where each element is
        the bin index (0 to n_bins-1) for the corresponding theta in the input
        array.
    """
    
    # Ensure inputs are numpy arrays for vectorized operations
    thetas = np.asarray(thetas)
    freqs = np.asarray(freqs)
    if thetas.shape != freqs.shape:
        raise ValueError("`thetas` and `freqs` must have the same shape.")

    # Get the indices that would sort the theta array in ascending order
    sorted_indices = np.argsort(thetas)

    # Create an empty array to store the bin assignments for the original array
    bin_indices = np.empty_like(sorted_indices)

    # Create an array representing the bin ID for each item in the sorted list.
    n_genotypes = len(thetas)
    sorted_bin_assignments = np.floor(np.arange(n_genotypes)*n_bins/n_genotypes).astype(int)
    
    # Handle potential floating point inaccuracies at the boundary
    sorted_bin_assignments[sorted_bin_assignments >= n_bins] = n_bins - 1

    # Use the sorted_indices to map these bin assignments back to their
    # original positions. This is an efficient "un-sorting" operation.
    bin_indices[sorted_indices] = sorted_bin_assignments

    # --- Calculate the cumulative distribution function Q ---
    # Sort the initial frequencies according to the theta sort order
    sorted_frequencies = freqs[sorted_indices]

    # Calculate P_b, the sum of frequencies in each bin, using bincount.
    # `bincount` sums the `weights` for each integer `bin`.
    P_b = np.bincount(sorted_bin_assignments, weights=sorted_frequencies, minlength=n_bins)

    # Calculate Q, the cumulative sum of the bin frequencies
    Q = np.cumsum(P_b)

    return Q, bin_indices

def ztp_avg(cum_freq_array, lam=2.5):
    """
    Average over a zero-truncated Poisson distribution. 
    
    Parameters
    ----------
    cum_freq_array: np.ndarray
        Array of cumulative frequencies.
    lam: float
        The lambda of the zero-truncated Poisson.
        
    Returns
    -------
    np.ndarray:
        The result of the S(x) function.
    """
    
    # Pre-compute constants
    c1 = lam * np.exp(-lam)
    c2 = -expm1(-lam)  # Equivalent to (1 - exp(-lam)) but more stable
    
    # Use expm1 for numerical stability of (exp(lam*x) - 1)
    numerator = expm1(lam * cum_freq_array)
    
    # Handle the case where x is very close to zero
    # As x -> 0, (e^(lam*x) - 1)/x -> lam
    # This avoids division by zero and indeterminate 0/0 form.
    # np.divide handles this with its 'where' argument.
    ratio = np.divide(numerator, cum_freq_array, where=(cum_freq_array != 0))
    ratio[x == 0] = lam
    
    return (c1 / c2) * ratio

def calculate_binned_growth_rate(bin_indices, Q, k_calc, lam, n_bins):
    """
    Calculates the effective growth rate for genotypes based on binned dominance.

    This function implements the formula for dominant averaging within a binned
    framework. It first computes a representative intrinsic growth rate for each
    bin. It then calculates the effective, averaged growth rate for each bin,
    accounting for the fact that genotypes in lower-theta bins can be "rescued"
    by co-transformed genotypes from higher-theta bins. Finally, it maps these
    effective bin-level growth rates back to each individual genotype.

    Parameters
    ----------
    bin_indices : np.ndarray
        An integer array of shape (n_genotypes,) where each element is the bin
        index (0 to n_bins-1) for the corresponding genotype. This is an
        output from the `quantile_bin` function.
    Q : np.ndarray
        A 1D array of shape (n_bins,) representing the cumulative frequency
        distribution across theta-sorted bins. This is an output from the
        `quantile_bin` function.
    k_calc : np.ndarray
        A 1D array of shape (n_genotypes,) containing the intrinsic growth rate
        (g) for each genotype in its original, unsorted order.
    lam : float
        The lambda parameter for the zero-truncated Poisson distribution,
        representing the average number of plasmids per cell.
    n_bins : int
        The total number of bins used.

    Returns
    -------
    np.ndarray
        An array of shape (n_genotypes,) containing the effective,
        bin-level growth rate for each genotype, in the original order.
    """
    
    # Calculate the representative intrinsic growth rate (g_b) for each bin.
    
    # Sum the k_calc values for all genotypes in each bin.
    g_bin_sums = np.bincount(bin_indices, weights=k_calc, minlength=n_bins)
    
    # Count the number of genotypes in each bin.
    bin_counts = np.bincount(bin_indices, minlength=n_bins)
    
    # Avoid division by zero for any empty bins (though unlikely with quantiles).
    g_bins = np.divide(g_bin_sums, bin_counts, where=(bin_counts != 0))

    # Calculate the S(Q) values and their differences.
    S_vals = ztp_avg(Q, lam)
    
    # Prepend a 0 to represent S(Q_{-1}) = 0 for the difference calculation.
    S_vals_padded = np.insert(S_vals, 0, 0)
    delta_S = np.diff(S_vals_padded) # delta_S[b] = S(Q_b) - S(Q_{b-1})

    # Vectorize the summation term: sum_{b=m+1}^{B} g_b * delta_S_b
    
    # This term represents the "rescue" effect from all higher bins.
    rescue_terms = g_bins * delta_S
    
    # A cumulative sum from the right gives the sum of all terms in the "tail".
    tail_sums = np.cumsum(rescue_terms[::-1])[::-1]

    # The sum for bin 'm' starts from element 'm+1'.
    # We create an array of these sums for all bins.
    sum_term = np.zeros(n_bins)
    sum_term[:-1] = tail_sums[1:] # The last bin has no higher bins, so its sum is 0.

    # Calculate the effective growth rate for each bin.
    g_bar_bins = g_bins * S_vals + sum_term

    # Map the effective bin growth rates back to each original genotype.
    effective_growth_rates = g_bar_bins[bin_indices]

    return effective_growth_rates

def calc_cycle(ln_cfu_obs_final,  theta, other_params, lam, n_bins, t_final):
    """
    Executes one full forward pass of the dominance-corrected growth model.

    This function simulates the entire experimental and modeling process for a
    given set of parameters. It starts with initial genotype counts, calculates
    their effective growth rates after accounting for dominant mixing, predicts
    the final counts, and computes the error against observed data.

    Parameters
    ----------
    ln_cfu_obs_final : np.ndarray
        The observed log-scale absolute CFU/mL for each genotype at the final time point.
    ln_cfu_initial : np.ndarray
        The initial log-scale absolute CFU/mL for each genotype at t=0.
    theta : np.ndarray
        The current trial values for the theta parameter for each genotype.
    other_params : dict
        A dictionary containing all other parameters needed by `big_function`.
    lam : float
        The trial value for the ZTP distribution's lambda parameter.
    n_bins : int
        The number of bins to use for the theta-based histogram.
    t_final : float
        The duration of the experiment (e.g., number of generations or hours).

    Returns
    -------
    tuple[np.ndarray, float]
        - ln_cfu_pred_final (np.ndarray): The predicted final log CFU for each genotype.
        - loss (float): The sum of squared errors between the observed and
          predicted final log CFUs, which an optimizer would seek to minimize.
    """

    # This happens in linear space.
    cfu_initial = np.exp(ln_cfu_initial)
    total_initial_cfu = np.sum(cfu_initial)
    initial_frequencies = cfu_initial / total_initial_cfu

    
    # 2. Intrinsic Phenotype Layer: Calculate intrinsic growth rate (g).
    k_calc = big_function(theta, other_params)

    # 3. Dominance & Averaging Layer (The Convolution Step).
    
    # Bin genotypes based on the current trial thetas.
    bin_indices, Q = quantile_bin(theta, initial_frequencies, n_bins)
    
    # Calculate the effective growth rate (g_bar) for each genotype
    g_effective = calculate_binned_growth_rate(bin_indices, Q, k_calc, lam, n_bins)


    # 4. Population Dynamics Layer: Predict final counts.
    # This uses the simple exponential growth equation with the *effective* rate.
    ln_cfu_pred_final = ln_cfu_initial + g_effective * t_final

    # 5. Calculate Loss: Compare prediction to observation.
    # This value is what the optimizer or VI engine tries to minimize.
    loss = np.sum((ln_cfu_obs_final - ln_cfu_pred_final)**2)

    return ln_cfu_pred_final, loss