# Code Generation

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import neural
import numpy as np
from neural.basemodel import Model
from neural.model.neuron import *
import random
import inspect
import math

In [3]:
import ast
from sympy import *
import sympy as sp
import math
from neural.codegen._ast_transformers import NumPy2SymPy
from neural.codegen.parsedmodel import ParsedModel

In [4]:
class FakeModel(Model):
    Default_States = dict(x=0.0, y=0.0, z=0.0)
    Default_Params = dict(a=1.0, b=2.0, c=10.0)

    def ode(self, inp1=0.0, inp2=1.0):
        self.d_x = self.a * (1 - self.x) + self.b * self.x
        self.y = self.x * self.c
        self.z = self.x * self.c
        self.z = 0 if self.z < 1 else 100
        self.y = np.exp(np.cbrt(np.sqrt(self.z))) + random.gauss(0.0, self.c)
        tmp = (self.y > self.z) * self.y
        self.y = math.exp(10)

In [43]:
mod = ParsedModel(Rinzel)

In [84]:
dir(sp.printing)

['StrPrinter',
 'TableForm',
 '__all__',
 '__builtins__',
 '__cached__',
 '__doc__',
 '__file__',
 '__loader__',
 '__name__',
 '__package__',
 '__path__',
 '__spec__',
 'c',
 'ccode',
 'codeprinter',
 'conventions',
 'cxxcode',
 'defaults',
 'dot',
 'dotprint',
 'fcode',
 'glsl',
 'glsl_code',
 'gtk',
 'jscode',
 'julia',
 'julia_code',
 'lambdarepr',
 'latex',
 'maple',
 'maple_code',
 'mathematica',
 'mathematica_code',
 'mathml',
 'multiline_latex',
 'numpy',
 'octave',
 'octave_code',
 'pager_print',
 'pprint',
 'pprint_try_use_unicode',
 'pprint_use_unicode',
 'precedence',
 'pretty',
 'pretty_print',
 'preview',
 'print_ccode',
 'print_fcode',
 'print_glsl',
 'print_gtk',
 'print_jscode',
 'print_latex',
 'print_maple_code',
 'print_mathml',
 'print_python',
 'print_rcode',
 'print_tree',
 'printer',
 'pycode',
 'python',
 'rcode',
 'repr',
 'rust',
 'rust_code',
 'srepr',
 'sstr',
 'sstrrepr',
 'str',
 'tableform',
 'tree']

In [105]:
from sympy.printing.c import C99CodePrinter
from sympy.codegen.ast import Assignment

class MyCodePrinter(C99CodePrinter):
    def __init__(self, parsed_model, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.model = parsed_model

    def _print_Relational(self, expr):
        """Convert sympy.Eq to Assignments"""
        if expr in self.model.ode and isinstance(expr, sp.Eq):
            lhs_code = self._print(expr.lhs)
            rhs_code = self._print(expr.rhs)
            return self._get_statement("%s = %s" % (lhs_code, rhs_code))
        return super()._print_Relational(expr)

    def _print_Symbol(self, expr):
        for var in ['params', 'internals', 'inputs']:
            sym_to_name = {val.sym:val.name for val in getattr(self.model, var).values()}
            if expr in sym_to_name:
                if var in ['internals', 'inputs']:
                    return sym_to_name[expr]
                else:
                    return f"{var}.{sym_to_name[expr]}"
        return super()._print_Symbol(expr)

    def _print_Function(self, expr):
        for var in ['states', 'gstates']:
            sym_to_name = {val.sym:val.name for val in getattr(self.model, var).values()}
            if expr in sym_to_name:
                return f"{var}.{sym_to_name[expr]}"
        return super()._print_Function(expr)
    
    def _print_Derivative(self, expr):
        for var in ['gstates']:
            sym_to_name = {val.sym:val.name for val in getattr(self.model, var).values()}
            if expr in sym_to_name:
                return f"{var}.{sym_to_name[expr]}"
        return super()._print_Function(expr)

In [106]:
for e in mod.ode:
    print(MyCodePrinter(mod).doprint(e))

alpha = exp((-states.v - 55.0)*1.0/10.0) - 1.0;
beta = 0.125*exp((-states.v - 65.0)*1.0/80.0);
alpha = ((fabs(alpha) <= 9.9999999999999995e-8) ? (
   0.10000000000000001
)
: (
   (-0.01*states.v - 0.55000000000000004)/alpha
));
n_infty = alpha/(alpha + beta);
alpha = exp((-states.v - 40.0)*1.0/10.0) - 1.0;
beta = 4.0*exp((-states.v - 65.0)*1.0/18.0);
alpha = ((fabs(alpha) <= 9.9999999999999995e-8) ? (
   1.0
)
: (
   (-0.10000000000000001*states.v - 4.0)/alpha
));
m_infty = alpha/(alpha + beta);
alpha = 0.070000000000000007*exp((-states.v - 65.0)*1.0/20.0);
beta = 1.0*1.0/(exp((-states.v - 35.0)*1.0/10.0) + 1.0);
h_infty = alpha/(alpha + beta);
w_infty = params.s*1.0/(pow(params.s, 2) + 1.0)*(n_infty + params.s*(-h_infty + 1.0));
tau_w = 1.0 + 5.0*exp((-states.v - 55.0)*(states.v + 55.0));
gstates.w = w_infty*3.0/tau_w - states.w*3.0/tau_w;
i_na = params.g_Na*(-params.E_Na + states.v)*pow(m_infty, 3)*(-states.w + 1.0);
i_k = params.g_K*(-params.E_K + states.v)*pow(states.w/params.s, 4)

In [125]:
import numpy as np
import cupy as cp

code = r"""
    template <typename T>
    __global__ void affect(T* const __restrict__ output,
                         const T arg0, const T arg1,
                         const size_t nelements) {
        ptrdiff_t i = blockIdx.x * blockDim.x + threadIdx.x;
        if (i >= nelements)
            return;
        output[i] = (i%2 == 0 ? arg0 : arg1);
        # output[i].x = arg0.x;
    }
"""
module = cp.RawModule(code=code,
                      name_expressions=('affect<float3>',),
                      options=('-std=c++11',))
kernel = module.get_function('affect<float3>')

float3 = np.dtype( { 'names': ['x', 'y', 'z'],
                     'formats': [np.float32]*3 } )

arg0 = np.random.rand(3).astype(np.float32).view(float3)
arg1 = np.random.rand(3).astype(np.float32).view(float3)

N = 512
h_output = np.empty(dtype=float3, shape=N)
d_output = cp.asarray(h_output.view(np.float32))

grid = (N,)
block = (256,)
args = (d_output, arg0, arg1, np.uint64(N))
kernel(grid, block, args)

h_output[0::2] = arg0
h_output[1::2] = arg1

np.testing.assert_array_equal(h_output,
                              cp.asnumpy(d_output).view(float3))

CompileException: /tmp/tmp5qm3gw47/4c0eed21c490bbd6ba856961c76c99d155fa7211.cubin.cu(10): error: unrecognized preprocessing directive

1 error detected in the compilation of "/tmp/tmp5qm3gw47/4c0eed21c490bbd6ba856961c76c99d155fa7211.cubin.cu".


In [123]:
float3 = np.dtype( { 'names': ['x', 'y', 'z'],
                     'formats': [np.float32]*3 } )

arg0 = np.random.rand(3).astype(np.float32).view(float3)
arg1 = np.random.rand(3).astype(np.float32).view(float3)
h_output = np.empty(dtype=float3, shape=N)
d_output = cp.asarray(h_output.view(np.float32))
kernel = cp.ElementwiseKernel(
   "T in1, T in2",
   "T out"
   """
      out.x = in1.x + in2.x;
      out.y = in1.y + in2.y;
      out.z = in1.z + in2.z;
   """,
   "add"
)

Exception: Unknown keyword "T"

In [115]:
d_output

array([0.41416034, 0.5350983 , 0.638774  , ..., 0.810269  , 0.15582092,
       0.74441326], dtype=float32)

In [110]:
arg1

array([(0.810269, 0.15582092, 0.74441326)],
      dtype=[('x', '<f4'), ('y', '<f4'), ('z', '<f4')])

In [108]:
float3

dtype([('x', '<f4'), ('y', '<f4'), ('z', '<f4')])