In [4]:
import numpy as np

def padded_binary(i: int, n: int) -> str:
    return bin(i)[2:].zfill(n)

# Replace this method in your class
@staticmethod
def binary_string_to_array(string: str) -> np.ndarray:
    # use a comprehension so we don't rely on the builtin `list` name being callable
    return np.array([int(ch) for ch in string], dtype=int)


def calculate_valid_transitions(n: int):
    num_states = 2**n
    all_states = [padded_binary(i, n) for i in range(num_states)]
    
    # print(f"Total number of states: {num_states}")
    # print("Valid single-step transitions:")

    valid_difference_vectors = set()
    
    valid_E_reactions = [] # distributively
    valid_F_reactions = [] # distributively

    for i in range(num_states):
        for j in range(num_states):
            # Do not consider transitions from a state to itself
            if i == j:
                continue

            if np.sum(np.abs(binary_string_to_array(all_states[j]) - binary_string_to_array(all_states[i]))) == 1:
                # Determine if it's a phosphorylation or dephosphorylation event
                # A +1 indicates phosphorylation, a -1 indicates dephosphorylation
                element = "E" if np.any(binary_string_to_array(all_states[j]) - binary_string_to_array(all_states[i]) == 1) else "F"
                
                if element == "E":
                    print(f"{all_states[i]} --> {all_states[j]} ({element}), {i}, {j}")
                    valid_E_reactions.append([all_states[i], all_states[j], i, j, element])
                if element == "F":
                    print(f"{all_states[i]} --> {all_states[j]} ({element}), {i}, {j}")
                    valid_F_reactions.append([all_states[i], all_states[j], i, j, element])

                valid_difference_vectors.add(tuple(binary_string_to_array(all_states[j]) - binary_string_to_array(all_states[i])))

    return valid_E_reactions, valid_F_reactions

import numpy as np
import sympy as sp
from typing import Dict

def generate_odes(n: int) -> Dict[sp.Symbol, sp.Expr]:
    if n <= 0:
        raise ValueError("n must be a positive integer")

    valid_E_reactions, valid_F_reactions = calculate_valid_transitions(n)
    num_states = 2**n 

    # symbolic objects
    E = sp.Symbol('E'); F = sp.Symbol('F')
    S = sp.IndexedBase('S')
    ES = sp.IndexedBase('ES'); EdotS = sp.IndexedBase('E.S')
    FS = sp.IndexedBase('FS'); FdotS = sp.IndexedBase('F.S')

    a_E = sp.IndexedBase('a^E'); a_F = sp.IndexedBase('a^F')
    b_E = sp.IndexedBase('b^E'); b_F = sp.IndexedBase('b^F')
    c_E = sp.IndexedBase('c^E'); c_F = sp.IndexedBase('c^F')

    d_dt = {}

    cE_out = {}
    for _, _, k, j, _ in valid_E_reactions:
        cE_out.setdefault(k, []).append(j)
    cF_out = {}
    for _, _, k, j, _ in valid_F_reactions:
        cF_out.setdefault(k, []).append(j)

    cE_in = {}
    for _, _, k, j, _ in valid_E_reactions:
        cE_in.setdefault(j, []).append(k)
    cF_in = {}
    for _, _, k, j, _ in valid_F_reactions:
        cF_in.setdefault(j, []).append(k)

    for i in range(0, num_states):

        d_dt[S[i]] = 0 
        
        if 0 <= i < num_states - 1:
            d_dt[S[i]] += b_E[i] * ES[i] - a_E[i] * E * S[i]

        if 0 < i <= num_states:
            d_dt[S[i]] += b_F[i] * FS[i] - a_F[i] * F * S[i]

        sum_cE = sum(c_E[k, i] * ES[k] for k in cE_in.get(i, [])) if i in cE_in else 0
        sum_cF = sum(c_F[k, i] * FS[k] for k in cF_in.get(i, [])) if i in cF_in else 0

        d_dt[S[i]] += sum_cE + sum_cF

    for i in range(0, num_states - 1):
        sum_c_E_ij = sum(c_E[i, j] for j in cE_out.get(i, [])) if i in cE_out else 0
        d_dt[ES[i]] = a_E[i] * E * S[i] - (b_E[i] + sum_c_E_ij) * ES[i]

    for i in range(1, num_states):
        sum_c_F_ij = sum(c_F[i, j] for j in cF_out.get(i, [])) if i in cF_out else 0
        d_dt[FS[i]] = a_F[i] * F * S[i] - (b_F[i] + sum_c_F_ij) * FS[i]

    d_dt[E] = sum(
        -a_E[j] * E * S[j]
        + (b_E[j] + (sum(c_E[j, k] for k in cE_out.get(j, [])) if j in cE_out else 0)) * ES[j]
        for j in range(0, num_states - 1)  
    )

    d_dt[F] = sum(
        -a_F[j] * F * S[j]
        + (b_F[j] + (sum(c_F[j, k] for k in cF_out.get(j, [])) if j in cF_out else 0)) * FS[j]
        for j in range(1, num_states)  
    )

    return d_dt


In [3]:
calculate_valid_transitions(3)

000 --> 001 (E), 0, 1
000 --> 010 (E), 0, 2
000 --> 100 (E), 0, 4
001 --> 000 (F), 1, 0
001 --> 011 (E), 1, 3
001 --> 101 (E), 1, 5
010 --> 000 (F), 2, 0
010 --> 011 (E), 2, 3
010 --> 110 (E), 2, 6
011 --> 001 (F), 3, 1
011 --> 010 (F), 3, 2
011 --> 111 (E), 3, 7
100 --> 000 (F), 4, 0
100 --> 101 (E), 4, 5
100 --> 110 (E), 4, 6
101 --> 001 (F), 5, 1
101 --> 100 (F), 5, 4
101 --> 111 (E), 5, 7
110 --> 010 (F), 6, 2
110 --> 100 (F), 6, 4
110 --> 111 (E), 6, 7
111 --> 011 (F), 7, 3
111 --> 101 (F), 7, 5
111 --> 110 (F), 7, 6


([['000', '001', 0, 1, 'E'],
  ['000', '010', 0, 2, 'E'],
  ['000', '100', 0, 4, 'E'],
  ['001', '011', 1, 3, 'E'],
  ['001', '101', 1, 5, 'E'],
  ['010', '011', 2, 3, 'E'],
  ['010', '110', 2, 6, 'E'],
  ['011', '111', 3, 7, 'E'],
  ['100', '101', 4, 5, 'E'],
  ['100', '110', 4, 6, 'E'],
  ['101', '111', 5, 7, 'E'],
  ['110', '111', 6, 7, 'E']],
 [['001', '000', 1, 0, 'F'],
  ['010', '000', 2, 0, 'F'],
  ['011', '001', 3, 1, 'F'],
  ['011', '010', 3, 2, 'F'],
  ['100', '000', 4, 0, 'F'],
  ['101', '001', 5, 1, 'F'],
  ['101', '100', 5, 4, 'F'],
  ['110', '010', 6, 2, 'F'],
  ['110', '100', 6, 4, 'F'],
  ['111', '011', 7, 3, 'F'],
  ['111', '101', 7, 5, 'F'],
  ['111', '110', 7, 6, 'F']])