# Sympy vs The World

**Quick summary:**
- Generally, sympy is slow.
    - The initial expression graph construction can be *very* slow.
    - Solution of equation with high polynomial degree can be *extremely* slow.
- Sympy can be quite quick if it manages to simplify the expression.
    - On par with numpy, still slower than numba and jax.
- If an expression is recursive and cannot be simplified, sympy will fail.
    - Failed on 100000 recursions.
 
 
 **One-liner:** As long as performance concerned, THE WORLD WINS!

In [75]:
import scipy as sp
import numpy as np
import sympy as sy
from numba import njit, vectorize as nvectorize
from jax import numpy as jnp, scipy as jsp, jit as jjit
import matplotlib.pyplot as plt

In [3]:
np.set_printoptions(precision=2)

plt.style.use('seaborn')

# Direct computation

## Scalar 1

### python

In [25]:
def foo(a):
    s = 0
    for i in range(100000):
        s += a
    return s

%timeit foo(5.)

2.58 ms ± 51.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


### numba

In [50]:
@njit('f8(f8)')
def foo(a):
    s = 0.
    for i in range(100000):
        s += a
    return s

%timeit foo(5.)

93 µs ± 1.73 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


### jax

In [50]:
@jjit
def foo(a):
    s = 0.
    for i in range(100000):
        s += a
    return s

%timeit foo(5.)

93 µs ± 1.73 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


### sympy

In [27]:
%%time
a = sy.symbols('a')
s = 0
for i in range(100000):
    s = s+a

CPU times: user 10.2 s, sys: 0 ns, total: 10.2 s
Wall time: 10.2 s


In [35]:
%timeit _ = s.evalf(subs=dict(a=5.))

172 µs ± 1.73 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


### Scalar 2

### python

In [89]:
def foo(a):
    p = 1
    s = 0
    for i in range(10000):
        s += p
        p *= a
    return s

print(foo(1.001))
%timeit foo(1.001)

21915681.339056846
405 µs ± 7.38 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


### numba

In [58]:
@njit('f8(f8)')
def foo(a):
    p = 1.
    s = 0.
    for i in range(10000):
        s = s + p
        p = p * a
    return s

print(foo(1.001))
%timeit foo(1.001)

21915681.339056846
9.19 µs ± 88.9 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


### sympy

In [59]:
%%time
a = sy.symbols('a')
s = 0
p = 1
for i in range(10000):
    s = s + p
    p = p * a

CPU times: user 3min 42s, sys: 264 ms, total: 3min 42s
Wall time: 3min 42s


In [63]:
print(s.evalf(subs=dict(a=1.001)))
%timeit _ = s.evalf(subs=dict(a=1.001))

21915681.3390567
894 ms ± 8.39 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [67]:
foo = sy.lambdify(a, s)
foo(1.001)

RecursionError: maximum recursion depth exceeded during compilation

In [88]:
# --- simplified 
a = sy.symbols('a')
s = 0
p = 1
for i in range(1000):
    s = s + p
    p = p * a
    
foo = sy.lambdify(a, s)
%timeit foo(1.001)

47.3 µs ± 1.38 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


## Vector 1

### python

In [74]:
def foo(a):
    p = 1
    s = 0
    for i in range(10000):
        s += p
        p *= a
    return s

%timeit [foo(1.001) for _ in range(10000)]

4.1 s ± 49.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### numpy

In [73]:
@np.vectorize
def foo(a):
    p = 1
    s = 0
    for i in range(10000):
        s += p
        p *= a
    return s

%timeit foo(np.full(10000, 1.001))

3.99 s ± 66.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### numba

In [76]:
@nvectorize
def foo(a):
    p = 1
    s = 0
    for i in range(10000):
        s += p
        p *= a
    return s

%timeit foo(np.full(10000, 1.001))

90.5 ms ± 808 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


### jax

In [79]:
@jnp.vectorize
@jjit
def foo(a):
    p = 1
    s = 0
    for i in range(10000):
        s += p
        p *= a
    return s

%timeit foo(jnp.full(10000, 1.001))

The slowest run took 5.69 times longer than the fastest. This could mean that an intermediate result is being cached.
779 µs ± 665 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


### sympy

In [109]:
%%time
# --- simplified
a = sy.symbols('a')
s = 0
p = 1.
for i in range(1000):
    s = s + p
    p = p * a

CPU times: user 7.21 s, sys: 7.75 ms, total: 7.21 s
Wall time: 7.21 s


In [110]:
%%time
from sympy.utilities.autowrap import ufuncify
foo = ufuncify([a], s)

CPU times: user 412 ms, sys: 4.03 ms, total: 416 ms
Wall time: 2.17 s


In [111]:
%timeit foo(np.full(10000, 1.001))

147 ms ± 2.02 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


# Root finding

## Scalar 1

### python

In [127]:
def foo(a):
    p = 1
    s = 0
    for i in range(1000):
        s += p
        p *= a
    return s


print(sp.optimize.brentq(lambda x: foo(x)-10, 0, 2))
%timeit sp.optimize.brentq(lambda x: foo(x)-10, 0, 2)

0.9000000000000001
642 µs ± 14.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


### numba

In [128]:
@njit('f8(f8)')
def foo(a):
    p = 1.
    s = 0
    for i in range(1000):
        s += p
        p *= a
    return s


print(sp.optimize.brentq(lambda x: foo(x)-10, 0, 2))
%timeit sp.optimize.brentq(lambda x: foo(x)-10, 0, 2)

0.9000000000000001
19.9 µs ± 504 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


### sympy

In [134]:
%%time
#--- very simplified
a = sy.symbols('a')
s = 0
p = 1
for i in range(10):
    s = s + p
    p = p * a

CPU times: user 1.22 ms, sys: 3 µs, total: 1.23 ms
Wall time: 1.23 ms


In [135]:
%%time
sy.solve(sy.Eq(s, 10))

CPU times: user 280 ms, sys: 0 ns, total: 280 ms
Wall time: 278 ms


[1,
 CRootOf(x**8 + 2*x**7 + 3*x**6 + 4*x**5 + 5*x**4 + 6*x**3 + 7*x**2 + 8*x + 9, 0),
 CRootOf(x**8 + 2*x**7 + 3*x**6 + 4*x**5 + 5*x**4 + 6*x**3 + 7*x**2 + 8*x + 9, 1),
 CRootOf(x**8 + 2*x**7 + 3*x**6 + 4*x**5 + 5*x**4 + 6*x**3 + 7*x**2 + 8*x + 9, 2),
 CRootOf(x**8 + 2*x**7 + 3*x**6 + 4*x**5 + 5*x**4 + 6*x**3 + 7*x**2 + 8*x + 9, 3),
 CRootOf(x**8 + 2*x**7 + 3*x**6 + 4*x**5 + 5*x**4 + 6*x**3 + 7*x**2 + 8*x + 9, 4),
 CRootOf(x**8 + 2*x**7 + 3*x**6 + 4*x**5 + 5*x**4 + 6*x**3 + 7*x**2 + 8*x + 9, 5),
 CRootOf(x**8 + 2*x**7 + 3*x**6 + 4*x**5 + 5*x**4 + 6*x**3 + 7*x**2 + 8*x + 9, 6),
 CRootOf(x**8 + 2*x**7 + 3*x**6 + 4*x**5 + 5*x**4 + 6*x**3 + 7*x**2 + 8*x + 9, 7)]