Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix indexing of vector input ports #1042

Merged
merged 10 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 9 additions & 9 deletions doc/running/running_nest.rst
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ For the example mentioned :ref:`here <Multiple input ports with vectors>`, the `
neuron = nest.Create("multi_synapse_vectors")
receptor_types = nest.GetStatus(neuron, "receptor_types")

The name of the receptors of the input ports are denoted by suffixing the ``vector index + 1`` to the port name. For instance, the receptor name for ``foo[0]`` would be ``FOO_1``.
The name of the receptors of the input ports are denoted by suffixing the ``vector index`` to the port name. For instance, the receptor name for ``foo[0]`` would be ``FOO_0``.

The above code querying for ``receptor_types`` gives a list of port names and NEST ``rport`` numbers as shown below:

Expand All @@ -155,21 +155,21 @@ The above code querying for ``receptor_types`` gives a list of port names and NE
- 1
* - NMDA_spikes
- 2
* - FOO_1
* - FOO_0
- 3
* - FOO_2
* - FOO_1
- 4
* - EXC_SPIKES_1
* - EXC_SPIKES_0
- 5
* - EXC_SPIKES_2
* - EXC_SPIKES_1
- 6
* - EXC_SPIKES_3
* - EXC_SPIKES_2
- 7
* - INH_SPIKES_1
* - INH_SPIKES_0
- 5
* - INH_SPIKES_2
* - INH_SPIKES_1
- 6
* - INH_SPIKES_3
* - INH_SPIKES_2
- 7

For a full example, please see `iaf_psc_exp_multisynapse_vectors.nestml <https://github.com/nest/nestml/blob/master/tests/nest_tests/resources/iaf_psc_exp_multisynapse_vectors.nestml>`_ for the neuron model and ``test_multisynapse_with_vector_input_ports`` in `tests/nest_tests/nest_multisynapse_test.py <https://github.com/nest/nestml/blob/master/tests/nest_tests/nest_multisynapse_test.py>`_ for the corresponding test.
Expand Down
4 changes: 3 additions & 1 deletion pynestml/cocos/co_co_vector_input_port_correct_size_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.
from pynestml.utils.ast_utils import ASTUtils

from pynestml.meta_model.ast_expression import ASTExpression

from pynestml.cocos.co_co import CoCo
Expand Down Expand Up @@ -54,7 +56,7 @@ def visit_input_port(self, node: ASTInputPort):
return

# otherwise, it is a simple expression
if size_parameter.is_variable() or (size_parameter.is_numeric_literal() and not isinstance(size_parameter.get_numeric_literal(), int)):
if not isinstance(ASTUtils.get_numeric_vector_input_port_size(node), int):
code, message = Messages.get_input_port_size_not_integer(node.get_name())
Logger.log_message(error_position=node.get_source_position(), log_level=LoggingLevel.ERROR,
code=code, message=message)
Expand Down
4 changes: 3 additions & 1 deletion pynestml/codegeneration/nest_assignments_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@

from typing import Optional

from pynestml.utils.ast_utils import ASTUtils

from pynestml.meta_model.ast_assignment import ASTAssignment
from pynestml.symbols.symbol import SymbolKind
from pynestml.symbols.variable_symbol import VariableSymbol
from pynestml.symbols.variable_symbol import VariableSymbol, VariableType, BlockType
from pynestml.utils.logger import LoggingLevel, Logger


Expand Down
10 changes: 6 additions & 4 deletions pynestml/codegeneration/printers/gsl_variable_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,12 @@ def _print_buffer_value(self, variable: ASTVariable) -> str:
variable_symbol = variable.get_scope().resolve_to_symbol(variable.get_complete_name(), SymbolKind.VARIABLE)
if variable_symbol.is_spike_input_port():
var_name = variable_symbol.get_symbol_name().upper()
if variable_symbol.get_vector_parameter() is not None:
vector_parameter = ASTUtils.get_numeric_vector_size(variable_symbol)
var_name = var_name + "_" + str(vector_parameter)

if variable.has_vector_parameter():
if variable.get_vector_parameter().is_variable():
# the enum corresponding to the first input port in a vector of input ports will have the _0 suffixed to the enum's name.
var_name += "_0 + " + variable.get_vector_parameter().get_variable().get_name()
pnbabu marked this conversation as resolved.
Show resolved Hide resolved
else:
var_name += "_" + str(variable.get_vector_parameter())
return "spike_inputs_grid_sum_[node." + var_name + " - node.MIN_SPIKE_RECEPTOR]"

return variable_symbol.get_symbol_name() + '_grid_sum_'
10 changes: 6 additions & 4 deletions pynestml/codegeneration/printers/nest_variable_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,12 @@ def _print_buffer_value(self, variable: ASTVariable) -> str:
variable_symbol = variable.get_scope().resolve_to_symbol(variable.get_complete_name(), SymbolKind.VARIABLE)
if variable_symbol.is_spike_input_port():
var_name = variable_symbol.get_symbol_name().upper()
if variable.get_vector_parameter() is not None:
vector_parameter = ASTUtils.get_numeric_vector_size(variable)
var_name = var_name + "_" + str(vector_parameter)

if variable.has_vector_parameter():
if variable.get_vector_parameter().is_variable():
# the enum corresponding to the first input port in a vector of input ports will have the _0 suffixed to the enum's name.
var_name += "_0 + " + variable.get_vector_parameter().get_variable().get_name()
else:
var_name += "_" + str(variable.get_vector_parameter())
return "spike_inputs_grid_sum_[" + var_name + " - MIN_SPIKE_RECEPTOR]"

return variable_symbol.get_symbol_name() + '_grid_sum_'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ std::vector< std::tuple< int, int > > {{neuronName}}::rport_to_nestml_buffer_idx
{%- if ports[0].has_vector_parameter() %}
{%- set size = utils.get_numeric_vector_size(ports[0]) %}
{%- for i in range(size) %}
{{ rport_to_port_map_entry.RportToBufferIndexEntry(ports, ns.rport, index=i+1) }}
{{ rport_to_port_map_entry.RportToBufferIndexEntry(ports, ns.rport, index=i) }}
{%- set ns.rport = ns.rport + 1 %}
{%- endfor %}
{%- else %}
Expand All @@ -150,7 +150,7 @@ std::vector< std::tuple< int, int > > {{neuronName}}::rport_to_nestml_buffer_idx
std::string {{neuronName}}::get_var_name(size_t elem, std::string var_name)
{
std::stringstream n;
n << var_name << elem + 1;
n << var_name << elem;
return n.str();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ private:
{%- if port.has_vector_parameter() -%}
{% set size = utils.get_numeric_vector_size(port) | int %}
{%- for i in range(size) %}
{{port.get_symbol_name().upper()}}_{{i + 1}} = {{ns.count}},
{{port.get_symbol_name().upper()}}_{{i}} = {{ns.count}},
{%- set ns.count = ns.count + 1 -%}
{%- endfor %}
{%- else %}
Expand Down Expand Up @@ -991,7 +991,7 @@ inline void {{neuronName}}::get_status(DictionaryDatum &__d) const
{%- else %}
{%- set size = utils.get_numeric_vector_size(port) %}
{%- for i in range(size) %}
( *__receptor_type )[ "{{port.get_symbol_name().upper()}}_{{i + 1}}" ] = {{ns.rport + i + 1}},
( *__receptor_type )[ "{{port.get_symbol_name().upper()}}_{{i}}" ] = {{ns.rport + i + 1}},
{%- endfor %}
{%- endif %}
{%- endfor %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,4 @@
@param ast ASTAssignment
#}
{%- if tracing %}/* generated by {{self._TemplateReference__context.name}} */ {% endif %}
{%- set lhs_variable = ast.get_variable() %}

{%- if assignments.is_vectorized_assignment(ast) %}
{%- if lhs_variable.has_vector_parameter() %}
{%- set lhs_vector_variable = lhs_variable.get_vector_parameter() %}
{%- if lhs_vector_variable.is_numeric_literal() %}
{%- set lhs_variable_sym = assignments.lhs_variable(ast) %}
{%- if lhs_variable_sym is none %}
{{ raise('Symbol with name "%s" could not be resolved' % ast.lhs.get_complete_name()) }}
{%- endif %}
{{ nest_codegen_utils.print_symbol_origin(lhs_variable_sym, lhs_variable) % printer_no_origin.print(lhs_variable) }}[{{ ast.get_variable().get_vector_parameter() }}]
{%- elif lhs_vector_variable.is_variable() %}
{%- set vec_symbol = lhs_vector_variable.get_scope().resolve_to_symbol(lhs_vector_variable.get_variable().get_complete_name(), SymbolKind.VARIABLE) %}
{{ printer.print(lhs_variable) }}
{%- else -%}
{{ raise("Cannot handle vector index expression") }}
{%- endif %}
{%- else %}
{{ printer.print(lhs_variable) }}
{%- endif %}
{{assignments.print_assignments_operation(ast)}} {{ printer.print(ast.get_expression()) }};
{%- else %}
{{ printer.print(lhs_variable) }} {{ assignments.print_assignments_operation(ast) }} {{ printer.print(ast.get_expression()) }};
{%- endif %}
{{ printer.print(ast.get_variable()) }} {{ assignments.print_assignments_operation(ast) }} {{ printer.print(ast.get_expression()) }};
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{%- macro RportToBufferIndexEntry(ports, rport, index=0) -%}
{%- if index > 0 -%}
{%- macro RportToBufferIndexEntry(ports, rport, index=-1) -%}
{%- if index >= 0 -%}
{%- set name = "{}_" ~ index|string %}
{%- else -%}
{%- set name = "{}" %}
Expand Down
20 changes: 18 additions & 2 deletions pynestml/utils/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def get_vectorized_variable(cls, ast, scope):
return None

@classmethod
def get_numeric_vector_size(cls, variable: VariableSymbol) -> int:
def get_numeric_vector_size(cls, variable: ASTVariable) -> int:
"""
Returns the numerical size of the vector by resolving any variable used as a size parameter in declaration
:param variable: vector variable
Expand All @@ -355,6 +355,22 @@ def get_numeric_vector_size(cls, variable: VariableSymbol) -> int:
assert vector_parameter.is_numeric_literal()
return int(vector_parameter.get_numeric_literal())

@classmethod
def get_numeric_vector_input_port_size(cls, port: ASTInputPort) -> int:
"""
Returns the numerical size of the vector by resolving any variable used as a size parameter in declaration
:param port: input port
:return: the size of the vector as a numerical value
"""
size_parameter = port.get_size_parameter()
if size_parameter.is_variable():
symbol = port.get_scope().resolve_to_symbol(size_parameter.get_variable().get_name(),
SymbolKind.VARIABLE)
return symbol.get_declaring_expression().get_numeric_literal()

assert size_parameter.is_numeric_literal()
return int(size_parameter.get_numeric_literal())

@classmethod
def get_function_call(cls, ast, function_name):
"""
Expand Down Expand Up @@ -1382,7 +1398,7 @@ def get_input_port_by_name(cls, input_blocks: List[ASTInputBlock], port_name: st
if isinstance(size_parameter, ASTSimpleExpression):
size_parameter = size_parameter.get_numeric_literal()
port_name, port_index = port_name.split("_")
assert int(port_index) > 0
assert int(port_index) >= 0
assert int(port_index) <= size_parameter
if input_port.name == port_name:
return input_port
Expand Down
5 changes: 5 additions & 0 deletions pynestml/visitors/ast_symbol_table_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from pynestml.meta_model.ast_model_body import ASTModelBody
from pynestml.meta_model.ast_namespace_decorator import ASTNamespaceDecorator
from pynestml.meta_model.ast_declaration import ASTDeclaration
from pynestml.meta_model.ast_simple_expression import ASTSimpleExpression
from pynestml.meta_model.ast_stmt import ASTStmt
from pynestml.meta_model.ast_variable import ASTVariable
from pynestml.symbol_table.scope import Scope, ScopeType
Expand Down Expand Up @@ -470,6 +471,7 @@ def visit_simple_expression(self, node):
def visit_variable(self, node: ASTVariable):
if node.has_vector_parameter():
node.get_vector_parameter().update_scope(node.get_scope())
node.get_vector_parameter().accept(self)

def visit_inline_expression(self, node):
"""
Expand Down Expand Up @@ -595,6 +597,9 @@ def endvisit_input_port(self, node):
if node.is_continuous() and node.has_datatype():
type_symbol = node.get_datatype().get_type_symbol()
type_symbol.is_buffer = True # set it as a buffer
if node.has_size_parameter():
if isinstance(node.get_size_parameter(), ASTSimpleExpression) and node.get_size_parameter().is_variable():
node.get_size_parameter().update_scope(node.get_scope())
symbol = VariableSymbol(element_reference=node, scope=node.get_scope(), name=node.get_name(),
block_type=BlockType.INPUT, vector_parameter=node.get_size_parameter(),
is_predefined=False, is_inline_expression=False, is_recordable=False,
Expand Down
2 changes: 1 addition & 1 deletion tests/cocos_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,4 +695,4 @@ def test_invalid_co_co_vector_input_port(self):
os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
'CoCoVectorInputPortSizeAndType.nestml'))
self.assertEqual(len(
Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 2)
Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1)