Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transform kernels and convolutions using a transformer before code generation #1050

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pynestml/cocos/co_co_all_variables_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def check_co_co(cls, node: ASTModel, after_ast_rewrite: bool = False):
# check if it is part of an invariant
# if it is the case, there is no "recursive" declaration
# so check if the parent is a declaration and the expression the invariant
expr_par = node.get_parent(expr)
expr_par = expr.get_parent()
if isinstance(expr_par, ASTDeclaration) and expr_par.get_invariant() == expr:
# in this case its ok if it is recursive or defined later on
continue
Expand Down
4 changes: 2 additions & 2 deletions pynestml/cocos/co_co_no_kernels_except_in_convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,11 @@ def visit_variable(self, node: ASTNode):
if not symbol.is_kernel():
continue
if node.get_complete_name() == kernelName:
parent = self.__neuron_node.get_parent(node)
parent = node.get_parent()
if parent is not None:
if isinstance(parent, ASTKernel):
continue
grandparent = self.__neuron_node.get_parent(parent)
grandparent = parent.get_parent()
if grandparent is not None and isinstance(grandparent, ASTFunctionCall):
grandparent_func_name = grandparent.get_name()
if grandparent_func_name == 'convolve':
Expand Down
2 changes: 1 addition & 1 deletion pynestml/cocos/co_co_resolution_func_legally_used.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def visit_simple_expression(self, node):
if function_name == PredefinedFunctions.TIME_RESOLUTION:
_node = node
while _node:
_node = self.neuron.get_parent(_node)
_node = _node.get_parent()

if isinstance(_node, ASTEquationsBlock) \
or isinstance(_node, ASTFunction):
Expand Down
2 changes: 1 addition & 1 deletion pynestml/cocos/co_co_simple_delta_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def check_co_co(cls, model: ASTModel):
def check_simple_delta(_expr=None):
if _expr.is_function_call() and _expr.get_function_call().get_name() == "delta":
deltafunc = _expr.get_function_call()
parent = model.get_parent(_expr)
parent = _expr.get_parent()

# check the argument
if not (len(deltafunc.get_args()) == 1
Expand Down
2 changes: 1 addition & 1 deletion pynestml/cocos/co_co_vector_declaration_right_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class VectorDeclarationVisitor(ASTVisitor):
def visit_variable(self, node: ASTVariable):
vector_parameter = node.get_vector_parameter()
if vector_parameter is not None:
if isinstance(self._neuron.get_parent(node), ASTDeclaration):
if isinstance(node.get_parent(), ASTDeclaration):
# node is being declared: size should be >= 1
min_index = 1

Expand Down
1 change: 0 additions & 1 deletion pynestml/codegeneration/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ def _setup_template_env(self, template_files: List[str], templates_root_dir: str
# Environment for neuron templates
env = Environment(loader=FileSystemLoader(_template_dirs))
env.globals["raise"] = self.raise_helper
env.globals["is_delta_kernel"] = ASTUtils.is_delta_kernel

# Load all the templates
_templates = list()
Expand Down
162 changes: 14 additions & 148 deletions pynestml/codegeneration/nest_code_generator.py

Large diffs are not rendered by default.

122 changes: 13 additions & 109 deletions pynestml/codegeneration/nest_compartmental_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,22 +280,16 @@ def analyse_transform_neurons(self, neurons: List[ASTModel]) -> None:

def create_ode_indict(self,
neuron: ASTModel,
parameters_block: ASTBlockWithVariables,
kernel_buffers: Mapping[ASTKernel,
ASTInputPort]):
odetoolbox_indict = self.transform_ode_and_kernels_to_json(
neuron, parameters_block, kernel_buffers)
parameters_block: ASTBlockWithVariables):
odetoolbox_indict = self.transform_ode_and_kernels_to_json(neuron, parameters_block)
odetoolbox_indict["options"] = {}
odetoolbox_indict["options"]["output_timestep_symbol"] = "__h"
return odetoolbox_indict

def ode_solve_analytically(self,
neuron: ASTModel,
parameters_block: ASTBlockWithVariables,
kernel_buffers: Mapping[ASTKernel,
ASTInputPort]):
odetoolbox_indict = self.create_ode_indict(
neuron, parameters_block, kernel_buffers)
parameters_block: ASTBlockWithVariables):
odetoolbox_indict = self.create_ode_indict(neuron, parameters_block)

full_solver_result = analysis(
odetoolbox_indict,
Expand All @@ -314,8 +308,7 @@ def ode_solve_analytically(self,

return full_solver_result, analytic_solver

def ode_toolbox_analysis(self, neuron: ASTModel,
kernel_buffers: Mapping[ASTKernel, ASTInputPort]):
def ode_toolbox_analysis(self, neuron: ASTModel):
"""
Prepare data for ODE-toolbox input format, invoke ODE-toolbox analysis via its API, and return the output.
"""
Expand All @@ -324,15 +317,13 @@ def ode_toolbox_analysis(self, neuron: ASTModel,

equations_block = neuron.get_equations_blocks()[0]

if len(equations_block.get_kernels()) == 0 and len(
equations_block.get_ode_equations()) == 0:
if len(equations_block.get_ode_equations()) == 0:
# no equations defined -> no changes to the neuron
return None, None

parameters_block = neuron.get_parameters_blocks()[0]

solver_result, analytic_solver = self.ode_solve_analytically(
neuron, parameters_block, kernel_buffers)
solver_result, analytic_solver = self.ode_solve_analytically(neuron, parameters_block)

# if numeric solver is required, generate a stepping function that
# includes each state variable
Expand All @@ -341,8 +332,7 @@ def ode_toolbox_analysis(self, neuron: ASTModel,
x for x in solver_result if x["solver"].startswith("numeric")]

if numeric_solvers:
odetoolbox_indict = self.create_ode_indict(
neuron, parameters_block, kernel_buffers)
odetoolbox_indict = self.create_ode_indict(neuron, parameters_block)
solver_result = analysis(
odetoolbox_indict,
disable_stiffness_check=True,
Expand Down Expand Up @@ -417,24 +407,6 @@ def analyse_neuron(self, neuron: ASTModel) -> List[ASTAssignment]:

return []

# goes through all convolve() inside ode's from equations block
# if they have delta kernels, use sympy to expand the expression, then
# find the convolve calls and replace them with constant value 1
# then return every subexpression that had that convolve() replaced
delta_factors = ASTUtils.get_delta_factors_(neuron, equations_block)

# goes through all convolve() inside equations block
# extracts what kernel is paired with what spike buffer
# returns pairs (kernel, spike_buffer)
kernel_buffers = ASTUtils.generate_kernel_buffers(
neuron, equations_block)

# replace convolve(g_E, spikes_exc) with g_E__X__spikes_exc[__d]
# done by searching for every ASTSimpleExpression inside equations_block
# which is a convolve call and substituting that call with
# newly created ASTVariable kernel__X__spike_buffer
ASTUtils.replace_convolve_calls_with_buffers_(neuron, equations_block)

# substitute inline expressions with each other
# such that no inline expression references another inline expression
ASTUtils.make_inline_expressions_self_contained(
Expand All @@ -450,14 +422,13 @@ def analyse_neuron(self, neuron: ASTModel) -> List[ASTAssignment]:
# "update_expressions" key in those solvers contains a mapping
# {expression1: update_expression1, expression2: update_expression2}

analytic_solver, numeric_solver = self.ode_toolbox_analysis(
neuron, kernel_buffers)
analytic_solver, numeric_solver = self.ode_toolbox_analysis(neuron)

"""
# separate analytic solutions by kernel
# this is is needed for the synaptic case
self.kernel_name_to_analytic_solver[neuron.get_name(
)] = self.ode_toolbox_anaysis_cm_syns(neuron, kernel_buffers)
)] = self.ode_toolbox_anaysis_cm_syns(neuron)
"""

self.analytic_solver[neuron.get_name()] = analytic_solver
Expand All @@ -472,12 +443,6 @@ def analyse_neuron(self, neuron: ASTModel) -> List[ASTAssignment]:
# by odetoolbox, higher order variables don't get deleted here
ASTUtils.remove_initial_values_for_kernels(neuron)

# delete all kernels as they are all converted into buffers
# and corresponding update formulas calculated by odetoolbox
# Remember them in a variable though
kernels = ASTUtils.remove_kernel_definitions_from_equations_block(
neuron)

# Every ODE variable (a variable of order > 0) is renamed according to ODE-toolbox conventions
# their initial values are replaced by expressions suggested by ODE-toolbox.
# Differential order can now be set to 0, becase they can directly represent the value of the derivative now.
Expand All @@ -491,22 +456,11 @@ def analyse_neuron(self, neuron: ASTModel) -> List[ASTAssignment]:
# corresponding updates
ASTUtils.remove_ode_definitions_from_equations_block(neuron)

# restore state variables that were referenced by kernels
# and set their initial values by those suggested by ODE-toolbox
ASTUtils.create_initial_values_for_kernels(
neuron, [analytic_solver, numeric_solver], kernels)

# Inside all remaining expressions, translate all remaining variable names
# according to the naming conventions of ODE-toolbox.
ASTUtils.replace_variable_names_in_expressions(
neuron, [analytic_solver, numeric_solver])

# find all inline kernels defined as ASTSimpleExpression
# that have a single kernel convolution aliasing variable ('__X__')
# translate all remaining variable names according to the naming
# conventions of ODE-toolbox
ASTUtils.replace_convolution_aliasing_inlines(neuron)

# add variable __h to internals block
ASTUtils.add_timestep_symbol(neuron)

Expand Down Expand Up @@ -677,13 +631,9 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict:
expr_ast.accept(ASTSymbolTableVisitor())
namespace["numeric_update_expressions"][sym] = expr_ast

namespace["spike_updates"] = neuron.spike_updates

namespace["recordable_state_variables"] = [
sym for sym in neuron.get_state_symbols() if namespace["declarations"].get_domain_from_type(
sym.get_type_symbol()) == "double" and sym.is_recordable and not ASTUtils.is_delta_kernel(
neuron.get_kernel_by_name(
sym.name))]
sym.get_type_symbol()) == "double" and sym.is_recordable]
namespace["recordable_inline_expressions"] = [
sym for sym in neuron.get_inline_expression_symbols() if namespace["declarations"].get_domain_from_type(
sym.get_type_symbol()) == "double" and sym.is_recordable]
Expand Down Expand Up @@ -807,7 +757,7 @@ def get_spike_update_expressions(
for var_order in range(
ASTUtils.get_kernel_var_order_from_ode_toolbox_result(
kernel_var.get_name(), solver_dicts)):
kernel_spike_buf_name = ASTUtils.construct_kernel_X_spike_buf_name(
kernel_spike_buf_name = ASTUtils.construct_kernel_spike_buf_name(
kernel_var.get_name(), spike_input_port, var_order)
expr = ASTUtils.get_initial_value_from_ode_toolbox_result(
kernel_spike_buf_name, solver_dicts)
Expand Down Expand Up @@ -849,18 +799,9 @@ def get_spike_update_expressions(
def transform_ode_and_kernels_to_json(
self,
neuron: ASTModel,
parameters_block,
kernel_buffers):
parameters_block):
"""
Converts AST node to a JSON representation suitable for passing to ode-toolbox.

Each kernel has to be generated for each spike buffer convolve in which it occurs, e.g. if the NESTML model code contains the statements

convolve(G, ex_spikes)
convolve(G, in_spikes)

then `kernel_buffers` will contain the pairs `(G, ex_spikes)` and `(G, in_spikes)`, from which two ODEs will be generated, with dynamical state (variable) names `G__X__ex_spikes` and `G__X__in_spikes`.

:param parameters_block: ASTBlockWithVariables
:return: Dict
"""
Expand Down Expand Up @@ -890,43 +831,6 @@ def transform_ode_and_kernels_to_json(
iv_symbol_name)] = expr
odetoolbox_indict["dynamics"].append(entry)

# write a copy for each (kernel, spike buffer) combination
for kernel, spike_input_port in kernel_buffers:

if ASTUtils.is_delta_kernel(kernel):
# delta function -- skip passing this to ode-toolbox
continue

for kernel_var in kernel.get_variables():
expr = ASTUtils.get_expr_from_kernel_var(
kernel, kernel_var.get_complete_name())
kernel_order = kernel_var.get_differential_order()
kernel_X_spike_buf_name_ticks = ASTUtils.construct_kernel_X_spike_buf_name(
kernel_var.get_name(), spike_input_port, kernel_order, diff_order_symbol="'")

ASTUtils.replace_rhs_variables(expr, kernel_buffers)

entry = {}
entry["expression"] = kernel_X_spike_buf_name_ticks + " = " + str(expr)

# initial values need to be declared for order 1 up to kernel
# order (e.g. none for kernel function f(t) = ...; 1 for kernel
# ODE f'(t) = ...; 2 for f''(t) = ... and so on)
entry["initial_values"] = {}
for order in range(kernel_order):
iv_sym_name_ode_toolbox = ASTUtils.construct_kernel_X_spike_buf_name(
kernel_var.get_name(), spike_input_port, order, diff_order_symbol="'")
symbol_name_ = kernel_var.get_name() + "'" * order
symbol = equations_block.get_scope().resolve_to_symbol(
symbol_name_, SymbolKind.VARIABLE)
assert symbol is not None, "Could not find initial value for variable " + symbol_name_
initial_value_expr = symbol.get_declaring_expression()
assert initial_value_expr is not None, "No initial value found for variable name " + symbol_name_
entry["initial_values"][iv_sym_name_ode_toolbox] = self._ode_toolbox_printer.print(
initial_value_expr)

odetoolbox_indict["dynamics"].append(entry)

odetoolbox_indict["parameters"] = {}
if parameters_block is not None:
for decl in parameters_block.get_declarations():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,10 +262,8 @@ std::vector< std::tuple< int, int > > {{neuronName}}::rport_to_nestml_buffer_idx

// copy state struct S_
{%- for init in neuron.get_state_symbols() %}
{%- if not is_delta_kernel(neuron.get_kernel_by_name(init.name)) %}
{%- set node = utils.get_state_variable_by_name(astnode, init.get_symbol_name()) %}
{{ nest_codegen_utils.print_symbol_origin(init, node) % printer_no_origin.print(node) }} = __n.{{ nest_codegen_utils.print_symbol_origin(init, node) % printer_no_origin.print(node) }};
{%- endif %}
{%- endfor %}

// copy internals V_
Expand Down Expand Up @@ -719,28 +717,6 @@ void {{neuronName}}::update(nest::Time const & origin,const long from, const lon
update_delay_variables();
{%- endif %}

/**
* subthreshold updates of the convolution variables
*
* step 1: regardless of whether and how integrate_odes() will be called, update variables due to convolutions
**/

{%- if uses_analytic_solver %}
{%- for variable_name in analytic_state_variables: %}
{%- if "__X__" in variable_name %}
{%- set update_expr = update_expressions[variable_name] %}
{%- set var_ast = utils.get_variable_by_name(astnode, variable_name)%}
{%- set var_symbol = var_ast.get_scope().resolve_to_symbol(variable_name, SymbolKind.VARIABLE)%}
{%- if use_gap_junctions %}
const {{ type_symbol_printer.print(var_symbol.type_symbol) }} {{variable_name}}__tmp_ = {{ printer.print(update_expr) | replace("B_." + gap_junction_port + "_grid_sum_", "(B_." + gap_junction_port + "_grid_sum_ + __I_gap)") }};
{%- else %}
const {{ type_symbol_printer.print(var_symbol.type_symbol) }} {{variable_name}}__tmp_ = {{ printer.print(update_expr) }};
{%- endif %}
{%- endif %}
{%- endfor %}
{%- endif %}


/**
* Begin NESTML generated code for the update block(s)
**/
Expand Down Expand Up @@ -770,30 +746,6 @@ const {{ type_symbol_printer.print(var_symbol.type_symbol) }} {{variable_name}}_
}
{%- endfor %}

/**
* subthreshold updates of the convolution variables
*
* step 2: regardless of whether and how integrate_odes() was called, update variables due to convolutions. Set to the updated values at the end of the timestep.
**/
{% if uses_analytic_solver %}
{%- for variable_name in analytic_state_variables: %}
{%- if "__X__" in variable_name %}
{%- set update_expr = update_expressions[variable_name] %}
{%- set var_ast = utils.get_variable_by_name(astnode, variable_name)%}
{%- set var_symbol = var_ast.get_scope().resolve_to_symbol(variable_name, SymbolKind.VARIABLE)%}
{{ printer.print(var_ast) }} = {{variable_name}}__tmp_;
{%- endif %}
{%- endfor %}
{%- endif %}


/**
* spike updates due to convolutions
**/
{% filter indent(4) %}
{%- include "directives_cpp/ApplySpikesFromBuffers.jinja2" %}
{%- endfilter %}

/**
* Begin NESTML generated code for the onCondition block(s)
**/
Expand Down Expand Up @@ -1149,13 +1101,9 @@ void
{%- endfor %}

/**
* print updates due to convolutions
* push back spike history
**/

{%- for _, spike_update in post_spike_updates.items() %}
{{ printer.print(utils.get_variable_by_name(astnode, spike_update.get_variable().get_complete_name())) }} += 1.;
{%- endfor %}

last_spike_ = t_sp_ms;
history_.push_back( histentry__{{neuronName}}( last_spike_
{%- for var in purely_numeric_state_variables_moved|sort %}
Expand Down
Loading
Loading