In [4]:
import numpy as np
from math import gcd
from functools import reduce
from TapF import F
from TapO import O


def gcd_list(numbers):

    def gcd_pair(a, b):
        while b:
            a, b = b, a % b
        return a

    return reduce(gcd_pair, numbers)


def balance_equation(reaction):

    original_equation = reaction['equation']
    try:
        reactants = list(reaction['reactants'])
        products = list(reaction['products'])
        compounds = reactants + products
        n_reactants = len(reactants)
        n_compounds = len(compounds)

        # Get all elements from O
        elements = set()
        compound_elements = {}
        for compound in compounds:
            for d in O:
                if compound in d:
                    elements.update(d[compound]['elements'].keys())
                    compound_elements[compound] = d[compound]['elements']
                    break
            else:
                return original_equation, original_equation

        elements = list(elements)
        n_elements = len(elements)

        # Build coefficient matrix
        A = np.zeros((n_elements, n_compounds))
        for i, element in enumerate(elements):
            for j, compound in enumerate(compounds):
                count = compound_elements[compound].get(element, 0)
                A[i, j] = -count if j < n_reactants else count

        # Solve Ax = 0
        _, _, V = np.linalg.svd(A)
        coeffs = V[-1]
        coeffs = coeffs / min(abs(c) for c in coeffs if abs(c) > 1e-10)
        coeffs = coeffs * np.sign(coeffs[0] if abs(coeffs[0]) > 1e-10 else coeffs[1])
        coeffs = coeffs * 1000  # Scale to avoid floating-point issues
        coeffs_int = [int(round(c)) for c in coeffs]
        if any(c == 0 for c in coeffs_int):
            return original_equation, original_equation

        # Simplify coefficients
        g = gcd_list([c for c in coeffs_int if c != 0])
        coeffs_int = [c // g for c in coeffs_int]

        # Format equation, omitting coefficients of 1
        def format_term(coeff, compound):
            return compound if coeff == 1 else f"{coeff}{compound}"

        reactant_terms = [format_term(coeffs_int[i], compound) for i, compound in enumerate(reactants) if coeffs_int[i] != 0]
        product_terms = [
            format_term(coeffs_int[i + n_reactants], compound) for i, compound in enumerate(products) if coeffs_int[i + n_reactants] != 0
        ]
        balanced_equation = f"{' + '.join(reactant_terms)} → {' + '.join(product_terms)}"
        return original_equation, balanced_equation
    except Exception:
        return original_equation, original_equation


def find_reaction_for_reactants(reactants):
    """
    Find a reaction in F that matches the given reactants and balance it.
    Args:
        reactants (set): Set of reactant compounds, e.g., {'Zn', 'O2'}
    Returns:
        tuple: (original_equation, balanced_equation, error) where equations are None if not found
    """
    for reaction in F:
        if reaction['reactants'] == reactants:
            original_equation, balanced_equation = balance_equation(reaction)
            return original_equation, balanced_equation, None
    reactants_str = '+'.join(sorted(reactants))
    return None, None, f"No reaction found for {reactants_str} in F"


def find_direct_reactions(start, end):
    """
    Find all reactions in F where start is a reactant and end is a product, and balance them.
    Args:
        start (str): Starting compound, e.g., 'KClO3'
        end (str): Ending compound, e.g., 'KCl'
    Returns:
        tuple: (list of (original_equation, balanced_equation), error)
    """
    equations = []
    for reaction in F:
        if start in reaction['reactants'] and end in reaction['products']:
            original_equation, balanced_equation = balance_equation(reaction)
            equations.append((original_equation, balanced_equation))
    if equations:
        return equations, None
    return None, f"No reactions found from {start} to {end} in F"


def find_reaction_chain(compounds):
    """
    Find an initial reaction chain connecting the given compounds in order.
    Args:
        compounds (list): List of compounds, e.g., ['Zn', 'ZnO', 'ZnSO4']
    Returns:
        tuple: (path, error) where path is a list of (original_equation, balanced_equation), or None
    """
    if len(compounds) < 2:
        return None, "At least two compounds are required."

    start = compounds[0]
    intermediates = compounds[1:-1]
    end = compounds[-1]

    all_compounds = {compound for d in O for compound in d}
    if not all(compound in all_compounds for compound in compounds):
        return None, "One or more compounds not found in the compound set."

    def get_reactions_for_compound(compound):
        return [r for r in F if compound in r['reactants']]

    def build_path(current, target, path, used_compounds, intermediates_left):
        if current == target and not intermediates_left:
            return path
        reactions = get_reactions_for_compound(current)
        for reaction in reactions:
            products = reaction['products']
            for product in products:
                original_equation, balanced_equation = balance_equation(reaction)
                if product == target and not intermediates_left:
                    return path + [(original_equation, balanced_equation)]
                elif product in intermediates_left:
                    new_intermediates = intermediates_left.copy()
                    new_intermediates.remove(product)
                    new_path = build_path(
                        product, target, path + [(original_equation, balanced_equation)], used_compounds | {product}, new_intermediates
                    )
                    if new_path:
                        return new_path
                elif product not in used_compounds:
                    new_path = build_path(
                        product, target, path + [(original_equation, balanced_equation)], used_compounds | {product}, intermediates_left
                    )
                    if new_path:
                        return new_path
        return None

    path = build_path(start, end, [], {start}, intermediates)
    return path, "No reaction chain found." if path is None else None


def refine_solution(S, H, G):
    """
    Apply Algorithm 2.2 to refine a solution S for problem H → G.
    Args:
        S (list): Initial solution of (original_equation, balanced_equation) pairs
        H (set): Initial compounds
        G (set): Target compounds
    Returns:
        list: Refined solution of (original_equation, balanced_equation) pairs
    """
    NewS = []
    V = set(G)
    equation_to_reaction = {r['equation']: r for r in F}
    for original_equation, balanced_equation in reversed(S):
        reaction = equation_to_reaction.get(original_equation)
        if not reaction:
            continue
        products = reaction['products']
        if V & products:
            NewS.insert(0, (original_equation, balanced_equation))
            V = (V - products) | (reaction['reactants'] - H)
    return NewS


def main():
    try:
        user_input = input("Enter compounds (e.g., Zn,ZnO,ZnSO4 or Zn+O2): ")
        user_input = user_input.strip()

        if '+' in user_input:
            reactants = set(user_input.split('+'))
            reactants = {r.strip() for r in reactants}
            all_compounds = {compound for d in O for compound in d}
            if not all(r in all_compounds for r in reactants):
                print(f"Error: One or more reactants not found in the compound set.")
                return
            original_equation, balanced_equation, error = find_reaction_for_reactants(reactants)
            if error:
                print(error)
            else:
                print(f"Original reaction: {original_equation}")
                print(f"Balanced reaction: {balanced_equation}")
        else:
            compounds = [c.strip() for c in user_input.split(',')]
            all_compounds = {compound for d in O for compound in d}
            if not all(c in all_compounds for c in compounds):
                print(f"Error: One or more compounds not found in the compound set.")
                return

            if len(compounds) == 2:
                equations, error = find_direct_reactions(compounds[0], compounds[1])
                if error:
                    print(error)
                else:
                    print("Reactions found:")
                    for i, (original_equation, balanced_equation) in enumerate(equations, 1):
                        print(f"{i}) Original: {original_equation}")
                        print(f"   Balanced: {balanced_equation}")
            else:
                chain, error = find_reaction_chain(compounds)
                if error:
                    print(f"Error: {error}")
                    return
                if chain:
                    H = {compounds[0]}
                    G = {compounds[-1]}
                    refined_chain = refine_solution(chain, H, G)
                    if refined_chain:
                        print("Refined reaction chain found:")
                        for i, (original_equation, balanced_equation) in enumerate(refined_chain, 1):
                            print(f"{i}) Original: {original_equation}")
                            print(f"   Balanced: {balanced_equation}")
                    else:
                        print("No refined reaction chain found.")
                else:
                    print("No reaction chain found for the given compounds.")

    except Exception as e:
        print(f"An error occurred: {str(e)}")


if __name__ == "__main__":
    main()

Error: One or more compounds not found in the compound set.
