https://www.youtube.com/watch?v=5jzIVp6bTy0

Slightly beyond 1:30:00

- use a code printer to generate c code
- use an array compatible assignment to print c code
- subclass the printer class and modify it
- utilize cse elimination

In [1]:
import sympy as sym

In [21]:
sym.init_printing()

Here's some junky ODEs stored as a matrix operator

In [28]:
y0, y1, y2, y3, y4 = sym.symbols('y0 y1 y2 y3 y4')

state = sym.Matrix([y0, y1, y2, y3, y4])

lhs_odes = sym.Matrix([
    -25 * y0 * y0 + y3,
    0.1 * y2 * y1 + y2 - 2*y4,
    3 * y4 + y0 * y1 -5,
    8 - y2 * y3,
    y0 - y4**3
])

Now we want to apply this operator to a state vector (matrix in sympy, array in C99), and store the result. Here's how we do it using a MatrixSymbol and a C99Printer

In [29]:
rhs_result = sym.MatrixSymbol('rhs_result', 5, 1)

In [30]:
print(rhs_result)

rhs_result


In [31]:
print(rhs_result[0])

rhs_result[0, 0]


In [32]:
from sympy.printing.ccode import C99CodePrinter

In [33]:
printer = C99CodePrinter()

In [34]:
print(printer.doprint(lhs_odes, assign_to=rhs_result))

rhs_result[0] = -25*pow(y0, 2) + y3;
rhs_result[1] = 0.10000000000000001*y1*y2 + y2 - 2*y4;
rhs_result[2] = y0*y1 + 3*y4 - 5;
rhs_result[3] = -y2*y3 + 8;
rhs_result[4] = y0 - pow(y4, 3);


Let's do the same for the jacobian of this thing, because why not. Note how it automatically handles 2d arrays well.

In [36]:
lhs_odes_jac = lhs_odes.jacobian(state)

jac_result = sym.MatrixSymbol('jac_result', *lhs_odes_jac.shape)
print(printer.doprint(lhs_odes_jac, assign_to=jac_result))

jac_result[0] = -50*y0;
jac_result[1] = 0;
jac_result[2] = 0;
jac_result[3] = 1;
jac_result[4] = 0;
jac_result[5] = 0;
jac_result[6] = 0.10000000000000001*y2;
jac_result[7] = 0.10000000000000001*y1 + 1;
jac_result[8] = 0;
jac_result[9] = -2;
jac_result[10] = y1;
jac_result[11] = y0;
jac_result[12] = 0;
jac_result[13] = 0;
jac_result[14] = 3;
jac_result[15] = 0;
jac_result[16] = 0;
jac_result[17] = -y3;
jac_result[18] = -y2;
jac_result[19] = 0;
jac_result[20] = 1;
jac_result[21] = 0;
jac_result[22] = 0;
jac_result[23] = 0;
jac_result[24] = -3*pow(y4, 2);


All code printers can be customized. Type the following and press tab:

C99CodePrinter._print

You get a list of all possible sympy objects you might want to print

In [38]:
C99CodePrinter._print

<function sympy.printing.printer.Printer._print(self, expr, **kwargs)>

We can look at the source code for them by using ??

In [40]:
C99CodePrinter._print_Symbol??

[1;31mSignature:[0m [0mC99CodePrinter[0m[1;33m.[0m[0m_print_Symbol[0m[1;33m([0m[0mself[0m[1;33m,[0m [0mexpr[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m
[1;31mDocstring:[0m <no docstring>
[1;31mSource:[0m   
    [1;32mdef[0m [0m_print_Symbol[0m[1;33m([0m[0mself[0m[1;33m,[0m [0mexpr[0m[1;33m)[0m[1;33m:[0m[1;33m
[0m        [0mname[0m [1;33m=[0m [0msuper[0m[1;33m([0m[0mC89CodePrinter[0m[1;33m,[0m [0mself[0m[1;33m)[0m[1;33m.[0m[0m_print_Symbol[0m[1;33m([0m[0mexpr[0m[1;33m)[0m[1;33m
[0m        [1;32mif[0m [0mexpr[0m [1;32min[0m [0mself[0m[1;33m.[0m[0m_settings[0m[1;33m[[0m[1;34m'dereference'[0m[1;33m][0m[1;33m:[0m[1;33m
[0m            [1;32mreturn[0m [1;34m'(*{0})'[0m[1;33m.[0m[0mformat[0m[1;33m([0m[0mname[0m[1;33m)[0m[1;33m
[0m        [1;32melse[0m[1;33m:[0m[1;33m
[0m            [1;32mreturn[0m [0mname[0m[1;33m[0m[1;33m[0m[0m
[1;31mFile:[0m      c:\users\zandv\miniconda3\env

(Note how there's a class hierarchy of printers, mimicking the way the successive C standards form supersets)

Using the above information, we can subclass the C99 printer to do something else for specific sympy objects. Here's a trivial example:

In [41]:
class MyCodePrinter(C99CodePrinter):
    def _print_Symbol(self, expr):
        return self._print("I'll always print this text no matter what symbol you pass, lol")

In [42]:
my_printer = MyCodePrinter()

In [45]:
x = sym.symbols('x')
my_printer.doprint(x)

"I'll always print this text no matter what symbol you pass, lol"

Note that using self._print() will make recursion on complex expressions work. It dynamic dispatches to specific printers, so if you only printify part of the expression but want to pass subexpressions on to other parsers, this is how you do it. Eventually, specific printers are expected to return a string, which ends the recursion process and unwinds the stack.