22import types
33import logging
44import platform
5- from typing import Optional , Dict
5+ from typing import Optional , Dict , Set
66
77from llvmlite import ir
88
@@ -70,16 +70,22 @@ def autoload_library():
7070
7171
7272def compile_functions_c (model : "Model" , jit_compiler : TCCJITCompiler ):
73+ needs_compile_function_indices = []
74+ for function_index in model .function_indices :
75+ if not model ._has_function_evaluator (function_index ):
76+ needs_compile_function_indices .append (function_index )
77+
78+ if len (needs_compile_function_indices ) == 0 :
79+ return
80+
7381 io = StringIO ()
7482
7583 generate_csrc_prelude (io )
7684
77- for (
78- function_index ,
79- cppad_autodiff_graph ,
80- ) in model .function_cppad_autodiff_graphs .items ():
85+ for function_index in needs_compile_function_indices :
8186 name = model .function_names [function_index ]
8287 autodiff_structure = model .function_autodiff_structures [function_index ]
88+ cppad_autodiff_graph = model .function_cppad_autodiff_graphs [function_index ]
8389
8490 np = autodiff_structure .np
8591 ny = autodiff_structure .ny
@@ -131,14 +137,10 @@ def compile_functions_c(model: "Model", jit_compiler: TCCJITCompiler):
131137
132138 csrc = io .getvalue ()
133139
134- jit_compiler .source_code = csrc
140+ state = jit_compiler .create_state ()
141+ jit_compiler .compile_string (state , csrc )
135142
136- jit_compiler .compile_string (csrc .encode ())
137-
138- for (
139- function_index ,
140- cppad_autodiff_graph ,
141- ) in model .function_cppad_autodiff_graphs .items ():
143+ for function_index in needs_compile_function_indices :
142144 name = model .function_names [function_index ]
143145 autodiff_structure = model .function_autodiff_structures [function_index ]
144146
@@ -147,13 +149,13 @@ def compile_functions_c(model: "Model", jit_compiler: TCCJITCompiler):
147149 gradient_name = name + "_gradient"
148150 hessian_name = name + "_hessian"
149151
150- f_ptr = jit_compiler .get_symbol (f_name . encode () )
152+ f_ptr = jit_compiler .get_symbol (state , f_name )
151153 jacobian_ptr = gradient_ptr = hessian_ptr = 0
152154 if autodiff_structure .has_jacobian :
153- jacobian_ptr = jit_compiler .get_symbol (jacobian_name . encode () )
154- gradient_ptr = jit_compiler .get_symbol (gradient_name . encode () )
155+ jacobian_ptr = jit_compiler .get_symbol (state , jacobian_name )
156+ gradient_ptr = jit_compiler .get_symbol (state , gradient_name )
155157 if autodiff_structure .has_hessian :
156- hessian_ptr = jit_compiler .get_symbol (hessian_name . encode () )
158+ hessian_ptr = jit_compiler .get_symbol (state , hessian_name )
157159
158160 evaluator = AutodiffEvaluator (
159161 autodiff_structure , f_ptr , jacobian_ptr , gradient_ptr , hessian_ptr
@@ -162,17 +164,23 @@ def compile_functions_c(model: "Model", jit_compiler: TCCJITCompiler):
162164
163165
164166def compile_functions_llvm (model : "Model" , jit_compiler : LLJITCompiler ):
167+ needs_compile_function_indices = []
168+ for function_index in model .function_indices :
169+ if not model ._has_function_evaluator (function_index ):
170+ needs_compile_function_indices .append (function_index )
171+
172+ if len (needs_compile_function_indices ) == 0 :
173+ return
174+
165175 module = ir .Module (name = "my_module" )
166176 create_llvmir_basic_functions (module )
167177
168178 export_functions = []
169179
170- for (
171- function_index ,
172- cppad_autodiff_graph ,
173- ) in model .function_cppad_autodiff_graphs .items ():
180+ for function_index in needs_compile_function_indices :
174181 name = model .function_names [function_index ]
175182 autodiff_structure = model .function_autodiff_structures [function_index ]
183+ cppad_autodiff_graph = model .function_cppad_autodiff_graphs [function_index ]
176184
177185 np = autodiff_structure .np
178186 ny = autodiff_structure .ny
@@ -224,12 +232,9 @@ def compile_functions_llvm(model: "Model", jit_compiler: LLJITCompiler):
224232
225233 export_functions .extend ([f_name , jacobian_name , gradient_name , hessian_name ])
226234
227- jit_compiler .compile_module (module , export_functions )
235+ rt = jit_compiler .compile_module (module , export_functions )
228236
229- for (
230- function_index ,
231- cppad_autodiff_graph ,
232- ) in model .function_cppad_autodiff_graphs .items ():
237+ for function_index in needs_compile_function_indices :
233238 name = model .function_names [function_index ]
234239 autodiff_structure = model .function_autodiff_structures [function_index ]
235240
@@ -238,13 +243,13 @@ def compile_functions_llvm(model: "Model", jit_compiler: LLJITCompiler):
238243 gradient_name = name + "_gradient"
239244 hessian_name = name + "_hessian"
240245
241- f_ptr = jit_compiler . get_symbol ( f_name )
246+ f_ptr = rt [ f_name ]
242247 jacobian_ptr = gradient_ptr = hessian_ptr = 0
243248 if autodiff_structure .has_jacobian :
244- jacobian_ptr = jit_compiler . get_symbol ( jacobian_name )
245- gradient_ptr = jit_compiler . get_symbol ( gradient_name )
249+ jacobian_ptr = rt [ jacobian_name ]
250+ gradient_ptr = rt [ gradient_name ]
246251 if autodiff_structure .has_hessian :
247- hessian_ptr = jit_compiler . get_symbol ( hessian_name )
252+ hessian_ptr = rt [ hessian_name ]
248253
249254 evaluator = AutodiffEvaluator (
250255 autodiff_structure , f_ptr , jacobian_ptr , gradient_ptr , hessian_ptr
@@ -445,6 +450,7 @@ def __init__(self, jit: str = "LLVM"):
445450 self .jit = jit
446451 self .add_variables = types .MethodType (make_nd_variable , self )
447452
453+ self .function_indices : Set [FunctionIndex ] = set ()
448454 self .function_cppad_autodiff_graphs : Dict [FunctionIndex , CppADAutodiffGraph ] = (
449455 {}
450456 )
@@ -514,6 +520,7 @@ def register_function(
514520 )
515521
516522 function_index = super ()._register_function (autodiff_structure )
523+ self .function_indices .add (function_index )
517524 self .function_cppad_autodiff_graphs [function_index ] = cppad_graph
518525 self .function_autodiff_structures [function_index ] = autodiff_structure
519526 self .function_tracing_results [function_index ] = tracing_result
0 commit comments