/
GSL_generator.py
1161 lines (1028 loc) · 44.2 KB
/
GSL_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
GSLCodeGenerators for code that uses the ODE solver provided by the GNU Scientific Library (GSL)
"""
import os
import re
import numpy as np
from brian2.codegen.generators import c_data_type
from brian2.codegen.permutation_analysis import (
OrderDependenceError,
check_for_order_independence,
)
from brian2.codegen.translation import make_statements
from brian2.core.functions import Function
from brian2.core.preferences import BrianPreference, PreferenceError, prefs
from brian2.core.variables import ArrayVariable, AuxiliaryVariable, Constant
from brian2.parsing.statements import parse_statement
from brian2.units.fundamentalunits import fail_for_dimension_mismatch
from brian2.utils.logger import get_logger
from brian2.utils.stringtools import get_identifiers, word_substitute
__all__ = ["GSLCodeGenerator", "GSLCPPCodeGenerator", "GSLCythonCodeGenerator"]
logger = get_logger(__name__)
def valid_gsl_dir(val):
"""
Validate given string to be path containing required GSL files.
"""
if val is None:
return True
if not isinstance(val, str):
raise PreferenceError(
f"Illegal value for GSL directory: {str(val)}, has to be str"
)
if not os.path.isdir(val):
raise PreferenceError(
f"Illegal value for GSL directory: {val}, has to be existing directory"
)
if any(
not os.path.isfile(os.path.join(val, "gsl", filename))
for filename in ["gsl_odeiv2.h", "gsl_errno.h", "gsl_matrix.h"]
):
raise PreferenceError(
f"Illegal value for GSL directory: '{val}', "
"has to contain gsl_odeiv2.h, gsl_errno.h "
"and gsl_matrix.h"
)
return True
prefs.register_preferences(
"GSL",
"Directory containing GSL code",
directory=BrianPreference(
validator=valid_gsl_dir,
docs=(
"Set path to directory containing GSL header files (gsl_odeiv2.h etc.)"
"\nIf this directory is already in Python's include (e.g. because of "
"conda installation), this path can be set to None."
),
default=None,
),
)
class GSLCodeGenerator:
"""
GSL code generator.
Notes
-----
Approach is to first let the already existing code generator for a target
language do the bulk of the translating from abstract_code to actual code.
This generated code is slightly adapted to render it GSL compatible.
The most critical part here is that the vector_code that is normally
contained in a loop in the ```main()``` is moved to the function ```_GSL_func```
that is sent to the GSL integrator. The variables used in the vector_code are
added to a struct named ```dataholder``` and their values are set from the
Brian namespace just before the scalar code block.
"""
def __init__(
self,
variables,
variable_indices,
owner,
iterate_all,
codeobj_class,
name,
template_name,
override_conditional_write=None,
allows_scalar_write=False,
):
self.generator = codeobj_class.original_generator_class(
variables,
variable_indices,
owner,
iterate_all,
codeobj_class,
name,
template_name,
override_conditional_write,
allows_scalar_write,
)
self.method_options = dict(owner.state_updater.method_options)
self.integrator = owner.state_updater.integrator
# default timestep to start with is the timestep of the NeuronGroup itself
self.method_options["dt_start"] = owner.dt.variable.get_value()[0]
self.variable_flags = owner.state_updater._gsl_variable_flags
def __getattr__(self, item):
return getattr(self.generator, item)
# A series of functions that should be overridden by child class:
def c_data_type(self, dtype):
"""
Get string version of object dtype that is attached to Brian variables. c
pp_generator already has this function, but the Cython generator does not,
but we need it for GSL code generation.
"""
return NotImplementedError
def initialize_array(self, varname, values):
"""
Initialize a static array with given floating point values. E.g. in C++,
when called with arguments ``array`` and ``[1.0, 3.0, 2.0]``, this
method should return ``double array[] = {1.0, 3.0, 2.0}``.
Parameters
----------
varname : str
The name of the array variable that should be initialized
values : list of float
The values that should be assigned to the array
Returns
-------
code : str
One or more lines of array initialization code.
"""
raise NotImplementedError
def var_init_lhs(self, var, type):
"""
Get string version of the left hand side of an initializing expression
Parameters
----------
var : str
type : str
Returns
-------
code : str
For cpp returns type + var, while for cython just var
"""
raise NotImplementedError
def unpack_namespace_single(self, var_obj, in_vector, in_scalar):
"""
Writes the code necessary to pull single variable out of the Brian
namespace into the generated code.
The code created is significantly different between cpp and cython,
so I decided to not make this function general
over all target languages (i.e. in contrast to most other functions
that only have syntactical differences)
"""
raise NotImplementedError
# GSL functions that are the same for all target languages:
def find_function_names(self):
"""
Return a list of used function names in the self.variables dictionary
Functions need to be ignored in the GSL translation process, because the
brian generator already sufficiently
dealt with them. However, the brian generator also removes them from the
variables dict, so there is no
way to check whether an identifier is a function after the brian
translation process. This function is called
before this translation process and the list of function names is stored
to be used in the GSL translation.
Returns
-------
function_names : list
list of strings that are function names used in the code
"""
variables = self.variables
return [
var for var, var_obj in variables.items() if isinstance(var_obj, Function)
]
def is_cpp_standalone(self):
"""
Check whether we're running with cpp_standalone.
Test if `get_device()` is instance `CPPStandaloneDevice`.
Returns
-------
is_cpp_standalone : bool
whether currently using cpp_standalone device
See Also
--------
is_constant_and_cpp_standalone : uses the returned value
"""
# imports here to avoid circular imports
from brian2.devices.cpp_standalone.device import CPPStandaloneDevice
from brian2.devices.device import get_device
device = get_device()
return isinstance(device, CPPStandaloneDevice)
def is_constant_and_cpp_standalone(self, var_obj):
"""Check whether self.cpp_standalone and variable is Constant.
This check is needed because in the case of using the cpp_standalone device we do not
want to apply our GSL variable conversion (var --> _GSL_dataholder.var), because the cpp_standalone
code generation process involves replacing constants with their actual value ('freezing').
This results in code that looks like (if for example var = 1.2): _GSL_dataholder.1.2 = 1.2 and _GSL_dataholder->1.2.
To prevent repetitive calls to get_device() etc. the outcome of is_cpp_standalone is saved.
Parameters
----------
var_obj : `Variable`
instance of brian Variable class describing the variable
Returns
-------
is_cpp_standalone : bool
whether the used device is cpp_standalone and the given variable is an instance of Constant
"""
if not hasattr(self, "cpp_standalone"):
self.cpp_standalone = self.is_cpp_standalone()
return isinstance(var_obj, Constant) and self.cpp_standalone
def find_differential_variables(self, code):
"""
Find the variables that were tagged _gsl_{var}_f{ind} and return var, ind pairs.
`GSLStateUpdater` tagged differential variables and here we extract the information given in these tags.
Parameters
----------
code : list of strings
A list of strings containing gsl tagged variables
Returns
-------
diff_vars : dict
A dictionary with variable names as keys and differential equation index as value
"""
diff_vars = {}
for expr_set in code:
for expr in expr_set.split("\n"):
expr = expr.strip(" ")
try:
lhs, op, rhs, comment = parse_statement(expr)
except ValueError:
pass
m = re.search("_gsl_(.+?)_f([0-9]*)$", lhs)
if m:
diff_vars[m.group(1)] = m.group(2)
return diff_vars
def diff_var_to_replace(self, diff_vars):
"""
Add differential variable-related strings that need to be replaced to go
from normal brian to GSL code
From the code generated by Brian's 'normal' generators (cpp_generator or
cython_generator a few bits of text need to be replaced to get GSL
compatible code. The bits of text related to differential equation
variables are put in the replacer dictionary in this function.
Parameters
----------
diff_vars : dict
dictionary with variables as keys and differential equation index as value
Returns
-------
to_replace : dict
dictionary with strings that need to be replaced as keys and the
strings that will replace them as values
"""
variables = self.variables
to_replace = {}
for var, diff_num in list(diff_vars.items()):
to_replace.update(self.var_replace_diff_var_lhs(var, diff_num))
var_obj = variables[var]
array_name = self.generator.get_array_name(var_obj, access_data=True)
idx_name = "_idx" # TODO: could be dynamic?
replace_what = f"{var} = {array_name}[{idx_name}]"
replace_with = f"{var} = _GSL_y[{diff_num}]"
to_replace[replace_what] = replace_with
return to_replace
def get_dimension_code(self, diff_num):
"""
Generate code for function that sets the dimension of the ODE system.
GSL needs to know how many differential variables there are in the
ODE system. Since the current approach is to have the code in the vector
loop the same for all simulations, this dimension is set by an external
function. The code for this set_dimension functon is written here.
It is assumed the code will be the same for each target language with the
exception of some syntactical differences
Parameters
----------
diff_num : int
Number of differential variables that describe the ODE system
Returns
-------
set_dimension_code : str
The code describing the target language function in a single string
"""
code = ["\n{start_declare}int set_dimension(size_t * dimension){open_function}"]
code += ["\tdimension[0] = %d{end_statement}" % diff_num]
code += ["\treturn GSL_SUCCESS{end_statement}{end_function}"]
return ("\n").join(code).format(**self.syntax)
def yvector_code(self, diff_vars):
"""
Generate code for function dealing with GSLs y vector.
The values of differential variables have to be transferred from
Brian's namespace to a vector that is given to GSL. The transferring
from Brian --> y and back from y --> Brian after integration happens in
separate functions. The code for these is written here.
Parameters
----------
diff_vars : dictionary
Dictionary containing variable names as keys (str) and differential
variable index as value
Returns
-------
yvector_code : str
The code for the two functions (``_fill_y_vector`` and
``_empty_y_vector``) as single string.
"""
fill_y = [
"\n{start_declare}int _fill_y_vector(_dataholder *"
"_GSL_dataholder, double * _GSL_y, int _idx){open_function}"
]
empty_y = [
"\n{start_declare}int _empty_y_vector(_dataholder * "
"_GSL_dataholder, double * _GSL_y, int _idx){"
"open_function}"
]
for var, diff_num in list(diff_vars.items()):
diff_num = int(diff_num)
array_name = self.generator.get_array_name(
self.variables[var], access_data=True
)
fill_y += [
"\t_GSL_y[%d] = _GSL_dataholder{access_pointer}%s[_idx]{end_statement}"
% (diff_num, array_name)
]
empty_y += [
"\t_GSL_dataholder{access_pointer}%s[_idx] = _GSL_y[%d]{end_statement}"
% (array_name, diff_num)
]
fill_y += ["\treturn GSL_SUCCESS{end_statement}{end_function}"]
empty_y += ["\treturn GSL_SUCCESS{end_statement}{end_function}"]
return ("\n").join(fill_y + empty_y).format(**self.syntax)
def make_function_code(self, lines):
"""
Add lines of GSL translated vector code to 'non-changing' _GSL_func code.
Adds nonchanging aspects of GSL _GSL_func code to lines of code
written somewhere else (`translate_vector_code`). Here these lines
are put between the non-changing parts of the code and the
target-language specific syntax is added.
Parameters
----------
lines : str
Code containing GSL version of equations
Returns
-------
function_code : str
code describing ``_GSL_func`` that is sent to GSL integrator.
"""
code = [
"\n{start_declare}int _GSL_func(double t, const double "
"_GSL_y[], double f[], void * params){open_function}"
"\n\t{start_declare}_dataholder * _GSL_dataholder = {open_cast}"
"_dataholder *{close_cast} params{end_statement}"
"\n\t{start_declare}int _idx = _GSL_dataholder{access_pointer}_idx"
"{end_statement}"
]
code += [lines]
code += ["\treturn GSL_SUCCESS{end_statement}{end_function}"]
return ("\n").join(code).format(**self.syntax)
def write_dataholder_single(self, var_obj):
"""
Return string declaring a single variable in the ``_dataholder`` struct.
Parameters
----------
var_obj : `Variable`
Returns
-------
code : str
string describing this variable object as required for the ``_dataholder`` struct
(e.g. ``double* _array_neurongroup_v``)
"""
dtype = self.c_data_type(var_obj.dtype)
if isinstance(var_obj, ArrayVariable):
pointer_name = self.get_array_name(var_obj, access_data=True)
try:
restrict = self.generator.restrict
except AttributeError:
restrict = ""
if var_obj.scalar or var_obj.size == 1:
restrict = ""
return f"{dtype}* {restrict} {pointer_name}{{end_statement}}"
else:
return f"{dtype} {var_obj.name}{{end_statement}}"
def write_dataholder(self, variables_in_vector):
"""
Return string with full code for _dataholder struct.
Parameters
----------
variables_in_vector : dict
dictionary containing variable name as key and `Variable` as value
Returns
-------
code : str
code for _dataholder struct
"""
code = ["\n{start_declare}struct _dataholder{open_struct}"]
code += ["\tint _idx{end_statement}"]
for var, var_obj in list(variables_in_vector.items()):
if (
var == "t"
or "_gsl" in var
or self.is_constant_and_cpp_standalone(var_obj)
):
continue
code += [f" {self.write_dataholder_single(var_obj)}"]
code += ["{end_struct}"]
return ("\n").join(code).format(**self.syntax)
def scale_array_code(self, diff_vars, method_options):
"""
Return code for definition of ``_GSL_scale_array`` in generated code.
Parameters
----------
diff_vars : dict
dictionary with variable name (str) as key and differential variable
index (int) as value
method_options : dict
dictionary containing integrator settings
Returns
-------
code : str
full code describing a function returning a array containing doubles
with the absolute errors for each differential variable (according
to their assigned index in the GSL StateUpdater)
"""
# get scale values per variable from method_options
abs_per_var = method_options["absolute_error_per_variable"]
abs_default = method_options["absolute_error"]
if not isinstance(abs_default, float):
raise TypeError(
"The absolute_error key in method_options should be "
f"a float. Was type {type(abs_default)}"
)
if abs_per_var is None:
diff_scale = {var: float(abs_default) for var in list(diff_vars.keys())}
elif isinstance(abs_per_var, dict):
diff_scale = {}
for var, error in list(abs_per_var.items()):
# first do some checks on input
if var not in diff_vars:
if var not in self.variables:
raise KeyError(
"absolute_error specified for variable that "
f"does not exist: {var}"
)
else:
raise KeyError(
"absolute_error specified for variable that is "
f"not being integrated: {var}"
)
fail_for_dimension_mismatch(
error,
self.variables[var],
"Unit of absolute_error_per_variable "
f"for variable {var} does not match "
"unit of variable itself",
)
# if all these are passed we can add the value for error in base units
diff_scale[var] = float(error)
# set the variables that are not mentioned to default value
for var in list(diff_vars.keys()):
if var not in abs_per_var:
diff_scale[var] = float(abs_default)
else:
raise TypeError(
"The absolute_error_per_variable key in method_options "
"should either be None or a dictionary "
"containing the error for each individual state variable. "
f"Was type {type(abs_per_var)}"
)
# write code
return self.initialize_array(
"_GSL_scale_array", [diff_scale[var] for var in sorted(diff_vars)]
)
def find_undefined_variables(self, statements):
r"""
Find identifiers that are not in ``self.variables`` dictionary.
Brian does not save the ``_lio_`` variables it uses anywhere. This is
problematic for our GSL implementation because we save the lio variables
in the ``_dataholder`` struct (for which we need the datatype of the
variables). This function adds the left hand side variables that are
used in the vector code to the variable dictionary as
`AuxiliaryVariable`\ s (all we need later is the datatype).
Parameters
----------
statements : list
list of statement objects (need to have the dtype attribute)
Notes
-----
I keep ``self.variables`` and ``other_variables`` separate so I can
distinguish what variables are in the Brian namespace and which ones are
defined in the code itself.
"""
variables = self.variables
other_variables = {}
for statement in statements:
var = statement.var
if var not in variables:
other_variables[var] = AuxiliaryVariable(var, dtype=statement.dtype)
return other_variables
def find_used_variables(self, statements, other_variables):
"""
Find all the variables used on the right hand side of the given
expressions.
Parameters
----------
statements : list
list of statement objects
Returns
-------
used_variables : dict
dictionary of variables that are used as variable name (str),
`Variable` pairs.
"""
variables = self.variables
used_variables = {}
for statement in statements:
rhs = statement.expr
for var in get_identifiers(rhs):
if var in self.function_names:
continue
try:
var_obj = variables[var]
except KeyError:
var_obj = other_variables[var]
used_variables[var] = var_obj # save as object because this has
# all needed info (dtype, name, isarray)
# I don't know a nicer way to do this, the above way misses write
# variables (e.g. not_refractory)..
read, write, _ = self.array_read_write(statements)
for var in read | write:
if var not in used_variables:
used_variables[var] = variables[var] # will always be array and
# thus exist in variables
return used_variables
def to_replace_vector_vars(self, variables_in_vector, ignore=frozenset()):
"""
Create dictionary containing key, value pairs with to be replaced text
to translate from conventional Brian to GSL.
Parameters
----------
variables_in_vector : dict
dictionary with variable name (str), `Variable` pairs of variables
occurring in vector code
ignore : set, optional
set of strings with variable names that should be ignored
Returns
-------
to_replace : dict
dictionary with strings that need to be replaced i.e. _lio_1 will be
_GSL_dataholder._lio_1 (in cython) or _GSL_dataholder->_lio_1 (cpp)
Notes
-----
t will always be added because GSL defines its own t.
i.e. for cpp: {'const t = _ptr_array_defaultclock_t[0];' : ''}
"""
access_pointer = self.syntax["access_pointer"]
to_replace = {}
t_in_code = None
for var, var_obj in list(variables_in_vector.items()):
if var_obj.name == "t":
t_in_code = var_obj
continue
if "_gsl" in var or var in ignore:
continue
if self.is_constant_and_cpp_standalone(var_obj):
# does not have to be processed by GSL generator
self.variables_to_be_processed.remove(var_obj.name)
continue
if isinstance(var_obj, ArrayVariable):
pointer_name = self.get_array_name(var_obj, access_data=True)
to_replace[pointer_name] = (
f"_GSL_dataholder{access_pointer}{pointer_name}"
)
else:
to_replace[var] = f"_GSL_dataholder{access_pointer}{var}"
# also make sure t declaration is replaced if in code
if t_in_code is not None:
t_declare = self.var_init_lhs("t", "const double ")
array_name = self.get_array_name(t_in_code, access_data=True)
end_statement = self.syntax["end_statement"]
replace_what = f"{t_declare} = {array_name}[0]{end_statement}"
to_replace[replace_what] = ""
self.variables_to_be_processed.remove("t")
return to_replace
def unpack_namespace(
self, variables_in_vector, variables_in_scalar, ignore=frozenset()
):
"""
Write code that unpacks Brian namespace to cython/cpp namespace.
For vector code this means putting variables in _dataholder (i.e.
_GSL_dataholder->var or _GSL_dataholder.var = ...)
Note that code is written so a variable could occur both in scalar and
vector code
Parameters
----------
variables_in_vector : dict
dictionary with variable name (str), `Variable` pairs of variables
occurring in vector code
variables_in_scalar : dict
dictionary with variable name (str), `Variable` pairs of variables
occurring in scalar code
ignore : set, optional
set of string names of variables that should be ignored
Returns
-------
unpack_namespace_code : str
code fragment unpacking the Brian namespace (setting variables in
the _dataholder struct in case of vector)
"""
code = []
for var, var_obj in list(self.variables.items()):
if var in ignore:
continue
if self.is_constant_and_cpp_standalone(var_obj):
continue
in_vector = var in variables_in_vector
in_scalar = var in variables_in_scalar
if in_vector:
self.variables_to_be_processed.remove(var)
code += [self.unpack_namespace_single(var_obj, in_vector, in_scalar)]
return ("\n").join(code)
def translate_vector_code(self, code_lines, to_replace):
"""
Translate vector code to GSL compatible code by substituting fragments
of code.
Parameters
----------
code_lines : list
list of strings describing the vector_code
to_replace: dict
dictionary with to be replaced strings (see to_replace_vector_vars
and to_replace_diff_vars)
Returns
-------
vector_code : str
New code that is now to be added to the function that is sent to the
GSL integrator
"""
code = []
for expr_set in code_lines:
for line in expr_set.split(
"\n"
): # every line seperate to make tabbing correct
code += [f" {line}"]
code = ("\n").join(code)
code = word_substitute(code, to_replace)
# special substitute because of limitations of regex word boundaries with
# variable[_idx]
for from_sub, to_sub in list(to_replace.items()):
m = re.search(r"\[(\w+)\];?$", from_sub)
if m:
code = re.sub(re.sub(r"\[", r"\[", from_sub), to_sub, code)
if "_gsl" in code:
raise AssertionError(
"Translation failed, _gsl still in code (should only "
"be tag, and should be replaced).\n"
f"Code:\n{code}"
)
return code
def translate_scalar_code(
self, code_lines, variables_in_scalar, variables_in_vector
):
"""
Translate scalar code: if calculated variables are used in the vector_code
their value is added to the variable in the _dataholder.
Parameters
----------
code_lines : list
list of strings containing scalar code
variables_in_vector : dict
dictionary with variable name (str), `Variable` pairs of variables
occurring in vector code
variables_in_scalar : dict
dictionary with variable name (str), `Variable` pairs of variables
occurring in scalar code
Returns
-------
scalar_code : str
code fragment that should be injected in the main before the loop
"""
code = []
for line in code_lines:
m = re.search(r"(\w+ = .*)", line)
try:
new_line = m.group(1)
var, op, expr, comment = parse_statement(new_line)
except (ValueError, AttributeError):
code += [line]
continue
if var in list(variables_in_scalar.keys()):
code += [line]
elif var in list(variables_in_vector.keys()):
if var == "t":
continue
try:
self.variables_to_be_processed.remove(var)
except KeyError:
raise AssertionError(
"Trying to process variable named %s by "
"putting its value in the _GSL_dataholder "
"based on scalar code, but the variable "
"has been processed already." % var
)
code += [f"_GSL_dataholder.{var} {op} {expr} {comment}"]
return "\n".join(code)
def add_gsl_variables_as_non_scalar(self, diff_vars):
"""
Add _gsl variables as non-scalar.
In `GSLStateUpdater` the differential equation variables are substituted
with GSL tags that describe the information needed to translate the
conventional Brian code to GSL compatible code. This function tells
Brian that the variables that contain these tags should always be vector
variables. If we don't do this, Brian renders the tag-variables as
scalar if no vector variables are used in the right hand side of the
expression.
Parameters
----------
diff_vars : dict
dictionary with variables as keys and differential equation index as
value
"""
for var, ind in list(diff_vars.items()):
name = f"_gsl_{var}_f{ind}"
self.variables[name] = AuxiliaryVariable(var, scalar=False)
def add_meta_variables(self, options):
if options["use_last_timestep"]:
try:
N = self.variables["N"].item()
self.owner.variables.add_array(
"_last_timestep",
size=N,
values=np.ones(N) * options["dt_start"],
dtype=np.float64,
)
except KeyError:
# has already been run
pass
self.variables["_last_timestep"] = self.owner.variables.get(
"_last_timestep"
)
pointer_last_timestep = (
f"{self.get_array_name(self.variables['_last_timestep'])}[_idx]"
)
else:
pointer_last_timestep = None
if options["save_failed_steps"]:
N = self.variables["N"].item()
try:
self.owner.variables.add_array("_failed_steps", size=N, dtype=np.int32)
except KeyError:
# has already been run
pass
self.variables["_failed_steps"] = self.owner.variables.get("_failed_steps")
pointer_failed_steps = (
f"{self.get_array_name(self.variables['_failed_steps'])}[_idx]"
)
else:
pointer_failed_steps = None
if options["save_step_count"]:
N = int(self.variables["N"].get_value())
try:
self.owner.variables.add_array("_step_count", size=N, dtype=np.int32)
except KeyError:
# has already been run
pass
self.variables["_step_count"] = self.owner.variables.get("_step_count")
pointer_step_count = (
f"{self.get_array_name(self.variables['_step_count'])}[_idx]"
)
else:
pointer_step_count = None
return {
"pointer_last_timestep": pointer_last_timestep,
"pointer_failed_steps": pointer_failed_steps,
"pointer_step_count": pointer_step_count,
}
def translate(
self, code, dtype
): # TODO: it's not so nice we have to copy the contents of this function..
"""
Translates an abstract code block into the target language.
"""
# first check if user code is not using variables that are also used by GSL
reserved_variables = [
"_dataholder",
"_fill_y_vector",
"_empty_y_vector",
"_GSL_dataholder",
"_GSL_y",
"_GSL_func",
]
if any([var in self.variables for var in reserved_variables]):
# import here to avoid circular import
raise ValueError(
f"The variables {str(reserved_variables)} are reserved for the GSL"
" internal code."
)
# if the following statements are not added, Brian translates the
# differential expressions in the abstract code for GSL to scalar statements
# in the case no non-scalar variables are used in the expression
diff_vars = self.find_differential_variables(list(code.values()))
self.add_gsl_variables_as_non_scalar(diff_vars)
# add arrays we want to use in generated code before self.generator.translate() so
# brian does namespace unpacking for us
pointer_names = self.add_meta_variables(self.method_options)
scalar_statements = {}
vector_statements = {}
for ac_name, ac_code in code.items():
statements = make_statements(
ac_code, self.variables, dtype, optimise=True, blockname=ac_name
)
scalar_statements[ac_name], vector_statements[ac_name] = statements
for vs in vector_statements.values():
# Check that the statements are meaningful independent on the order of
# execution (e.g. for synapses)
try:
if self.has_repeated_indices(
vs
): # only do order dependence if there are repeated indices
check_for_order_independence(
vs, self.generator.variables, self.generator.variable_indices
)
except OrderDependenceError:
# If the abstract code is only one line, display it in full
if len(vs) <= 1:
error_msg = f"Abstract code: '{vs[0]}'\n"
else:
error_msg = (
f"{len(vs)} lines of abstract code, first line is: '{vs[0]}'\n"
)
logger.warn(
"Came across an abstract code block that may not be "
"well-defined: the outcome may depend on the "
"order of execution. You can ignore this warning if "
"you are sure that the order of operations does not "
"matter. " + error_msg
)
# save function names because self.generator.translate_statement_sequence
# deletes these from self.variables but we need to know which identifiers
# we can safely ignore (i.e. we can ignore the functions because they are
# handled by the original generator)
self.function_names = self.find_function_names()
scalar_code, vector_code, kwds = self.generator.translate_statement_sequence(
scalar_statements, vector_statements
)
############ translate code for GSL
# first check if any indexing other than '_idx' is used (currently not supported)
for code_list in list(scalar_code.values()) + list(vector_code.values()):
for code in code_list:
m = re.search(r"\[(\w+)\]", code)
if m is not None:
if m.group(1) != "0" and m.group(1) != "_idx":
from brian2.stateupdaters.base import (
UnsupportedEquationsException,
)
raise UnsupportedEquationsException(
"Equations result in state "
"updater code with indexing "
"other than '_idx', which "
"is currently not supported "
"in combination with the "
"GSL stateupdater."
)
# differential variable specific operations
to_replace = self.diff_var_to_replace(diff_vars)
GSL_support_code = self.get_dimension_code(len(diff_vars))
GSL_support_code += self.yvector_code(diff_vars)
# analyze all needed variables; if not in self.variables: put in separate dic.
# also keep track of variables needed for scalar statements and vector statements
other_variables = self.find_undefined_variables(
scalar_statements[None] + vector_statements[None]
)
variables_in_scalar = self.find_used_variables(
scalar_statements[None], other_variables
)
variables_in_vector = self.find_used_variables(
vector_statements[None], other_variables
)
# so that _dataholder holds diff_vars as well, even if they don't occur
# in the actual statements
for var in list(diff_vars.keys()):
if var not in variables_in_vector:
variables_in_vector[var] = self.variables[var]
# let's keep track of the variables that eventually need to be added to
# the _GSL_dataholder somehow
self.variables_to_be_processed = list(variables_in_vector.keys())
# add code for _dataholder struct
GSL_support_code = self.write_dataholder(variables_in_vector) + GSL_support_code
# add e.g. _lio_1 --> _GSL_dataholder._lio_1 to replacer
to_replace.update(
self.to_replace_vector_vars(
variables_in_vector, ignore=list(diff_vars.keys())