## Analytic solution
Consider this system of chemical reactions:

<pre>
A -&gt; B; k=a*f(t)
B -&gt; C; k=b*f(t)

f(t) = 1/(1s + t)
0 &lt;= t
</pre>

the system of ODEs is then defined as:
$$
\frac{\rm{d}A}{\rm{d}t} = -a f(t) A \\
\frac{\rm{d}B}{\rm{d}t} = -b f(t) B + a f(t) A \\
\frac{\rm{d}C}{\rm{d}t} = +b f(t) B
$$

In [None]:
from sympy import symbols, Function, Tuple, Eq, dsolve, cse

In [None]:
t, a, b, A0, B0, C0 = symbols('t a b A0 B0 C0', real=True, nonnegative=True)
funcs_anon = f, A, B, C = symbols('f A B C', cls=Function, nonnegative=True)
funcs_t = ft, At, Bt, Ct = [_(t) for _ in funcs_anon]
funcs_dt = dAdt, dBdt, dCdt = [_.diff(t) for _ in funcs_t[1:]]
e_ft = 1/(1+t)
r1 = a*e_ft*At
r2 = b*e_ft*Bt
rhss = e_dAdt, e_dBdt, e_dCdt = -r1, -r2 + r1, r2
odes3 = Tuple(*[Eq(_1, _2) for _1, _2 in zip(funcs_dt, rhss)])
odes3

In [None]:
sA = dsolve(odes3[0], func=At, ics={A(0): A0})
assert sA.subs(t, 0).rhs - A0 == 0
assert sA.rhs.diff(t) - odes3[0].rhs.subs({At: sA.rhs}) == 0
sA

In [None]:
odes2 = odes3[1:].subs({sA.lhs: sA.rhs})
odes2

In [None]:
# SymPy's dsolve (effectively?) hangs for below input:
#sB = dsolve(odes2[0], func=Bt, ics={B(0): B0})
# so we prescribe the analytic solution directly:
_c1 = B0 - a*A0/(b-a)
sB = Eq(Bt, (a*A0*(t+1)**-a / (b-a) + _c1*(t+1)**-b))#.simplify())
assert sB.subs(t, 0).rhs - B0 == 0
assert sB.rhs.diff(t).expand().factor() - odes2[0].rhs.subs({Bt: sB.rhs}).expand().factor() == 0
sB

In [None]:
sB.rhs

In [None]:
odes1 = odes2[1:].subs({sB.lhs: sB.rhs})
odes1

In [None]:
_sC = (t+1)**(-a-b)*(b*(B0*(t+1)**a + A0*(t+1)**b) - a*(t+1)**a*(A0+B0))/(a-b)
sC = Eq(Ct, (_sC + C0 - _sC.subs(t, 0)))
assert sC.subs(t, 0).rhs - C0 == 0
assert sC.rhs.diff(t).expand().factor() - odes1[0].rhs.subs({Ct: sC.rhs}).expand().factor() == 0
sC

In [None]:
analytic_rhss = [sA.rhs, sB.rhs, sC.rhs]
cses, red = cse(analytic_rhss)
cses, red

In [None]:
[r.subs(cses[::-1]) - ref for r, ref in zip(red, analytic_rhss)]

### Deferred below
Better CSEs but implementation not yet complete.

In [None]:
from collections import defaultdict

def my_cse(exprs):
    cses, red = cse(exprs)
    new_keys, new_values = [], []
    backsubs = {}
    denoms = defaultdict(list)
    for x, se in cses:
        _a, _b = se.as_coeff_Mul()
        if _a == -1 and _b.is_Symbol:
            backsubs[x] = se
            continue
            
        #if -se in new_values:
        #    backsubs[x] = -new_keys[new_values.index(-se)]
        #    continue
        
        new_keys.append(x)
        se = se.subs(backsubs)
        new_values.append(se)
        
        numer, denom = se.as_numer_denom()
        if not denom.is_number:
            denoms[denom].append((x, numer))
    taken = set()
    for denom, pairs1 in denoms.items():
        if denom in taken:
            continue
        pairs2 = denoms.get(-denom, None)
        if pairs2 is None:
            continue
        else:
            taken.add(-denom)
        
        def _cost(numer):
            return int(numer.is_number), abs(numer) if numer.is_number else numer.count_ops()
        pairs12 = (pairs1, pairs2)
        scores = [sorted([(x, _cost(nur)) for x, nur in pairs], key=lambda _: _[1]) for pairs in pairs12]
        x1, sc1 = scores[0][0]
        x2, sc2 = scores[1][0]
        use1 = sc1 < sc2
        x = (x1, x2)[use1]
        pairs12d = [dict(pairs) for pairs in pairs12]
        print(pairs12d)
        nur = pairs12d[use1][x]
        kk, vv = [], []
        for k, v in zip(new_keys, new_values):
            kk.append(k)
            if k != x and k in pairs12d[0]:
                side = 0
            elif k != x and k in pairs12d[1]:
                side = 1
            else:
                vv.append(v)
                continue
            factor = -1 if (use1 ^ side) else 1
            print(x, k, v, factor)
            expr = factor*x*pairs12d[side][k]/pairs12d[use1][x]
            print(expr)
            vv.append(expr)
            
        new_keys = kk
        new_values = vv
        
    return list(zip(new_keys, new_values)), [_.subs(backsubs) for _ in red]

In [None]:
cses2, red2 = my_cse(analytic_rhss)
assert all(r.subs(cses2[::-1]) - ref == 0 for r, ref in zip(red2, analytic_rhss))
cses2, red2

In [None]:
cses2

In [None]:
numer, denom = cses2[3][1].as_numer_denom()

In [None]:
denom.as_numer_denom()

In [None]:
denom.is_number