In [2]:
import numpy as np
import sympy as sp
from typing import List, Tuple, Dict, Any, Set
from scipy.stats import levy
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp

class all_parameter_generation:
    """
    Generate state transitions and random parameters (a, b, c, enzyme) for an n-site phosphorylation model.

    Args:
        n: number of sites (int)
        distribution: distribution name ("gamma" supported)
        params: parameters for the distribution (for gamma: [shape, scale])
        verbose: if True, prints transitions and matrices
    """
    def __init__(self, n: int, reaction_types: str, distribution: str, distribution_paramaters: List[float], verbose: bool = False):
        self.n = n
        self.num_states = 2 ** n
        self.distribution = distribution
        self.params = distribution_paramaters
        self.reaction_types = reaction_types
        self.verbose = verbose
        self.rng = np.random.default_rng()
        
    @staticmethod
    def padded_binary(i: int, n: int) -> str:
        return bin(i)[2:].zfill(n)

    @staticmethod
    def binary_string_to_array(string: str) -> np.ndarray:
        return np.array([int(i) for i in string], dtype=int)

    def calculate_valid_transitions(self) -> Tuple[List[List[Any]], List[List[Any]]]:
        """
        Returns:
            valid_X_reactions: list of [state_i_str, state_j_str, i, j, "E"]
            valid_Y_reactions: list of [state_i_str, state_j_str, i, j, "F"]
        """
        all_states = [self.padded_binary(i, self.n) for i in range(self.num_states)]

        valid_difference_vectors: Set[Tuple[int, ...]] = set()
        valid_X_reactions: List[List[Any]] = []
        valid_Y_reactions: List[List[Any]] = []

        for i in range(self.num_states):
            arr_i = self.binary_string_to_array(all_states[i])
            for j in range(self.num_states):
                if i == j:
                    continue
                arr_j = self.binary_string_to_array(all_states[j])
                diff = arr_j - arr_i
                # if self.reaction_types == "distributive":
                    
                hamming_weight = np.sum(np.abs(diff))

                if hamming_weight == 1:
                    # +1 -> phosphorylation (E), -1 -> dephosphorylation (F)
                    element = "E" if np.any(diff == 1) else "F"
                    if element == "E":
                        if self.verbose:
                            print(f"{all_states[i]} --> {all_states[j]} (E), {i}, {j}")
                        valid_X_reactions.append([all_states[i], all_states[j], i, j, element])
                    else:
                        if self.verbose:
                            print(f"{all_states[i]} --> {all_states[j]} (F), {i}, {j}")
                        valid_Y_reactions.append([all_states[i], all_states[j], i, j, element])
                    valid_difference_vectors.add(tuple(diff))

        return valid_X_reactions, valid_Y_reactions
    
    def alpha_parameter_generation(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray,
                                            Dict[int, List[int]], Dict[int, List[int]],
                                            Dict[int, List[int]], Dict[int, List[int]]]:
        
        valid_X_reactions, valid_Y_reactions = self.calculate_valid_transitions()

        shape, scale = self.params

        alpha_matrix = np.zeros((self.num_states, self.num_states))

        for _, _, i, j, _ in valid_X_reactions:

            alpha_matrix[i][j] = self.rng.gamma(shape, scale)

        return alpha_matrix

    def beta_parameter_generation(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray,
                                            Dict[int, List[int]], Dict[int, List[int]],
                                            Dict[int, List[int]], Dict[int, List[int]]]:
        
        valid_X_reactions, valid_Y_reactions = self.calculate_valid_transitions()

        shape, scale = self.params
        beta_matrix = np.zeros((self.num_states, self.num_states))
        
        for _, _, i, j, _ in valid_Y_reactions:

            beta_matrix[i][j] = self.rng.gamma(shape, scale)

        return beta_matrix
    
    def k_parameter_generation(self) -> Tuple[np.ndarray, np.ndarray]:
        # if self.distribution != "gamma":
        #     raise NotImplementedError("Only 'gamma' distribution implemented for a_parameter_generation")
        shape, scale = self.params
        if self.distribution == "gamma":
            k_positive_rates = self.rng.gamma(shape, scale, self.num_states - 1)
            k_negative_rates = self.rng.gamma(shape, scale, self.num_states - 1)
        if self.distribution == "levy":
            k_positive_rates = levy.rvs(loc=shape, scale=scale, size=self.num_states - 1, random_state=self.rng)
            k_negative_rates = levy.rvs(loc=shape, scale=scale, size=self.num_states - 1, random_state=self.rng)
        # k_positive_rates[-1] = 0
        # k_negative_rates[-1] = 0
        
        return k_positive_rates, k_negative_rates

    def p_parameter_generation(self) -> Tuple[np.ndarray, np.ndarray]:
        
        # if self.distribution != "gamma":
        #     raise NotImplementedError("Only 'gamma' distribution implemented for b_parameter_generation")
        shape, scale = self.params
        if self.distribution == "gamma":
            p_positive_rates = self.rng.gamma(shape, scale, self.num_states - 1)
            p_negative_rates = self.rng.gamma(shape, scale, self.num_states - 1)
        if self.distribution == "levy":
            p_positive_rates = levy.rvs(loc=shape, scale=scale, size=self.num_states - 1, random_state=self.rng)
            p_negative_rates = levy.rvs(loc=shape, scale=scale, size=self.num_states - 1, random_state=self.rng)
        # p_positive_rates[0] = 0
        # p_negative_rates[0] = 0

        return p_positive_rates, p_negative_rates
    

# BEST SYMBOLIC ODE GENERATOR:

In [12]:
import sympy as sp
def padded_binary(i: int, n: int) -> str:
    return bin(i)[2:].zfill(n)

# Replace this method in your class
@staticmethod
def binary_string_to_array(string: str) -> np.ndarray:
    # use a comprehension so we don't rely on the builtin `list` name being callable
    return np.array([int(ch) for ch in string], dtype=int)

def calculate_valid_transitions(n: int):
    num_states = 2**n
    all_states = [padded_binary(i, n) for i in range(num_states)]
    
    # print(f"Total number of states: {num_states}")
    # print("Valid single-step transitions:")

    valid_difference_vectors = set()
    
    valid_X_reactions = [] # distributively
    valid_Y_reactions = [] # distributively

    for i in range(num_states):
        for j in range(num_states):
            # Do not consider transitions from a state to itself
            if i == j:
                continue

            if np.sum(np.abs(binary_string_to_array(all_states[j]) - binary_string_to_array(all_states[i]))) == 1:
                # Determine if it's a phosphorylation or dephosphorylation event
                # A +1 indicates phosphorylation, a -1 indicates dephosphorylation
                element = "X" if np.any(binary_string_to_array(all_states[j]) - binary_string_to_array(all_states[i]) == 1) else "Y"
                
                if element == "X":
                    print(f"{all_states[i]} --> {all_states[j]} ({element}), {i} -> {j}")
                    valid_X_reactions.append([all_states[i], all_states[j], i, j, element])
                if element == "Y":
                    print(f"{all_states[i]} --> {all_states[j]} ({element}), {i} -> {j}")
                    valid_Y_reactions.append([all_states[i], all_states[j], i, j, element])

                valid_difference_vectors.add(tuple(binary_string_to_array(all_states[j]) - binary_string_to_array(all_states[i])))

    return valid_X_reactions, valid_Y_reactions

def build_phos_odes(n, prefix=""):
    """
    Build symbolic ODEs for the phosphorylation system for given n (N = 2**n).
    Returns a dict with keys:
      adot, bdot, cdot, xdot, ydot,
    """
    N = 2**n
    a_syms = sp.symbols([f"{prefix}a_{i}" for i in range(N)])
    b_syms = sp.symbols([f"{prefix}b_{i}" for i in range(N)])
    c_syms = sp.symbols([f"{prefix}c_{i}" for i in range(N)])
    x_sym, y_sym = sp.symbols(f"{prefix}x {prefix}y")

    a = sp.Matrix(N, 1, lambda i,j: a_syms[i])
    b = sp.Matrix(N, 1, lambda i,j: b_syms[i])
    c = sp.Matrix(N, 1, lambda i,j: c_syms[i])
    
    x = x_sym
    y = y_sym

    kplus = sp.symbols([f"{prefix}k^+_{i}" for i in range(N)]); kplus[-1] = sp.Integer(0)
    kminus = sp.symbols([f"{prefix}k^-_{i}" for i in range(N)]); kminus[-1] = sp.Integer(0)
    pplus = sp.symbols([f"{prefix}p^+_{i}" for i in range(N)]); pplus[0] = sp.Integer(0)
    pminus = sp.symbols([f"{prefix}p^-_{i}" for i in range(N)]); pminus[0] = sp.Integer(0)

    Kp = sp.diag(*kplus)
    Km = sp.diag(*kminus)
    Pp = sp.diag(*pplus)
    Pm = sp.diag(*pminus)

    # compute allowed transitions
    valid_X_reactions, valid_Y_reactions = calculate_valid_transitions(n)
    
    # Build Alpha and Beta with symbols only at allowed (j,k)
    Alpha = sp.zeros(N, N)
    Beta  = sp.zeros(N, N)
    # create symbols for allowed entries
    for (_, _, j,k, _) in valid_X_reactions:
        Alpha[j,k] = sp.symbols(f"{prefix}alpha_{j}_{k}")
    for (_, _, j,k, _) in valid_Y_reactions:
        Beta[j,k] = sp.symbols(f"{prefix}beta_{j}_{k}")

    ones = sp.Matrix([1]*N)
    Alpha_row_sums = Alpha * ones
    Beta_row_sums = Beta * ones
    DAlpha = sp.diag(*[Alpha_row_sums[i, 0] for i in range(N)])
    DBeta  = sp.diag(*[Beta_row_sums[i, 0] for i in range(N)])
    print(DAlpha)
    print(DBeta)
    adot = Km * b + Pm * c + Alpha.T * b + Beta.T * c - x * (Kp * a) - y * (Pp * a)
    bdot = x * (Kp * a) - Km * b - DAlpha * b
    cdot = y * (Pp * a) - Pm * c - DBeta * c
    bdot[-1] = sp.Integer(0)
    cdot[0] = sp.Integer(0)
    kplus_vec = sp.Matrix(kplus)
    pplus_vec = sp.Matrix(pplus)

    xdot = - x * (kplus_vec.T * a)[0] + (ones.T * (Km + DAlpha) * b)[0]
    ydot = - y * (pplus_vec.T * a)[0] + (ones.T * (Pm + DBeta) * c)[0]
    # print(type(adot[0]))

    return {
        "adot": sp.expand(adot), "bdot": sp.expand(bdot), "cdot": sp.expand(cdot),
        "xdot": sp.expand(xdot), "ydot": sp.expand(ydot),
    }

# Example: build symbolic ODEs for n=2 (N=4)
n = 2
odes = build_phos_odes(n)
odes["cdot"][1]
# sp.pprint(odes["adot"])
# sp.pprint(odes["bdot"])
# sp.pprint(odes["cdot"])
# sp.pprint(odes["xdot"])
# sp.pprint(odes["ydot"])

00 --> 01 (X), 0 -> 1
00 --> 10 (X), 0 -> 2
01 --> 00 (Y), 1 -> 0
01 --> 11 (X), 1 -> 3
10 --> 00 (Y), 2 -> 0
10 --> 11 (X), 2 -> 3
11 --> 01 (Y), 3 -> 1
11 --> 10 (Y), 3 -> 2
Matrix([[alpha_0_1 + alpha_0_2, 0, 0, 0], [0, alpha_1_3, 0, 0], [0, 0, alpha_2_3, 0], [0, 0, 0, 0]])
Matrix([[0, 0, 0, 0], [0, beta_1_0, 0, 0], [0, 0, beta_2_0, 0], [0, 0, 0, beta_3_1 + beta_3_2]])


a_1*p^+_1*y - beta_1_0*c_1 - c_1*p^-_1

In [13]:
odes["cdot"][2]


a_2*p^+_2*y - beta_2_0*c_2 - c_2*p^-_2

In [14]:
odes["cdot"][3]


a_3*p^+_3*y - beta_3_1*c_3 - beta_3_2*c_3 - c_3*p^-_3

In [15]:
odes["xdot"]


-a_0*k^+_0*x - a_1*k^+_1*x - a_2*k^+_2*x + alpha_0_1*b_0 + alpha_0_2*b_0 + alpha_1_3*b_1 + alpha_2_3*b_2 + b_0*k^-_0 + b_1*k^-_1 + b_2*k^-_2

# PARAMETER SEARCH CODE:

In [28]:
from numpy import linalg as LA
import numpy as np
import sympy as sp
from itertools import product
from scipy.optimize import root

def polynomial_finder(n, alpha_matrix, beta_matrix, k_positive_rates, k_negative_rates, p_positive_rates, p_negative_rates):

    N = 2**n
    
    # x_tot = sp.Float(x_tot_value); y_tot = sp.Float(y_tot_value); a_tot = sp.Float(a_tot_value)

    ones_vec = np.ones(N - 1)
    # ones_vec = np.ones((N-1, 1))
    Kp = np.diag(np.append(k_positive_rates, 0))
    Km = np.append(np.diag(k_negative_rates), np.zeros((1, len(k_negative_rates))), axis=0)

    Pp = np.diag(np.insert(p_positive_rates, 0, 0))
    Pm = np.vstack([np.zeros((1, len(p_negative_rates))), np.diag(p_negative_rates)])

    adjusted_alpha_mat = np.delete(alpha_matrix, -1, axis = 0)
    adjusted_beta_mat = np.delete(beta_matrix, 0, axis = 0)


    Da = np.diag(alpha_matrix[:-1, 1:] @ ones_vec)
    Db = np.diag(beta_matrix[1:, :-1] @ ones_vec)


    U = np.diag(k_negative_rates)
    I = np.diag(p_negative_rates)
    Q = Kp[:-1, :]
    D = np.delete(Pp, 0, axis=0)
    M_mat = U + Da
    N_mat = I + Db

    G = Km + adjusted_alpha_mat.T
    H = Pm + adjusted_beta_mat.T
    M_inv = np.linalg.inv(M_mat); N_inv = np.linalg.inv(N_mat)

    L1 = G @ M_inv @ Q - Kp; L2 = H @ N_inv @ D - Pp
    W1 = M_inv @ Q; W2 = N_inv @ D

    ####### RESCALING #######
    # t --> t * k_neg_0 * y_tot / a_tot
    # L1 = L1 * k_neg_0 / a_tot; L2 = L2 * k_neg_0 / a_tot
    # W1 = W1 / a_tot; W2 = W2 / a_tot

    return L1, L2, W1, W2

def stability_calculator(a_fixed_points, p, L1, L2, W1, W2):
    N = len(a_fixed_points)
    ones_vec_j = np.ones((1, N-1))  # shape (1, N-1)
    a_fixed_points = np.array(a_fixed_points).reshape((N, 1))  # shape (N, 1)

    L1 = np.array(L1, dtype=float)
    L2 = np.array(L2, dtype=float)
    W1 = np.array(W1, dtype=float)
    W2 = np.array(W2, dtype=float)

    # Compute denominators
    # try:
    denom1 = 1 + float(ones_vec_j @ (W1 @ a_fixed_points))
    denom2 = 1 + float(ones_vec_j @ (W2 @ a_fixed_points))
    # except Exception as e:
    #     continue

    # if (not np.isfinite(denom1)) or (not np.isfinite(denom2)) or (abs(denom1) < 1e-12) or (abs(denom2) < 1e-12):
    #     continue
    term1 = (L1 / denom1) - (((L1 @ a_fixed_points) @ (ones_vec_j @ W1)) / (denom1**2))
    term2 = (L2 / denom2) - (((L2 @ a_fixed_points) @ (ones_vec_j @ W2)) / (denom2**2))
    # term1 = (L1 / denom1) - np.outer(L1 @ a_fixed_points, ones_vec_j @ W1) / (denom1**2)
    # term2 = (L2 / denom2) - np.outer(L2 @ a_fixed_points, ones_vec_j @ W2) / (denom2**2)
    J = p * term1 + term2
    return J

from scipy.optimize import root

def fp_checker(a_tot_value, x_tot_value, y_tot_value,
                      alpha_matrix, beta_matrix,
                      k_positive_rates, k_negative_rates,
                      p_positive_rates, p_negative_rates):
    
    n = 2
    N = 2**n

    a_syms_full = sp.symbols([f"a{i}" for i in range(N)], real=True)

    a_syms_reduced = sp.symbols([f"a{i}" for i in range(N - 1)], real=True)
    a_tot_sym, x_tot_sym, y_tot_sym = sp.symbols("a_tot_sym x_tot_sym y_tot_sym", real=True)
    conservation_expression = 1 - sum(a_syms_reduced)

    a_vec_sym = sp.Matrix([sp.symbols([f"a{i}" for i in range(N - 1)], real=True)[i] if i < N-1 else conservation_expression for i in range(N)])

    L1, L2, W1, W2 = polynomial_finder(
        n, alpha_matrix, beta_matrix,
        k_positive_rates, k_negative_rates,
        p_positive_rates, p_negative_rates
    )

    L1_sym = sp.Matrix(L1.tolist())
    L2_sym = sp.Matrix(L2.tolist())
    W1_sym = sp.Matrix(W1.tolist())
    W2_sym = sp.Matrix(W2.tolist())

    ones_vec_sym = sp.Matrix([[1] * W2_sym.rows])

    inner_W1 = (ones_vec_sym * W1_sym * a_vec_sym)[0, 0]
    inner_W2 = (ones_vec_sym * W2_sym * a_vec_sym)[0, 0]

    L1a = L1_sym * a_vec_sym
    L2a = L2_sym * a_vec_sym
    p = x_tot_sym / y_tot_sym
    poly_exprs = p * L1a * (1 + inner_W2) + L2a * (1 + inner_W1)   

    subs_numeric = {a_tot_sym: float(a_tot_value), x_tot_sym: float(x_tot_value), y_tot_sym: float(y_tot_value)}

    # plugging in numeric values
    polynomials_list_numeric = sp.Matrix([sp.N(p.subs(subs_numeric)) for p in poly_exprs]) # all 4 polynomials are included, will be evaluated at identified roots
    polynomials_list_reduced_numeric = [sp.simplify(polynomials_list_numeric[i, 0]) for i in range(N - 1)] # range(N - 1) means we only have 3 polynomials now

    # lambdifying functions
    polynomials_reduced_lambdified = sp.lambdify(list(a_syms_reduced), polynomials_list_reduced_numeric, "numpy")
    polynomials_lambdified = sp.lambdify(list(a_syms_full), polynomials_list_numeric, "numpy")

    def residuals_vec_red(a_vec):
        a_vec = np.asarray(a_vec, dtype=float).ravel()             
        vals = polynomials_reduced_lambdified(*a_vec)                
        return np.asarray(vals, dtype=float).ravel()

    def residuals_vec(a_vec):
        a_vec = np.asarray(a_vec, dtype=float).ravel()              
        vals = polynomials_lambdified(*a_vec)            
        return np.asarray(vals, dtype=float).ravel()

    def is_duplicate(candidate, collection, norm_tol):
        for v in collection:
            # all_close = np.allclose(1e8 * v, 1e8 * candidate)
            all_close = np.allclose(1e5 * v, 1e5 * candidate)
            norm_close = False
            if (np.linalg.norm(np.array(v) - np.array(candidate)) < norm_tol):
                norm_close = True
            if (all_close and norm_close):
                return True
        return False
    # def jacobian_root_finder()
    # Initial guess
    final_sol_list_stable = []
    final_sol_list_unstable = []

    attempt_total_num = 8
    guesses = []
    for u in range(0, attempt_total_num):
        if u == 0:
            guess = np.array([1 / N] * (N - 1))
        else:
            r = np.random.rand(N - 1); r /= r.sum()
            guess = r
        # guesses.append(np.array([u/attempt_total_num] * (N-1)) + np.random.uniform(low=0.05, high=0.1, size=N-1))

    for guess in guesses:
        
        ### MAKE SURE ORDER OF VARIABLES ARE CORRECT??
        root_finder_tol = 1e-10
        try:
            sol = root(residuals_vec_red, guess, method = 'hybr', tol = root_finder_tol)
            if not sol.success:
                continue
        except Exception as e:
            continue
        sol = sol.x
        # print(sol)
        ##### SOLVING FOR a3, APPENDING TO ROOT SET!!!!
        full_sol = np.append(sol, 1 - np.sum(sol))

        # print(residuals_vec(full_sol))
        # print("full_sol:", full_sol)

        # THESE CONDITIONS ARE ALWAYS ABSOLUTELY NECESSARY
        # if not (np.all(full_sol > 0) and np.isclose(np.sum(full_sol), 1, 1e-6)):
        if not (np.all(residuals_vec(full_sol) < 1e-10)):
            continue

        if not (np.all(full_sol > 0) and np.isclose(np.sum(full_sol), 1, 1e-6)):
            continue

        J = stability_calculator(full_sol, x_tot_value / y_tot_value, L1, L2, W1, W2)
        # if not np.all(np.isfinite(J)) or np.isnan(J).any():
        if not np.isfinite(J).all():
            continue  # skip this guess if J contains NaN or inf
        # if np.linalg.cond(J) > 1e12:
        #     continue
        # norm_tol = 0
        norm_tol = 1e-2 # cannot be <= 1e-8

        eigenvalues = LA.eigvals(J)
        eigenvalues_real = np.real(eigenvalues)
        if (np.max(eigenvalues_real > 1e-10) and (is_duplicate(full_sol, final_sol_list_unstable, norm_tol) == False)):
            final_sol_list_unstable.append(full_sol)

        elif np.max(np.abs(eigenvalues_real)) <= 1e-10: # detecting marginal stability, inconclusive, just skip this guess
            continue
        
        elif is_duplicate(full_sol, final_sol_list_stable, norm_tol) == False:
            final_sol_list_stable.append(full_sol)   

        else:
            continue         
        # elif len(eigenvalues_real[eigenvalues_real < 0]) == len(eigenvalues_real and is_duplicate(full_sol, final_sol_list_stable, norm_tol) == False):


        # if len(eigenvalues_real[eigenvalues_real > 0]) >= 1 and (is_duplicate(full_sol, final_sol_list_unstable, norm_tol) == False):


        # if np.all(eigenvalues < 0): # stability check
        #     is_stable = True

    if (len(final_sol_list_stable) == 0) and len(final_sol_list_unstable) == 0:
        return np.array([]), np.array([])  # failed
    return np.array(final_sol_list_stable), np.array(final_sol_list_unstable) # an array of arrays (a matrix)

def log_uniform_sample(a: float, b: float, n: int, base: float = np.e) -> np.ndarray:
    """
    Sample n points uniformly in logarithmic space from range (a, b).

    Args:
        a (float): Lower bound (must be > 0).
        b (float): Upper bound (must be > a).
        n (int): Number of samples to draw.
        base (float): Logarithmic base (default: natural log).

    Returns:
        np.ndarray: Array of shape (n,) with log-uniformly sampled values.
    """
    if a <= 0 or b <= 0:
        raise ValueError("a and b must be positive.")
    if a >= b:
        raise ValueError("a must be smaller than b.")

    log_a = np.log(a) / np.log(base)
    log_b = np.log(b) / np.log(base)
    log_samples = np.random.uniform(log_a, log_b, n)
    samples = base ** log_samples
    return samples

def simulation(sites_n, simulation_size):
    a_tot_value = 1
    base = 10
    # base = np.e
    N = sites_n**2
    x_tot_value_parameter_array = np.array([log_uniform_sample(1e-4, 1e-1, n=1, base=base) for _ in range(simulation_size)])
    y_tot_value_parameter_array = np.array([log_uniform_sample(1e-4, 1e-1, n=1, base=base) for _ in range(simulation_size)])
    alpha_matrix_parameter_array = np.array([log_uniform_sample(1e-1, 1e10, n=N, base=base) for _ in range(simulation_size)])
    beta_matrix_parameter_array = np.array([log_uniform_sample(1e-1, 1e10, n=N, base=base) for _ in range(simulation_size)])
    k_positive_parameter_array = np.array([log_uniform_sample(1e-1, 1e10, n=N-1, base=base) for _ in range(simulation_size)])
    k_negative_parameter_array = np.array([log_uniform_sample(1e-1, 1e10, n=N-1, base=base) for _ in range(simulation_size)])
    p_positive_parameter_array = np.array([log_uniform_sample(1e-1, 1e10, n=N-1, base=base) for _ in range(simulation_size)])
    p_negative_parameter_array = np.array([log_uniform_sample(1e-1, 1e10, n=N-1, base=base) for _ in range(simulation_size)])

    monostable_final_list = []
    bistable_final_list = []

    for i in range(simulation_size):

        k_positive_rates = k_positive_parameter_array[i]
        k_negative_rates = k_negative_parameter_array[i]
        p_positive_rates = p_positive_parameter_array[i]
        p_negative_rates = p_negative_parameter_array[i]

        x_tot_value = x_tot_value_parameter_array[i][0]
        y_tot_value = y_tot_value_parameter_array[i][0]

        alpha_matrix = np.array([
            [0, alpha_matrix_parameter_array[i][0], alpha_matrix_parameter_array[i][1], 0],
            [0, 0, 0, alpha_matrix_parameter_array[i][2]],
            [0, 0, 0, alpha_matrix_parameter_array[i][3]],
            [0, 0, 0, 0]
        ])

        beta_matrix = np.array([
            [0, 0, 0, 0],
            [beta_matrix_parameter_array[i][0], 0, 0, 0],
            [beta_matrix_parameter_array[i][1], 0, 0, 0],
            [0, beta_matrix_parameter_array[i][2], beta_matrix_parameter_array[i][3], 0]
        ])

        stable_fp_array, unstable_fp_array = fp_checker(a_tot_value, x_tot_value, y_tot_value, 
                                             alpha_matrix, beta_matrix, k_positive_rates, 
                                             k_negative_rates, p_positive_rates, p_negative_rates)
 

        meta = np.array([i]).astype(int)

        # if we have at least 1 stable fp and detected 0 unstable fp, then we likely have a case of monostability (assuming nothing went wrong numerically)
        if (len(stable_fp_array) >= 1) and (len(unstable_fp_array) == 0):
            stable_array = np.concatenate([meta, stable_fp_array[0], 
                        np.array([a_tot_value, x_tot_value, y_tot_value]),
                        np.ravel(alpha_matrix_parameter_array[i]), 
                        np.ravel(beta_matrix_parameter_array[i]), 
                        np.ravel(k_positive_parameter_array[i]), 
                        np.ravel(k_negative_parameter_array[i]), 
                        np.ravel(p_positive_parameter_array[i]), 
                        np.ravel(p_negative_parameter_array[i])])
            monostable_final_list.append(stable_array)

        # if we have at least 1 stable fp and at least 1 unstable fp, then we likely have a case of bistability (assuming nothing went wrong numerically)
        if len(stable_fp_array) == 2 and len(unstable_fp_array) == 1:
            unstable_array = np.concatenate([meta, unstable_fp_array[0], 
                    np.array([a_tot_value, x_tot_value, y_tot_value]),
                    np.ravel(alpha_matrix_parameter_array[i]), 
                    np.ravel(beta_matrix_parameter_array[i]), 
                    np.ravel(k_positive_parameter_array[i]), 
                    np.ravel(k_negative_parameter_array[i]), 
                    np.ravel(p_positive_parameter_array[i]), 
                    np.ravel(p_negative_parameter_array[i])])
            bistable_final_list.append(unstable_array)

    header_string = "sim, a_0_stable_fp, a_1_stable_fp, a_2_stable_fp, a_3_stable_fp, a_tot, x_tot, y_tot, N*alpha_matrix_parameters, N*beta_matrix_parameters, (N-1)*k_+_parameters, (N-1)*k_-_parameters, (N-1)*p_+_parameters, (N-1)*p_-_parameters"

    if monostable_final_list:
        np.savetxt(f"monostability_parameters_new_{simulation_size}.csv", monostable_final_list, delimiter=",", header=header_string, comments='')
    if bistable_final_list:
        np.savetxt(f"bistability_parameters_new_{simulation_size}.csv", bistable_final_list, delimiter=",", header=header_string, comments='')

    return 

# simulation(2, 5000)


# PARALLELIZED VERSION:

In [None]:
# add near the top of your file
from joblib import Parallel, delayed

# helper must be top-level so it is picklable by joblib

from numpy import linalg as LA
import numpy as np
import sympy as sp
from scipy.optimize import root

def polynomial_finder(n, alpha_matrix, beta_matrix, k_positive_rates, k_negative_rates, p_positive_rates, p_negative_rates):

    N = 2**n
    
    # x_tot = sp.Float(x_tot_value); y_tot = sp.Float(y_tot_value); a_tot = sp.Float(a_tot_value)

    ones_vec = np.ones(N - 1)
    # ones_vec = np.ones((N-1, 1))
    Kp = np.diag(np.append(k_positive_rates, 0))
    Km = np.append(np.diag(k_negative_rates), np.zeros((1, len(k_negative_rates))), axis=0)

    Pp = np.diag(np.insert(p_positive_rates, 0, 0))
    Pm = np.vstack([np.zeros((1, len(p_negative_rates))), np.diag(p_negative_rates)])

    adjusted_alpha_mat = np.delete(alpha_matrix, -1, axis = 0)
    adjusted_beta_mat = np.delete(beta_matrix, 0, axis = 0)


    Da = np.diag(alpha_matrix[:-1, 1:] @ ones_vec)
    Db = np.diag(beta_matrix[1:, :-1] @ ones_vec)


    U = np.diag(k_negative_rates)
    I = np.diag(p_negative_rates)
    Q = Kp[:-1, :]
    D = np.delete(Pp, 0, axis=0)
    M_mat = U + Da
    N_mat = I + Db

    G = Km + adjusted_alpha_mat.T
    H = Pm + adjusted_beta_mat.T
    M_inv = np.linalg.inv(M_mat); N_inv = np.linalg.inv(N_mat)
    L1 = G @ M_inv @ Q - Kp; L2 = H @ N_inv @ D - Pp
    # print("L1 shape:", np.shape(L1))

    W1 = M_inv @ Q; W2 = N_inv @ D

    return L1, L2, W1, W2

def stability_calculator(a_fixed_points, p, L1, L2, W1, W2):

    a_fixed_points = np.asarray(a_fixed_points, dtype=float).ravel()
    N = a_fixed_points.size

    ones_vec_j = np.ones((1, N-1), dtype = float) # shape (1, N-1)

    a_fixed_points = a_fixed_points.reshape((N, 1))  # shape (N, 1)
    L1 = np.array(L1, dtype=float)
    L2 = np.array(L2, dtype=float)
    W1 = np.array(W1, dtype=float)
    W2 = np.array(W2, dtype=float)
    # Compute denominators
    
    denom1 = float(1.0 + np.dot(ones_vec_j, (W1 @ a_fixed_points)))
    denom2 = float(1.0 + np.dot(ones_vec_j, (W2 @ a_fixed_points)))

    term1 = (L1 / denom1) - (((L1 @ a_fixed_points) @ (ones_vec_j @ W1)) / (denom1**2))
    term2 = (L2 / denom2) - (((L2 @ a_fixed_points) @ (ones_vec_j @ W2)) / (denom2**2))

    J = (p * term1) + term2
    return J

from scipy.optimize import root
from sympy import shape

def fp_checker(a_tot_value, x_tot_value, y_tot_value,
                      alpha_matrix, beta_matrix,
                      k_positive_rates, k_negative_rates,
                      p_positive_rates, p_negative_rates):
    
    n = 2
    N = 2**n

    a_syms_full = sp.symbols([f"a{i}" for i in range(N)], real=True)

    a_syms_reduced = sp.symbols([f"a{i}" for i in range(N - 1)], real=True)
    a_tot_sym, x_tot_sym, y_tot_sym = sp.symbols("a_tot_sym x_tot_sym y_tot_sym", real=True)
    conservation_expression = 1 - sum(a_syms_reduced)

    a_vec_sym = sp.Matrix([sp.symbols([f"a{i}" for i in range(N - 1)], real=True)[i] if i < N-1 else conservation_expression for i in range(N)])

    L1, L2, W1, W2 = polynomial_finder(
        n, alpha_matrix, beta_matrix,
        k_positive_rates, k_negative_rates,
        p_positive_rates, p_negative_rates
    )
    L1_sym = sp.Matrix(L1.tolist())
    L2_sym = sp.Matrix(L2.tolist())
    W1_sym = sp.Matrix(W1.tolist())
    W2_sym = sp.Matrix(W2.tolist())
    ones_vec_sym = sp.Matrix([[1] * W2_sym.rows])

    inner_W1 = (ones_vec_sym * W1_sym * a_vec_sym)[0, 0]
    inner_W2 = (ones_vec_sym * W2_sym * a_vec_sym)[0, 0]

    L1a = L1_sym * a_vec_sym
    L2a = L2_sym * a_vec_sym
    p = x_tot_sym / y_tot_sym
    poly_exprs = p * L1a * (1 + inner_W2) + L2a * (1 + inner_W1)   

    subs_numeric = {a_tot_sym: float(a_tot_value), x_tot_sym: float(x_tot_value), y_tot_sym: float(y_tot_value)}

    # plugging in numeric values
    polynomials_list_numeric = sp.Matrix([sp.N(p.subs(subs_numeric)) for p in poly_exprs]) # all 4 polynomials are included, will be evaluated at identified roots
    polynomials_list_reduced_numeric = [sp.simplify(polynomials_list_numeric[i, 0]) for i in range(N - 1)] # range(N - 1) means we only have 3 polynomials now

    # lambdifying functions
    polynomials_reduced_lambdified = sp.lambdify(list(a_syms_reduced), polynomials_list_reduced_numeric, "numpy")
    polynomials_lambdified = sp.lambdify(list(a_syms_full), polynomials_list_numeric, "numpy")
    def residuals_vec_red(a_vec):
        a_vec = np.asarray(a_vec, dtype=float).ravel()             
        vals = polynomials_reduced_lambdified(*a_vec)                
        return np.asarray(vals, dtype=float).ravel()

    def residuals_vec(a_vec):
        a_vec = np.asarray(a_vec, dtype=float).ravel()              
        vals = polynomials_lambdified(*a_vec)            
        return np.asarray(vals, dtype=float).ravel()

    def is_duplicate(candidate, collection, norm_tol):
        for v in collection:
            # all_close = np.allclose(1e8 * v, 1e8 * candidate)
            all_close = np.allclose(v, candidate)
            norm_close = False
            if (np.linalg.norm(np.array(v) - np.array(candidate)) < norm_tol):
                norm_close = True
            if (all_close and norm_close):
                return True
        return False

    final_sol_list_stable = []
    final_sol_list_unstable = []

    attempt_total_num = 8
    guesses = []
    for u in range(0, attempt_total_num):
        guesses.append(np.array([u/attempt_total_num] * (N-1)) + np.random.uniform(low=0.05, high=0.1, size=N-1))

    for guess in guesses:
        
        ### MAKE SURE ORDER OF VARIABLES ARE CORRECT??
        root_finder_tol = 1e-8
        try:
            sol = root(residuals_vec_red, guess, method = 'hybr', tol = root_finder_tol)
            if not sol.success:
                continue
        except Exception as e:
            continue
        sol = np.asarray(sol.x, dtype=float).ravel()

        ##### SOLVING FOR a3, APPENDING TO ROOT SET!!!!
        full_sol = np.append(sol, 1.0 - np.sum(sol))

        # THESE CONDITIONS ARE ALWAYS ABSOLUTELY NECESSARY
        if not (np.all(residuals_vec(full_sol) < 1e-8)):
            continue
        if not (np.isclose(np.sum(full_sol), 1, 1e-6)):
            continue
        if not (np.all(full_sol > 0)):
            continue

        J = stability_calculator(full_sol, x_tot_value / y_tot_value, L1, L2, W1, W2)
        # if not np.all(np.isfinite(J)) or np.isnan(J).any():
        if not np.isfinite(J).all():
            continue  # skip this guess if J contains NaN or inf

        norm_tol = 1e-2 # cannot be <= 1e-8

        eigenvalues = LA.eigvals(J)
        eigenvalues_real = np.real(eigenvalues)
        if (np.max(eigenvalues_real > 1e-8) and (is_duplicate(full_sol, final_sol_list_unstable, norm_tol) == False)):
            final_sol_list_unstable.append(full_sol)
        
        if (len(eigenvalues_real[eigenvalues_real > 0]) >= 1) and is_duplicate(full_sol, final_sol_list_stable, norm_tol) == False:
            final_sol_list_stable.append(full_sol)   

        else:
            continue         

    if (len(final_sol_list_stable) == 0) and len(final_sol_list_unstable) == 0:
        return np.array([]), np.array([])  # failed
    
    return np.array(final_sol_list_stable), np.array(final_sol_list_unstable) # an array of arrays (a matrix)

def log_uniform_sample(a: float, b: float, n: int, base: float = np.e) -> np.ndarray:
    """
    Sample n points uniformly in logarithmic space from range (a, b).

    Args:
        a (float): Lower bound (must be > 0).
        b (float): Upper bound (must be > a).
        n (int): Number of samples to draw.
        base (float): Logarithmic base (default: natural log).

    Returns:
        np.ndarray: Array of shape (n,) with log-uniformly sampled values.
    """
    if a <= 0 or b <= 0:
        raise ValueError("a and b must be positive.")
    if a >= b:
        raise ValueError("a must be smaller than b.")

    log_a = np.log(a) / np.log(base)
    log_b = np.log(b) / np.log(base)
    log_samples = np.random.uniform(log_a, log_b, n)
    samples = base ** log_samples
    return samples


def process_sample(i,
                   sites_n,
                   a_tot_value,
                   base,
                   x_tot_value_parameter_array,
                   y_tot_value_parameter_array,
                   alpha_matrix_parameter_array,
                   beta_matrix_parameter_array,
                   k_positive_parameter_array,
                   k_negative_parameter_array,
                   p_positive_parameter_array,
                   p_negative_parameter_array):
    """
    Run the per-sample work that used to be inside the loop.
    Returns a tuple (mon_row_or_None, bist_row_or_None).
    """
    N = sites_n**2

    k_positive_rates = k_positive_parameter_array[i]
    k_negative_rates = k_negative_parameter_array[i]
    p_positive_rates = p_positive_parameter_array[i]
    p_negative_rates = p_negative_parameter_array[i]

    x_tot_value = x_tot_value_parameter_array[i][0]
    y_tot_value = y_tot_value_parameter_array[i][0]

    alpha_matrix = np.array([
        [0, alpha_matrix_parameter_array[i][0], alpha_matrix_parameter_array[i][1], 0],
        [0, 0, 0, alpha_matrix_parameter_array[i][2]],
        [0, 0, 0, alpha_matrix_parameter_array[i][3]],
        [0, 0, 0, 0]
    ])

    beta_matrix = np.array([
        [0, 0, 0, 0],
        [beta_matrix_parameter_array[i][0], 0, 0, 0],
        [beta_matrix_parameter_array[i][1], 0, 0, 0],
        [0, beta_matrix_parameter_array[i][2], beta_matrix_parameter_array[i][3], 0]
    ])

    stable_fp_array, unstable_fp_array = fp_checker(a_tot_value, x_tot_value, y_tot_value,
                                                     alpha_matrix, beta_matrix, k_positive_rates,
                                                     k_negative_rates, p_positive_rates, p_negative_rates)

    meta = np.array([i]).astype(int)

    mon_row = None
    bist_row = None

    # monostable condition (same logic as before)
    if (len(stable_fp_array) >= 1) and (len(unstable_fp_array) == 0):
        mon_row = np.concatenate([meta, stable_fp_array[0],
                    np.array([a_tot_value, x_tot_value, y_tot_value]),
                    np.ravel(alpha_matrix_parameter_array[i]),
                    np.ravel(beta_matrix_parameter_array[i]),
                    np.ravel(k_positive_parameter_array[i]),
                    np.ravel(k_negative_parameter_array[i]),
                    np.ravel(p_positive_parameter_array[i]),
                    np.ravel(p_negative_parameter_array[i])])

    # bistable condition (same as before)
    if len(stable_fp_array) >= 2 and len(unstable_fp_array) >= 1:
    # if len(stable_fp_array) == 2 and len(unstable_fp_array) == 1:
        bist_row = np.concatenate([meta, unstable_fp_array[0],
                    np.array([a_tot_value, x_tot_value, y_tot_value]),
                    np.ravel(alpha_matrix_parameter_array[i]),
                    np.ravel(beta_matrix_parameter_array[i]),
                    np.ravel(k_positive_parameter_array[i]),
                    np.ravel(k_negative_parameter_array[i]),
                    np.ravel(p_positive_parameter_array[i]),
                    np.ravel(p_negative_parameter_array[i])])

    return mon_row, bist_row


def simulation(sites_n, simulation_size):
    a_tot_value = 1
    base = np.e
    range_high = 1e5
    N = sites_n**2
    x_tot_value_parameter_array = np.array([log_uniform_sample(1e-4, 1e-1, n=1, base=base) for _ in range(simulation_size)])
    y_tot_value_parameter_array = np.array([log_uniform_sample(1e-4, 1e-1, n=1, base=base) for _ in range(simulation_size)])
    alpha_matrix_parameter_array = np.array([log_uniform_sample(1e-1, range_high, n=N, base=base) for _ in range(simulation_size)])
    beta_matrix_parameter_array = np.array([log_uniform_sample(1e-1, range_high, n=N, base=base) for _ in range(simulation_size)])
    k_positive_parameter_array = np.array([log_uniform_sample(1e-1, range_high, n=N-1, base=base) for _ in range(simulation_size)])
    k_negative_parameter_array = np.array([log_uniform_sample(1e-1, range_high, n=N-1, base=base) for _ in range(simulation_size)])
    p_positive_parameter_array = np.array([log_uniform_sample(1e-1, range_high, n=N-1, base=base) for _ in range(simulation_size)])
    p_negative_parameter_array = np.array([log_uniform_sample(1e-1, range_high, n=N-1, base=base) for _ in range(simulation_size)])

    results = Parallel(n_jobs=-1, backend="loky")(
        delayed(process_sample)(
            i,
            sites_n,
            a_tot_value,
            base,
            x_tot_value_parameter_array,
            y_tot_value_parameter_array,
            alpha_matrix_parameter_array,
            beta_matrix_parameter_array,
            k_positive_parameter_array,
            k_negative_parameter_array,
            p_positive_parameter_array,
            p_negative_parameter_array
        ) for i in range(simulation_size)
    )

    monostable_final_list = []
    bistable_final_list = []
    for mon_row, bist_row in results:
        if mon_row is not None:
            monostable_final_list.append(mon_row)
        if bist_row is not None:
            bistable_final_list.append(bist_row)

    header_string = "sim, a_0_stable_fp, a_1_stable_fp, a_2_stable_fp, a_3_stable_fp, a_tot, x_tot, y_tot, N*alpha_matrix_parameters, N*beta_matrix_parameters, (N-1)*k_+_parameters, (N-1)*k_-_parameters, (N-1)*p_+_parameters, (N-1)*p_-_parameters"

    if monostable_final_list:
        np.savetxt(f"monostability_parameters_new_{simulation_size}.csv", monostable_final_list, delimiter=",", header=header_string, comments='')
    if bistable_final_list:
        np.savetxt(f"bistability_parameters_new_{simulation_size}.csv", bistable_final_list, delimiter=",", header=header_string, comments='')
    return

simulation(2, 6000)


