From 8f1fac68dea0f4902cc031367360a77b1a6ff4c7 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Sat, 1 Jun 2024 23:16:39 +0800 Subject: [PATCH 1/2] support `Integrator.to_math_expr()` --- .../_src/integrators/_jaxpr_to_source_code.py | 1132 +++++++++++++++++ brainpy/_src/integrators/base.py | 54 +- brainpy/_src/integrators/ode/base.py | 2 +- .../integrators/tests/test_to_source_code.py | 48 + brainpy/_src/math/others.py | 4 +- brainpy/integrators/__init__.py | 1 + 6 files changed, 1237 insertions(+), 4 deletions(-) create mode 100644 brainpy/_src/integrators/_jaxpr_to_source_code.py create mode 100644 brainpy/_src/integrators/tests/test_to_source_code.py diff --git a/brainpy/_src/integrators/_jaxpr_to_source_code.py b/brainpy/_src/integrators/_jaxpr_to_source_code.py new file mode 100644 index 000000000..3fa1d9006 --- /dev/null +++ b/brainpy/_src/integrators/_jaxpr_to_source_code.py @@ -0,0 +1,1132 @@ +# Modified from: https://github.com/dlwh/jax_sourceror +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import ast +import enum +import warnings +from collections.abc import MutableMapping, MutableSet +from contextlib import contextmanager +from dataclasses import dataclass, field +from typing import Callable, Union + +import jax +import jax.numpy as jnp +import numpy as np +from jax._src.sharding_impls import UNSPECIFIED +from jax.core import Literal, Var, Jaxpr + +__all__ = [ + 'fn_to_python_code', + 'jaxpr_to_python_code', +] + + +class IdentitySet(MutableSet): + """Set that compares objects by identity. + + This is a set that compares objects by identity instead of equality. It is + useful for storing objects that are not hashable or that should be compared + by identity. + + This is a mutable set, but it does not support the ``__hash__`` method and + therefore cannot be used as a dictionary key or as an element of another set. + """ + + def __init__(self, iterable=None): + self._data = {} + if iterable is not None: + self.update(iterable) + + def __contains__(self, value): + return id(value) in self._data + + def __iter__(self): + return iter(self._data.values()) + + def __len__(self): + return len(self._data) + + def add(self, value): + self._data[id(value)] = value + + def discard(self, value): + self._data.pop(id(value), None) + + def __repr__(self): + return f"IdentitySet({list(repr(x) for x in self._data.values())})" + + def __str__(self): + return f"IdentitySet({list(str(x) for x in self._data.values())})" + + +class IdentityMap(MutableMapping): + """Map that compares keys by identity. + + This is a map that compares keys by identity instead of equality. It is + useful for storing objects that are not hashable or that should be compared + by identity. + + This is a mutable mapping, but it does not support the ``__hash__`` method + and therefore cannot be used as a dictionary key or as an element of another + set. + """ + + def __init__(self, iterable=None): + self._data = {} + if iterable is not None: + self.update(iterable) + + def __contains__(self, key): + return id(key) in self._data + + def __getitem__(self, key): + return self._data[id(key)] + + def __setitem__(self, key, value): + self._data[id(key)] = value + + def __delitem__(self, key): + del self._data[id(key)] + + def __iter__(self): + return iter(self._data.values()) + + def __len__(self): + return len(self._data) + + def __repr__(self): + return f"IdentityMap({list(repr(x) for x in self._data.values())})" + + def __str__(self): + return f"IdentityMap({list(str(x) for x in self._data.values())})" + + +@dataclass +class SourcerorState: + """State for the auto-minimizer. Basically just in charge of naming variables.""" + _var_names: IdentityMap[Var, str] = field(default_factory=IdentityMap) + _skolem_count: int = 0 + + def name(self, var, ctx=ast.Load()) -> ast.Name: + return ast.Name(id=self.str_name(var), ctx=ctx) + + def str_name(self, var: Var): + # Names things in a way vaguely compatible with + # JAX's naming scheme, which is 'a'-'z' followed + # by 'aa'-'az' etc. + if var in self._var_names: + return self._var_names[var] + else: + cur_count = len(self._var_names) + name = "" + while cur_count >= 26: + name += chr(ord('a') + cur_count % 26) + cur_count //= 26 + + name += chr(ord('a') + cur_count) + + name = name[::-1] + + self._var_names[var] = name + + return name + + def skolem(self, prefix: str): + self._skolem_count += 1 + return f"{prefix}_{self._skolem_count}" + + +prefix_imports = set() + + +@contextmanager +def catch_imports(): + try: + prefix_imports.clear() + yield + finally: + prefix_imports.clear() + + +def fn_to_python_code(fn, *args, **kwargs): + """ + Given a function which is defined by jax primitives and the function arguments, + return the Python code that would be generated by JAX for that function. + + :param fn: The function to generate code for + :param args: The positional arguments to the function + :param kwargs: The keyword arguments to the function + :return: The Python code that would be generated by JAX for that function + """ + closed_jaxpr = jax.make_jaxpr(fn)(*args, **kwargs) + jaxpr = constant_fold_jaxpr(closed_jaxpr.jaxpr) + state = SourcerorState() + try: + name = fn.__name__ + except AttributeError: + name = "unknown" + with catch_imports(): + node = jaxpr_to_py_ast(state, jaxpr, fn_name=name) + node = _maybe_wrap_fn_for_leaves(node, fn, len(args) + len(kwargs)) + ast.fix_missing_locations(node) + source = ast.unparse(node) + if len(prefix_imports): + source = "\n".join(prefix_imports) + "\n\n" + source + return source + + +def jaxpr_to_python_code(jaxpr: jax.core.Jaxpr, + fn_name: str = "generated_function"): + """ + Given a JAX jaxpr, return the Python code that would be generated by JAX for that jaxpr. + + :param jaxpr: The jaxpr to generate code. + :param fn_name: The name of the function to generate code. + :return: The Python code that would be generated by JAX for that jaxpr + """ + jaxpr = constant_fold_jaxpr(jaxpr) + state = SourcerorState() + with catch_imports(): + node = jaxpr_to_py_ast(state, jaxpr, fn_name=fn_name) + ast.fix_missing_locations(node) + source = ast.unparse(node) + if len(prefix_imports): + source = "\n".join(prefix_imports) + "\n\n" + source + return source + + +def register_prim_handler(prim_name, handler): + """ + Register a handler for a primitive for automin + :param prim_name: + :param handler: + :return: + """ + if prim_name in prim_to_python: + warnings.warn(f"Overwriting handler for primitive {prim_name}") + prim_to_python[prim_name] = handler + + +def register_prim_as(prim_name): + """ + Decorator to register a handler for a primitive. + + :param prim_name: + :return: + """ + + def decorator(fn): + register_prim_handler(prim_name, fn) + return fn + + return decorator + + +def _assign_stmt(call_expr: Callable): + """ + Create a handler for a primitive that is a simple assignment. + :param call_expr: + :return: + """ + + def binop_fn(state, eqn): + invars = [_astify_atom(state, v) for v in eqn.invars] + outvars = _astify_outvars(state, eqn.outvars) + return ast.Assign( + outvars, + call_expr( + *invars, + **{k: _astify_value(v) for k, v in eqn.params.items()} + ) + ) + + return binop_fn + + +def _binop_fn(op: ast.operator): + return _assign_stmt(lambda x, y: ast.BinOp(left=x, op=op, right=y)) + + +def _cmpop_fn(op: ast.cmpop): + return _assign_stmt(lambda x, y: ast.Compare(left=x, ops=[op], comparators=[y])) + + +def normal_fn(fn_name): + """ + Create a handler for a normal function call. + :param fn_name: + :return: + """ + return _assign_stmt( + lambda *args, **kwargs: ast.Call( + func=ast.Name(id=fn_name, ctx=ast.Load()), + args=list(args), + keywords=[ast.keyword(arg=k, value=v) for k, v in kwargs.items()] + ) + ) + + +def _reduce_fn(fn_name: str): + def reduce_fn_inner(state: SourcerorState, eqn): + invars = [_astify_atom(state, v) for v in eqn.invars] + outvars = _astify_outvars(state, eqn.outvars) + if eqn.params: + params = eqn.params.copy() + params['axis'] = tuple(params['axes']) + del params['axes'] + call_op = ast.Call( + func=ast.Name(id=fn_name, ctx=ast.Load()), + args=invars, + keywords=[ast.keyword(arg=k, value=_astify_value(v)) for k, v in params.items()] + ) + else: + call_op = ast.Call( + func=ast.Name(id=fn_name, ctx=ast.Load()), + args=invars, + keywords=[] + ) + + return ast.Assign(outvars, call_op) + + return reduce_fn_inner + + +prim_to_python = dict() + +register_prim_handler('add', _binop_fn(ast.Add())) +register_prim_handler('sub', _binop_fn(ast.Sub())) +register_prim_handler('mul', _binop_fn(ast.Mult())) +register_prim_handler('div', _binop_fn(ast.Div())) +register_prim_handler('neg', normal_fn('jax.lax.neg')) +register_prim_handler('lt', _cmpop_fn(ast.Lt())) +register_prim_handler('gt', _cmpop_fn(ast.Gt())) +register_prim_handler('le', _cmpop_fn(ast.LtE())) +register_prim_handler('ge', _cmpop_fn(ast.GtE())) +register_prim_handler('eq', _cmpop_fn(ast.Eq())) +register_prim_handler('ne', _cmpop_fn(ast.NotEq())) +register_prim_handler('min', normal_fn('jax.lax.min')) +register_prim_handler('max', normal_fn('jax.lax.max')) +register_prim_handler('select_n', normal_fn('jax.lax.select_n')) +register_prim_handler('squeeze', normal_fn('jax.lax.squeeze')) +register_prim_handler('broadcast', normal_fn('jax.lax.broadcast')) +register_prim_handler('reduce_sum', _reduce_fn('jax.numpy.sum')) +register_prim_handler('transpose', normal_fn('jax.lax.transpose')) + + +def _maybe_wrap_fn_for_leaves(node, f, num_args): + if len(node.args.args) == num_args: + return node + + wrapped_node = ast.FunctionDef( + name=f.__name__, + args=ast.arguments( + args=[], + vararg=ast.arg(arg="args", annotation=None), + kwarg=ast.arg(arg="kwargs", annotation=None), + kwonlyargs=[], kw_defaults=[], defaults=[], + posonlyargs=[] + ), + body=[ + node, + ast.Return( + ast.Call( + func=ast.Name(id=node.name, ctx=ast.Load()), + args=[ + ast.Starred( + ast.Call( + func=ast.Attribute(value=ast.Name(id="jax", ctx=ast.Load()), + attr="tree_leaves", + ctx=ast.Load()), + args=[ast.Tuple(elts=[ast.Name(id="args", ctx=ast.Load()), + ast.Name(id="kwargs", ctx=ast.Load())], + ctx=ast.Load())], + keywords=[] + ) + ) + ], + keywords=[] + ) + ), + ], + decorator_list=[] + ) + + return wrapped_node + + +def jaxpr_to_py_ast(state: SourcerorState, + jaxpr: jax.core.Jaxpr, + fn_name: str = "function"): + # Generate argument declarations + ast_args = [ast.arg(arg=state.str_name(var), annotation=None) + for var in jaxpr.invars] + ast_args = ast.arguments(args=ast_args, + vararg=None, + kwonlyargs=[], + kw_defaults=[], + kwarg=None, + defaults=[], + posonlyargs=[]) + + stmts = [] + + # Generate body of the function + for eqn in jaxpr.eqns: + prim = str(eqn.primitive) + if prim in prim_to_python: + eqn_stmts = prim_to_python[prim](state, eqn) + else: + eqn_stmts = normal_fn(prim)(state, eqn) + + if isinstance(eqn_stmts, list): + stmts.extend(eqn_stmts) + else: + stmts.append(eqn_stmts) + + # Generate return statement + if len(jaxpr.outvars) == 1: + returns = state.name(jaxpr.outvars[0]) + else: + returns = ast.Tuple(elts=[state.name(var) for var in jaxpr.outvars], ctx=ast.Load()) + stmts.append(ast.Return(value=returns)) + + return ast.FunctionDef(name=fn_name, args=ast_args, body=stmts, decorator_list=[]) + + +def constant_fold_jaxpr(jaxpr: jax.core.Jaxpr): + """ + Given a jaxpr, return a new jaxpr with all constant folding done. + """ + return partial_eval_jaxpr(jaxpr, {}) + + +def partial_eval_jaxpr(jaxpr, env): + env = env.copy() + new_eqns = [] + + def read(var): + if isinstance(var, Literal): + return var.val + else: + return env.get(var, None) + + def read_or_self(var): + out = read(var) + if out is None: + return var + elif isinstance(out, Var): + return out + elif isinstance(out, Literal): + return Literal(out.val, var.aval) + else: + assert not isinstance(out, Jaxpr) + return Literal(out, var.aval) + + for eqn in jaxpr.eqns: + vals = [read(var) for var in eqn.invars] + if eqn.primitive.name in constant_fold_blacklist: + new_eqns.append(eqn) + elif all(val is not None for val in vals): + # go ahead and eval it + out = _eval_eqn(eqn, vals) + + # two options: either it's a jaxpr result (partial eval) or it's a value or a list of values + if isinstance(out, Jaxpr): + # we need to inline this + new_eqns.extend(out.eqns) + out = out.outvars + elif not isinstance(out, tuple) and not isinstance(out, list): + out = (out,) + + for var, val in zip(eqn.outvars, out): + assert not isinstance(val, Jaxpr) + if isinstance(val, Literal): + env[var] = val.val + else: + env[var] = val + else: + new_eqns.append(eqn) + + # now that we've evaled everything, inline all the constants + out_eqns = [] + for eqn in new_eqns: + eqn = eqn.replace(invars=tuple(read_or_self(var) for var in eqn.invars)) + out_eqns.append(eqn) + + invars_still_used = IdentitySet() + for eqn in out_eqns: + for var in eqn.invars: + invars_still_used.add(var) + + invars = tuple(var for var in jaxpr.invars if var in invars_still_used) + + # sub in any constants for outvars + outvars = tuple(read_or_self(var) for var in jaxpr.outvars) + + return jaxpr.replace(eqns=out_eqns, outvars=outvars, invars=invars) + + +def _eval_eqn(eqn, vals) -> Union[Jaxpr, tuple, list, jnp.ndarray]: + if eqn.primitive.name == "closed_call": + assert eqn.primitive.call_primitive == True + assert eqn.primitive.map_primitive == False + + out = partial_eval_jaxpr(eqn.params['call_jaxpr'].jaxpr, + {var: val for var, val in zip(eqn.params['call_jaxpr'].jaxpr.invars, vals)}) + elif eqn.primitive.name == "scan": + out = eqn.primitive.bind(*vals, **eqn.params) + else: + out = eqn.primitive.bind(*vals, **eqn.params) + return out + + +@register_prim_as('dot_general') +def _astify_dot_general(state, eqn): + x, y = eqn.invars + d = eqn.params['dimension_numbers'] + precision = eqn.params['precision'] + preferred_element_type = eqn.params['preferred_element_type'] + + has_dtype = preferred_element_type is None or x.aval.dtype == y.aval.dtype == preferred_element_type + + # recognize simple matmul case + if d == (((1,), (0,)), ((), ())) and precision == None: + invars = [_astify_atom(state, x), _astify_atom(state, y)] + outvars = _astify_outvars(state, eqn.outvars) + out = ast.Assign(targets=outvars, value=ast.Call( + func=ast.Attribute(value=ast.Name(id='jax.numpy', ctx=ast.Load()), attr='matmul', ctx=ast.Load()), args=invars, + keywords=[])) + if not has_dtype: + out = ast.Assign(targets=outvars, + value=ast.Call(func=ast.Attribute(value=out.value, attr='astype', ctx=ast.Load()), + args=[_astify_value(preferred_element_type)], keywords=[])) + + return out + + # TODO: convert to einsum? + + invars = [_astify_atom(state, x), + _astify_atom(state, y), + _astify_value(d), + _astify_value(precision), + _astify_value(preferred_element_type)] + outvars = _astify_outvars(state, eqn.outvars) + return ast.Assign( + targets=outvars, + value=ast.Call( + func=ast.Attribute(value=ast.Name(id='jax.lax', ctx=ast.Load()), attr='dot_general', ctx=ast.Load()), + args=invars, + keywords=[] + ) + ) + + +@register_prim_as('dynamic_slice') +def _sourcify_dynamic_slice(state, eqn): + sliced = eqn.invars[0] + invars = ast.Tuple(elts=[_astify_atom(state, var) for var in eqn.invars[1:]], ctx=ast.Load()) + outvars = _astify_outvars(state, eqn.outvars) + params = [ast.keyword(arg=k, value=_astify_value(v)) for k, v in eqn.params.items()] + return ast.Assign( + targets=outvars, + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id='jax.lax', ctx=ast.Load()), + attr='dynamic_slice', + ctx=ast.Load() + ), + args=[_astify_atom(state, sliced), invars], + keywords=params + ) + ) + + +@register_prim_as('slice') +def _sourcify_slice(state, eqn): + sliced = eqn.invars[0] + # invars = ast.Tuple(elts=[_astify_atom(state, var) for var in eqn.invars[1:]], ctx=ast.Load()) + outvars = _astify_outvars(state, eqn.outvars) + start_indices = eqn.params['start_indices'] + limit_indices = eqn.params['limit_indices'] + strides = eqn.params['strides'] + if strides is None: + strides = (None,) * len(start_indices) + indices = [_astify_value(slice(s, e, stride)) + for s, e, stride in zip(start_indices, limit_indices, strides)] + # params = [ast.keyword(arg=k, value=_astify_value(v)) for k, v in eqn.params.items()] + return ast.Assign( + targets=outvars, + value=ast.Subscript( + value=_astify_atom(state, sliced), + slice=ast.Tuple(elts=indices, ctx=ast.Load()), + ctx=ast.Load() + ) + ) + + +@register_prim_as('dynamic_update_slice') +def _sourcify_dynamic_update_slice(state, eqn): + sliced = eqn.invars[0] + # the first two arguments are the sliced array and the update array + # the remaining are start indices and should be packaged into a tuple + target = _astify_atom(state, eqn.invars[0]) + update = _astify_atom(state, eqn.invars[1]) + start_indices = maybe_tuple_vars([_astify_atom(state, var) for var in eqn.invars[2:]]) + outvars = _astify_outvars(state, eqn.outvars) + + return ast.Assign(targets=outvars, value=ast.Call( + func=ast.Attribute( + value=ast.Name(id='jax.lax', ctx=ast.Load()), + attr='dynamic_update_slice', + ctx=ast.Load() + ), + args=[target, update, start_indices], + keywords=[] + )) + + +@register_prim_as('convert_element_type') +def _astify_convert_element_type(state, eqn): + # now we use ast + outvars = _astify_outvars(state, eqn.outvars) + assert len(eqn.invars) == 1 + invar = _astify_atom(state, eqn.invars[0]) + dtype = _astify_value(eqn.params['new_dtype']) + return ast.Assign(targets=outvars, value=ast.Call( + func=ast.Attribute( + value=invar, + attr='astype', + ctx=ast.Load() + ), + args=[dtype], + keywords=[] + )) + + +def is_array(arr): + return isinstance(arr, (np.ndarray, np.generic, jnp.ndarray)) + + +def _astify_array(value): + assert is_array(value) + if isinstance(value, np.int64): + return ast.Constant(value=int(value)) + + if value.ndim == 0 and value.dtype in (jnp.float32, jnp.int32, jnp.bool_, jnp.int64): + return ast.Constant(value=value.item()) + + if value.ndim == 0: + dtype_value = _astify_value(value.dtype) + return ast.Call( + dtype_value, + args=[ast.Constant(value=value.item())], + keywords=[], + ) + + values = value.tolist() + + def rec_astify_list(values): + if isinstance(values, list): + return ast.List(elts=[rec_astify_list(val) for val in values], ctx=ast.Load()) + else: + return ast.Constant(value=values) + + return ast.Call( + func=ast.Attribute( + value=ast.Name(id='jax.numpy', ctx=ast.Load()), + attr='array', + ctx=ast.Load() + ), + args=[rec_astify_list(values)], + keywords=[ast.keyword(arg='dtype', + value=_astify_value(value.dtype))] + ) + + +def _astify_atom(state: SourcerorState, var: Union[Literal, Var]): + if isinstance(var, Literal): + return _astify_value(var.val) + elif isinstance(var, Var): + return state.name(var) + else: + raise NotImplementedError() + + +def _astify_value(value): + assert not isinstance(value, (Literal, Var)) + + if is_array(value): + return _astify_array(value) + elif isinstance(value, (int, bool, float, str, type(None))): + return ast.Constant(value=value) + elif isinstance(value, (tuple, list)): + return ast.Tuple(elts=[_astify_value(v) for v in value], ctx=ast.Load()) + elif isinstance(value, jnp.dtype): + # return ast.Call(func=ast.Attribute(value=ast.Name(id='jax.numpy', ctx=ast.Load()), attr='dtype', ctx=ast.Load()), args=[ast.Constant(value=str(value))], keywords=[]) + if value.name in ('float32', 'float64', 'int32', 'int64', 'bfloat16', 'float16'): + # return ast.Constant(value=getattr(jnp, value.name)) + return ast.Attribute( + value=ast.Name(id='jax.numpy', ctx=ast.Load()), + attr=value.name, + ctx=ast.Load() + ) + elif value.name == 'bool': + return ast.Attribute( + value=ast.Name(id='jax.numpy', ctx=ast.Load()), + attr='bool_', + ctx=ast.Load() + ) + else: + return ast.Call( + func=ast.Attribute(value=ast.Name(id='jax.numpy', ctx=ast.Load()), + attr='dtype', + ctx=ast.Load()), + args=[ast.Constant(value=str(value))], + keywords=[] + ) + elif value is UNSPECIFIED: + prefix_imports.add('from jax._src.sharding_impls import UNSPECIFIED') + return ast.Name(id='UNSPECIFIED', ctx=ast.Load()) + elif isinstance(value, enum.Enum): + return ast.Attribute( + value=ast.Name(id=value.__class__.__qualname__, ctx=ast.Load()), + attr=value.name, + ctx=ast.Load() + ) + + else: + warnings.warn(f"Unknown value type {type(value)}") + return ast.parse(repr(value)).body[0] + + +def _astify_outvars(state, outvars): + out = [state.name(v, ctx=ast.Store()) for v in outvars] + if len(out) == 1: + return out + else: + return [ast.Tuple(elts=out, ctx=ast.Store())] + + +def maybe_tuple_vars(vars): + if len(vars) == 1: + return vars[0] + else: + return ast.Tuple(elts=vars, ctx=ast.Load()) + + +def maybe_untuple_vars(var, is_tuple): + if is_tuple: + return ast.Starred(value=var, ctx=ast.Load()) + else: + return var + + +@register_prim_as('scan') +def _astify_scan(state, eqn): + assert eqn.primitive.name == 'scan' + + # the args to scan are [constants, carry, xs] + # constants aren't exposed in the Python API, so we need to handle them specially (we use a lambda) + num_consts = eqn.params['num_consts'] + num_carry = eqn.params['num_carry'] + + # TODO: bring back map + # if num_carry == 0: + # this is a map + # return _astify_map(eqn) + + constant_args = eqn.invars[:num_consts] + carries = eqn.invars[num_consts:num_consts + num_carry] + xs = eqn.invars[num_consts + num_carry:] + + jaxpr = eqn.params['jaxpr'] + + if num_consts != 0: + # we want to construct an environment where we partial eval the function using the constants as the env + env = dict(zip(jaxpr.jaxpr.invars, constant_args)) + jaxpr = partial_eval_jaxpr(jaxpr.jaxpr, env) + else: + jaxpr = constant_fold_jaxpr(jaxpr.jaxpr) + + fn_name = state.skolem('fn') + fn_ast = jaxpr_to_py_ast(state, jaxpr, fn_name) + + length = _astify_value(eqn.params['length']) + unroll = _astify_value(eqn.params['unroll']) + reverse = _astify_value(eqn.params['reverse']) + + stmts = [] + + if num_carry != 1 or len(jaxpr.invars) != 2: + # what we want is something like: + # fn_name = lambda carry, xs: fn_name(*carry, *xs) + # jax.lax.scan(fn_name, (carries...), (xs...)) + + modified_signature = ast.arguments( + args=[ast.arg(arg='carry'), ast.arg(arg='x')], + vararg=None, + kwonlyargs=[], + kw_defaults=[], + kwarg=None, + defaults=[], + posonlyargs=[] + ) + + initial_assign = ast.Assign( + targets=[ast.Tuple(elts=[ast.Name(a.arg) for a in fn_ast.args.args], + ctx=ast.Store())], + value=ast.Tuple( + elts=[maybe_untuple_vars(ast.Name(id='carry', ctx=ast.Load()), num_carry != 1), + maybe_untuple_vars(ast.Name(id='x', ctx=ast.Load()), len(xs) != 1)] + ) + ) + + fn_return = fn_ast.body[-1] + assert isinstance(fn_return, ast.Return) + + fn_return_value = fn_return.value + + if isinstance(fn_return_value, ast.Tuple): + fn_return_value = fn_return_value.elts + ret_carries = maybe_tuple_vars(fn_return_value[:num_carry]) + ret_ys = maybe_tuple_vars(fn_return_value[num_carry:]) + elif num_carry == 0: + ret_carries = _astify_value(()) + ret_ys = fn_return_value + else: + ret_carries = fn_return_value + ret_ys = _astify_value(()) + + scan_return = ast.Return( + value=ast.Tuple(elts=[ret_carries, ret_ys], ctx=ast.Load()) + ) + + new_body = [initial_assign] + list(fn_ast.body[:-1]) + [scan_return] + + fn_ast = ast.FunctionDef( + name=fn_name, + args=modified_signature, + body=new_body, + decorator_list=[] + ) + + stmts.append(fn_ast) + + scan_call = ast.Assign( + # targets=_astify_outvars(eqn.outvars), + targets=[ + ast.Tuple( + elts=[ast.Name(id='final_carry', ctx=ast.Store()), + ast.Name(id='ys', ctx=ast.Store())], + ctx=ast.Store() + ) + ], + value=ast.Call( + func=ast.Name(id='jax.lax.scan', ctx=ast.Load()), + args=[ast.Name(id=fn_name, ctx=ast.Load()), + maybe_tuple_vars([_astify_atom(state, v) for v in carries]), + maybe_tuple_vars([_astify_atom(state, v) for v in xs])], + keywords=[ast.keyword(arg='length', value=length), + ast.keyword(arg='unroll', value=unroll), + ast.keyword(arg='reverse', value=reverse)] + ) + ) + stmts.append(scan_call) + + if num_carry > 0: + assign_carry = ast.Assign( + targets=_astify_outvars(state, eqn.outvars[:num_carry]), + value=ast.Name(id='final_carry', ctx=ast.Load()) + ) + + stmts.append(assign_carry) + + if num_carry < len(eqn.outvars): + assign_ys = ast.Assign( + targets=_astify_outvars(state, eqn.outvars[num_carry:]), + value=ast.Name(id='ys', ctx=ast.Load()) + ) + + stmts.append(assign_ys) + else: + stmts.append(fn_ast) + + scan_call = ast.Assign( + targets=_astify_outvars(state, eqn.outvars), + value=ast.Call( + func=ast.Name(id='jax.lax.scan', ctx=ast.Load()), + args=[ast.Name(id=fn_name, ctx=ast.Load())] + [_astify_atom(state, v) for v in eqn.invars], + keywords=[ast.keyword(arg='length', value=length), + ast.keyword(arg='unroll', value=unroll), + ast.keyword(arg='reverse', value=reverse)] + ) + ) + + stmts.append(scan_call) + + return stmts + + +def _astify_map(state, eqn): + assert eqn.primitive.name == 'scan' + assert eqn.params['num_carry'] == 0 + + jaxpr = eqn.params['jaxpr'] + jaxpr = constant_fold_jaxpr(jaxpr.jaxpr) + + fn_name = state.skolem('fn') + fn_ast = jaxpr_to_py_ast(state, jaxpr, fn_name) + + # map is a bit funny, because the jaxpr takes K args, but the jax.lax.map function takes a single tuple arg + # so we need to use a lambda to redirect the call + lam = ast.parse(f"lambda args: {fn_name}(*args)").body[0] + + assign = ast.Assign( + targets=_astify_outvars(state, eqn.outvars), + value=ast.Call( + func=ast.Name(id='jax.lax.map', ctx=ast.Load()), + args=[lam, + ast.Tuple(elts=[_astify_atom(state, v) for v in eqn.invars], + ctx=ast.Load())], + keywords=[] + ) + ) + + return [fn_ast, assign] + + +@register_prim_as('closed_call') +def _astify_closed_call(state, eqn): + # out = partial_eval_jaxpr(eqn.params['call_jaxpr'].jaxpr, + # {var: val for var, val in zip(eqn.params['call_jaxpr'].jaxpr.invars, vals)}) + raw_jaxpr = eqn.params['call_jaxpr'].jaxpr + literal_args = {k: v.val + for k, v in zip(raw_jaxpr.invars, eqn.invars) + if isinstance(v, Literal)} + call_japr = partial_eval_jaxpr(raw_jaxpr, literal_args) + fn_name = state.skolem('fn') + + fn_ast = jaxpr_to_py_ast(state, call_japr, fn_name) + + invars = [_astify_atom(state, v) + for v in eqn.invars + if not isinstance(v, Literal)] + outvars = _astify_outvars(state, eqn.outvars) + + assign = ast.Assign( + targets=outvars, + value=ast.Call( + func=ast.Name(id=fn_name, ctx=ast.Load()), + args=invars, + keywords=[] + ) + ) + + return [fn_ast, assign] + + +@register_prim_as('pjit') +def _astify_pjit(state, eqn): + # this one's a real pain. + # pjit's params are : + # jaxpr + # donated_invars: + # in_shardings, out_shardings + # resource env + # name (yay) + # keep_unused, inline (which we won't use) + + jaxpr = eqn.params['jaxpr'] + donated_invars = eqn.params['donated_invars'] + in_shardings = eqn.params['in_shardings'] + out_shardings = eqn.params['out_shardings'] + resource_env = eqn.params['resource_env'] + name = eqn.params['name'] + + can_ignore_donated = not any(donated_invars) + + # preprocess the function + jaxpr = constant_fold_jaxpr(jaxpr.jaxpr) + fn_name = state.skolem(name) + fn_ast = jaxpr_to_py_ast(state, jaxpr, fn_name) + + in_shardings = _astify_value(in_shardings) + out_shardings = _astify_value(out_shardings) + + keywords = [ + ast.keyword(arg='in_shardings', value=in_shardings), + ast.keyword(arg='out_shardings', value=out_shardings), + ] + + if not can_ignore_donated: + donated_invars = _astify_value(donated_invars) + keywords.append(ast.keyword(arg='donated_invars', value=donated_invars)) + + jitted_fn = ast.Call( + func=ast.Attribute( + ast.Name(id='jax', ctx=ast.Load()), + attr='jit' + ), + args=[ast.Name(id=fn_name, ctx=ast.Load())], + keywords=keywords + ) + + assign = ast.Assign( + targets=_astify_outvars(state, eqn.outvars), + value=ast.Call( + func=jitted_fn, + args=[_astify_atom(state, v) for v in eqn.invars], + keywords=[] + ) + ) + + return [fn_ast, assign] + + +@register_prim_as('remat2') +def _astify_remat(state: SourcerorState, eqn): + # out = partial_eval_jaxpr(eqn.params['call_jaxpr'].jaxpr, + # {var: val for var, val in zip(eqn.params['call_jaxpr'].jaxpr.invars, vals)}) + call_japr = constant_fold_jaxpr(eqn.params['jaxpr']) + fn_name = state.skolem('fn') + + fn_ast = jaxpr_to_py_ast(state, call_japr, fn_name) + + invars = [_astify_atom(state, v) for v in eqn.invars] + outvars = _astify_outvars(state, eqn.outvars) + + lam = ast.Assign( + targets=[ast.Name(id=f"ckpt_{fn_name}", ctx=ast.Store())], + # value=ast.parse(f"jax.checkpoint({fn_name})").body[0] + value=ast.Call( + func=ast.Name(id='jax.checkpoint', ctx=ast.Load()), + args=[ast.Name(id=fn_name, ctx=ast.Load())], + keywords=[]) + ) + + assign = ast.Assign( + targets=outvars, + value=ast.Call( + func=ast.Name(id=f"ckpt_{fn_name}"), + args=invars, + keywords=[] + )) + + return [fn_ast, lam, assign] + + +@register_prim_as('reshape') +def _astify_reshape(state, eqn): + # the lax reshape is a bit different, because it can combine a transpose and reshape into one. + # np.reshape(np.transpose(operand, dimensions), new_sizes) + dimensions = eqn.params['dimensions'] + new_sizes = eqn.params['new_sizes'] + + source = _astify_atom(state, eqn.invars[0]) + + if dimensions is not None: + source = ast.Call( + func=ast.Name(id='jax.numpy.transpose', ctx=ast.Load()), + args=[source, _astify_value(dimensions)], + keywords=[] + ) + + assign = ast.Assign( + targets=_astify_outvars(state, eqn.outvars), + value=ast.Call( + func=ast.Name(id='jax.numpy.reshape', ctx=ast.Load()), + args=[source, _astify_value(new_sizes)], + keywords=[] + )) + + return [assign] + + +@register_prim_as('add_any') +def _astify_add_any(state, eqn): + # add_any is a weird undocumented jax primitive. best guess is it adds? + return _binop_fn(ast.Add())(state, eqn) + + +@register_prim_as('broadcast_in_dim') +def _astify_broadcast_in_dim(state, eqn): + # broadcast_in_dim is how zeros, ones, full, etc are implemented, + # so we prefer to use those where possible + assert len(eqn.invars) == 1 + value = eqn.invars[0] + shape = eqn.params['shape'] + broadcast_dimensions = eqn.params['broadcast_dimensions'] + + if not isinstance(value, Literal) or broadcast_dimensions != (): + return normal_fn('jax.lax.broadcast_in_dim')(state, eqn) + + if not isinstance(value.val, np.ndarray) or value.val.ndim != 0: + return normal_fn('jax.lax.broadcast_in_dim')(state, eqn) + else: + constant_value = value.val.item() + if constant_value == 0: + call = ast.Call( + ast.Attribute( + value=ast.Name(id='jax.numpy', ctx=ast.Load()), + attr='zeros', + ctx=ast.Load() + ), + args=[_astify_value(shape), + _astify_value(value.val.dtype)], + keywords=[] + ) + elif constant_value == 1: + call = ast.Call( + ast.Attribute( + value=ast.Name(id='jax.numpy', ctx=ast.Load()), + attr='ones', + ctx=ast.Load() + ), + args=[_astify_value(shape), + _astify_value(value.val.dtype)], + keywords=[] + ) + else: + call = ast.Call( + ast.Attribute( + value=ast.Name(id='jax.numpy', ctx=ast.Load()), + attr='full', + ctx=ast.Load() + ), + args=[_astify_value(shape), + _astify_value(constant_value), + _astify_value(value.val.dtype)], + keywords=[] + ) + + return [ast.Assign( + targets=_astify_outvars(state, eqn.outvars), + value=call + )] + + +@register_prim_as('random_wrap') +def _astify_random_wrap(state, eqn): + # we treat this as a noop + return ast.Assign( + targets=_astify_outvars(state, eqn.outvars), + value=_astify_atom(state, eqn.invars[0]) + ) + + +constant_fold_blacklist = { + 'broadcast_in_dim', + 'broadcast', +} diff --git a/brainpy/_src/integrators/base.py b/brainpy/_src/integrators/base.py index 6168ffd87..7853123bc 100644 --- a/brainpy/_src/integrators/base.py +++ b/brainpy/_src/integrators/base.py @@ -1,7 +1,9 @@ # -*- coding: utf-8 -*- -from typing import Dict, Sequence, Union +from typing import Dict, Sequence, Union, Callable + +import jax from brainpy._src.math.object_transform.base import BrainPyObject from brainpy._src.math import TimeDelay, LengthDelay @@ -9,6 +11,9 @@ from brainpy.errors import DiffEqError from .constants import DT +from ._jaxpr_to_source_code import jaxpr_to_python_code +from contextlib import contextmanager + __all__ = [ 'Integrator', ] @@ -58,6 +63,9 @@ def __init__( self._state_delays[key] = delay self.register_implicit_nodes(self._state_delays) + # math expression + self._math_expr = None + @property def dt(self): """The numerical integration precision.""" @@ -119,6 +127,18 @@ def state_delays(self): def state_delays(self, value): raise ValueError('Cannot set "state_delays" by users.') + def _call_integral(self, *args, **kwargs): + if _during_compile: + jaxpr, out_shapes = jax.make_jaxpr(self.integral, return_shape=True)(**kwargs) + outs = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *jax.tree.leaves(kwargs)) + _, tree = jax.tree.flatten(out_shapes) + new_vars = tree.unflatten(outs) + self._math_expr = jaxpr_to_python_code(jaxpr.jaxpr) + + else: + new_vars = self.integral(**kwargs) + return new_vars + def __call__(self, *args, **kwargs): assert self.integral is not None, 'Please build the integrator first.' @@ -127,7 +147,9 @@ def __call__(self, *args, **kwargs): kwargs[self.arg_names[i]] = arg # integral - new_vars = self.integral(**kwargs) + new_vars = self._call_integral(**kwargs) + + # post-process if len(self.variables) == 1: dict_vars = {self.variables[0]: new_vars} else: @@ -146,3 +168,31 @@ def __call__(self, *args, **kwargs): f'While we got {delay}') return new_vars + + def to_math_expr(self): + if self._math_expr is None: + raise ValueError('Please call ``brainpy.integrators.compile_integrators`` first.') + return self._math_expr + + +_during_compile = False + + +@contextmanager +def _during_compile_context(): + global _during_compile + try: + _during_compile = True + yield + finally: + _during_compile = False + + +def compile_integrators(f: Callable, *args, **kwargs): + """ + Compile integrators in the given function. + """ + with _during_compile_context(): + return f(*args, **kwargs) + + diff --git a/brainpy/_src/integrators/ode/base.py b/brainpy/_src/integrators/ode/base.py index b34dd4bf4..36b0c5f04 100644 --- a/brainpy/_src/integrators/ode/base.py +++ b/brainpy/_src/integrators/ode/base.py @@ -111,7 +111,7 @@ def __call__(self, *args, **kwargs): kwargs[self.arg_names[i]] = arg # integral - new_vars = self.integral(**kwargs) + new_vars = self._call_integral(**kwargs) if len(self.variables) == 1: dict_vars = {self.variables[0]: new_vars} else: diff --git a/brainpy/_src/integrators/tests/test_to_source_code.py b/brainpy/_src/integrators/tests/test_to_source_code.py new file mode 100644 index 000000000..aecf83230 --- /dev/null +++ b/brainpy/_src/integrators/tests/test_to_source_code.py @@ -0,0 +1,48 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +import brainpy as bp + + +class EINet3(bp.DynSysGroup): + def __init__(self): + super().__init__() + self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) + self.E = bp.dyn.HalfProjAlignPostMg(comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6), + syn=bp.dyn.Expon.desc(size=4000, tau=5.), + out=bp.dyn.COBA.desc(E=0.), + post=self.N) + self.I = bp.dyn.HalfProjAlignPostMg(comm=bp.dnn.EventJitFPHomoLinear(800, 4000, prob=0.02, weight=6.7), + syn=bp.dyn.Expon.desc(size=4000, tau=10.), + out=bp.dyn.COBA.desc(E=-80.), + post=self.N) + + def update(self, input): + spk = self.delay.at('I') + self.E(spk[:3200]) + self.I(spk[3200:]) + self.delay(self.N(input)) + return self.N.spike.value + + +def test1(): + model = EINet3() + + bp.integrators.compile_integrators(model.step_run, 0, 0.) + for intg in model.nodes().subset(bp.Integrator).values(): + print(intg.to_math_expr()) diff --git a/brainpy/_src/math/others.py b/brainpy/_src/math/others.py index 59588d3b9..776da1b5c 100644 --- a/brainpy/_src/math/others.py +++ b/brainpy/_src/math/others.py @@ -6,6 +6,7 @@ import jax import jax.numpy as jnp from jax.tree_util import tree_map +import numpy as np from brainpy import check, tools from .compat_numpy import fill_diagonal @@ -100,7 +101,8 @@ def false_f(x): return (jnp.exp(x) - 1) / x # return jax.lax.cond(jnp.abs(x) < threshold, true_f, false_f, x) - return jnp.where(jnp.abs(x) <= threshold, 1. + x / 2. + x * x / 6., (jnp.exp(x) - 1) / x) + # return jnp.where(jnp.abs(x) <= threshold, 1. + x / 2. + x * x / 6., (jnp.exp(x) - 1) / x) + return jax.lax.select(jnp.abs(x) <= threshold, 1. + x / 2. + x * x / 6., (jnp.exp(x) - 1) / x) def exprel(x, threshold: float = None): diff --git a/brainpy/integrators/__init__.py b/brainpy/integrators/__init__.py index 176a71aec..7696bd33a 100644 --- a/brainpy/integrators/__init__.py +++ b/brainpy/integrators/__init__.py @@ -3,4 +3,5 @@ from . import ode from . import sde from . import fde +from brainpy._src.integrators.base import compile_integrators from brainpy._src.integrators.constants import * From b059cb757b82d8513bc58667e12634290610fe6a Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Sat, 1 Jun 2024 23:20:39 +0800 Subject: [PATCH 2/2] rename to math_expr --- .../tests/{test_to_source_code.py => test_to_math_expr.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename brainpy/_src/integrators/tests/{test_to_source_code.py => test_to_math_expr.py} (100%) diff --git a/brainpy/_src/integrators/tests/test_to_source_code.py b/brainpy/_src/integrators/tests/test_to_math_expr.py similarity index 100% rename from brainpy/_src/integrators/tests/test_to_source_code.py rename to brainpy/_src/integrators/tests/test_to_math_expr.py