In [None]:
from __future__ import (absolute_import, division, print_function)
from functools import reduce
from operator import mul
import sympy as sp
import numpy as np
import matplotlib.pyplot as plt
from pyneqsys.symbolic import SymbolicSys, TransformedSys, linear_exprs
sp.init_printing()
prod = lambda x: reduce(mul, x)
print(sp.__version__)

Let's consider:
$$
H^+ + OH^- \leftrightharpoons H_2O \\
NH_4^+ \leftrightharpoons H^+ + NH_3
$$

In [None]:
names = 'H+ OH- NH4+ NH3 H2O'.split()
NH3_idx = 3
NH3_varied = np.logspace(-7, 0)
c0 = 1e-7, 1e-7, 1e-7, 1, 55
K = Kw, Ka = 10**-14/55, 10**-9.24

In [None]:
stoichs = [[1, 1, 0, 0, -1], [1, 0, -1, 1, 0]]
H = [1, 1, 4, 3, 2]
N = [0, 0, 1, 1, 0]
O = [0, 1, 0, 0, 1]
e = [1, -1, 1, 0, 0]
preserv = [H, N, O, e]
def get_f(x, params):
    init_concs = params[:5]
    eq_constants = params[5:]
    le = linear_exprs(preserv, x, linear_exprs(preserv, init_concs), rref=True)
    return le + [
        prod(xi**p for xi, p in zip(x, coeffs)) - K for coeffs, K in zip(stoichs, eq_constants)
    ]

In [None]:
neqsys = SymbolicSys.from_callback(get_f, 5, 7, names=names)
neqsys.exprs

In [None]:
neqsys.get_jac()

In [None]:
%matplotlib inline
def solve_and_plot(nsys):
    fig = plt.figure(figsize=(16,6))
    ax_out = plt.subplot(1, 2, 1, xscale='log', yscale='log')
    ax_err = plt.subplot(1, 2, 2, xscale='log')
    ax_err.set_yscale('symlog', linthreshy=1e-14)
    xres, sols = nsys.solve_and_plot_series('scipy', c0, c0+K, NH3_varied, NH3_idx, ax_out, ax_err)
    for ax in (ax_out, ax_err):
        ax.set_xlabel('[NH3]0 / M')
    ax_out.set_ylabel('Concentration / M')
    ax_out.legend(loc='best')
    ax_err.set_ylabel('Residuals')
    
    avg_nfev = np.average([sol.nfev for sol in sols])
    avg_njev = np.average([sol.njev for sol in sols])
    success = np.average([int(sol.success) for sol in sols])
    return {'avg_nfev': avg_nfev, 'avg_njev': avg_njev, 'success': success}

    
solve_and_plot(neqsys)

Now let's see how pyneqsys can transform our system:

In [None]:
def my_log_transform(expr):
    if isinstance(expr, sp.Eq):
        return sp.expand_log(sp.log(expr.lhs), force=True) - sp.expand_log(sp.log(expr.rhs), force=True)
    else:
        return expr

In [None]:
tneqsys = TransformedSys.from_callback(get_f, 5, 7, my_log_transform, (sp.exp, sp.log))
tneqsys.exprs

In [None]:
c_res, sol = tneqsys.solve_scipy(c0, np.array(c0+K))
c0, c_res, sol.success

In [None]:
solve_and_plot(tneqsys)