Skip to content

Commit ea13b4e

Browse files
committed
Support re-opt for IPOPT
1 parent fdf0d8b commit ea13b4e

File tree

12 files changed

+188
-90
lines changed

12 files changed

+188
-90
lines changed

include/pyoptinterface/ipopt_model.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ struct IpoptModel
117117

118118
FunctionIndex _register_function(const AutodiffSymbolicStructure &structure);
119119
void _set_function_evaluator(const FunctionIndex &k, const AutodiffEvaluator &evaluator);
120+
bool _has_function_evaluator(const FunctionIndex &k);
120121

121122
NLConstraintIndex _add_nl_constraint_bounds(const FunctionIndex &k,
122123
const std::vector<VariableIndex> &xs,
@@ -134,6 +135,7 @@ struct IpoptModel
134135

135136
void clear_nl_objective();
136137

138+
void analyze_structure();
137139
void optimize();
138140

139141
// set options

include/pyoptinterface/nleval.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ struct LinearQuadraticEvaluator
209209
struct NonlinearFunctionEvaluator
210210
{
211211
std::vector<AutodiffSymbolicStructure> nl_function_structures;
212-
std::vector<AutodiffEvaluator> nl_function_evaluators;
212+
std::vector<std::optional<AutodiffEvaluator>> nl_function_evaluators;
213213
std::vector<FunctionInstances> constraint_function_instances;
214214
std::vector<size_t> active_constraint_function_indices;
215215
std::vector<FunctionInstances> objective_function_instances;
@@ -222,6 +222,7 @@ struct NonlinearFunctionEvaluator
222222

223223
FunctionIndex register_function(const AutodiffSymbolicStructure &structure);
224224
void set_function_evaluator(const FunctionIndex &k, const AutodiffEvaluator &evaluator);
225+
bool has_function_evaluator(const FunctionIndex &k);
225226

226227
NLConstraintIndex add_nl_constraint(const FunctionIndex &k,
227228
const std::vector<VariableIndex> &xs,

lib/ipopt_model.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,11 @@ void IpoptModel::_set_function_evaluator(const FunctionIndex &k, const AutodiffE
337337
m_function_model.set_function_evaluator(k, evaluator);
338338
}
339339

340+
bool IpoptModel::_has_function_evaluator(const FunctionIndex &k)
341+
{
342+
return m_function_model.has_function_evaluator(k);
343+
}
344+
340345
NLConstraintIndex IpoptModel::_add_nl_constraint_bounds(const FunctionIndex &k,
341346
const std::vector<VariableIndex> &xs,
342347
const std::vector<ParameterIndex> &ps,
@@ -444,8 +449,17 @@ static bool eval_h(ipindex n, ipnumber *x, bool new_x, ipnumber obj_factor, ipin
444449
return true;
445450
}
446451

447-
void IpoptModel::optimize()
452+
void IpoptModel::analyze_structure()
448453
{
454+
// init variables
455+
m_jacobian_nnz = 0;
456+
m_jacobian_rows.clear();
457+
m_jacobian_cols.clear();
458+
m_hessian_nnz = 0;
459+
m_hessian_rows.clear();
460+
m_hessian_cols.clear();
461+
m_hessian_index_map.clear();
462+
449463
// analyze structure
450464
m_function_model.analyze_active_functions();
451465
m_function_model.analyze_dense_gradient_structure();
@@ -465,6 +479,11 @@ void IpoptModel::optimize()
465479
fmt::print("Hessian has {} nonzeros\n", m_hessian_nnz);
466480
fmt::print("Hessian rows : {}\n", m_hessian_rows);
467481
fmt::print("Hessian cols : {}\n", m_hessian_cols);*/
482+
}
483+
484+
void IpoptModel::optimize()
485+
{
486+
analyze_structure();
468487

469488
auto problem_ptr =
470489
ipopt::CreateIpoptProblem(n_variables, m_var_lb.data(), m_var_ub.data(), n_constraints,

lib/ipopt_model_ext.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ NB_MODULE(ipopt_model_ext, m)
147147

148148
.def("_register_function", &IpoptModel::_register_function)
149149
.def("_set_function_evaluator", &IpoptModel::_set_function_evaluator)
150+
.def("_has_function_evaluator", &IpoptModel::_has_function_evaluator)
150151

151152
.def("_add_nl_constraint_bounds", &IpoptModel::_add_nl_constraint_bounds)
152153
.def("_add_nl_constraint_eq", &IpoptModel::_add_nl_constraint_eq)

lib/nleval.cpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ FunctionIndex NonlinearFunctionEvaluator::register_function(
377377
idx.index = nl_function_structures.size();
378378

379379
nl_function_structures.push_back(structure);
380-
nl_function_evaluators.emplace_back();
380+
nl_function_evaluators.emplace_back(std::nullopt);
381381
constraint_function_instances.emplace_back();
382382
objective_function_instances.emplace_back();
383383

@@ -390,6 +390,11 @@ void NonlinearFunctionEvaluator::set_function_evaluator(const FunctionIndex &k,
390390
nl_function_evaluators[k.index] = evaluator;
391391
}
392392

393+
bool NonlinearFunctionEvaluator::has_function_evaluator(const FunctionIndex &k)
394+
{
395+
return nl_function_evaluators[k.index].has_value();
396+
}
397+
393398
NLConstraintIndex NonlinearFunctionEvaluator::add_nl_constraint(
394399
const FunctionIndex &k, const std::vector<VariableIndex> &xs,
395400
const std::vector<ParameterIndex> &ps, size_t y)
@@ -693,7 +698,7 @@ void NonlinearFunctionEvaluator::eval_objective(const double *x, double *y)
693698
for (auto k : active_objective_function_indices)
694699
{
695700
auto &kernel = nl_function_structures[k];
696-
auto &evaluator = nl_function_evaluators[k];
701+
auto &evaluator = nl_function_evaluators[k].value();
697702
bool has_parameter = kernel.has_parameter;
698703
auto &inst_vec = objective_function_instances[k];
699704

@@ -728,7 +733,7 @@ void NonlinearFunctionEvaluator::eval_objective_gradient(const double *x, double
728733
for (auto k : active_objective_function_indices)
729734
{
730735
auto &kernel = nl_function_structures[k];
731-
auto &evaluator = nl_function_evaluators[k];
736+
auto &evaluator = nl_function_evaluators[k].value();
732737
if (!kernel.has_jacobian)
733738
continue;
734739

@@ -766,7 +771,7 @@ void NonlinearFunctionEvaluator::eval_constraint(const double *x, double *con)
766771
for (auto k : active_constraint_function_indices)
767772
{
768773
auto &kernel = nl_function_structures[k];
769-
auto &evaluator = nl_function_evaluators[k];
774+
auto &evaluator = nl_function_evaluators[k].value();
770775
bool has_parameter = kernel.has_parameter;
771776
auto &inst_vec = constraint_function_instances[k];
772777
if (has_parameter)
@@ -799,7 +804,7 @@ void NonlinearFunctionEvaluator::eval_constraint_jacobian(const double *x, doubl
799804
for (auto k : active_constraint_function_indices)
800805
{
801806
auto &kernel = nl_function_structures[k];
802-
auto &evaluator = nl_function_evaluators[k];
807+
auto &evaluator = nl_function_evaluators[k].value();
803808
if (!kernel.has_jacobian)
804809
continue;
805810

@@ -837,7 +842,7 @@ void NonlinearFunctionEvaluator::eval_lagrangian_hessian(const double *x, const
837842
for (auto k : active_constraint_function_indices)
838843
{
839844
auto &kernel = nl_function_structures[k];
840-
auto &evaluator = nl_function_evaluators[k];
845+
auto &evaluator = nl_function_evaluators[k].value();
841846
if (!kernel.has_hessian)
842847
continue;
843848

@@ -874,7 +879,7 @@ void NonlinearFunctionEvaluator::eval_lagrangian_hessian(const double *x, const
874879
for (auto k : active_objective_function_indices)
875880
{
876881
auto &kernel = nl_function_structures[k];
877-
auto &evaluator = nl_function_evaluators[k];
882+
auto &evaluator = nl_function_evaluators[k].value();
878883
if (!kernel.has_hessian)
879884
continue;
880885

src/pyoptinterface/_src/ipopt.py

Lines changed: 36 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import types
33
import logging
44
import platform
5-
from typing import Optional, Dict
5+
from typing import Optional, Dict, Set
66

77
from llvmlite import ir
88

@@ -70,16 +70,22 @@ def autoload_library():
7070

7171

7272
def compile_functions_c(model: "Model", jit_compiler: TCCJITCompiler):
73+
needs_compile_function_indices = []
74+
for function_index in model.function_indices:
75+
if not model._has_function_evaluator(function_index):
76+
needs_compile_function_indices.append(function_index)
77+
78+
if len(needs_compile_function_indices) == 0:
79+
return
80+
7381
io = StringIO()
7482

7583
generate_csrc_prelude(io)
7684

77-
for (
78-
function_index,
79-
cppad_autodiff_graph,
80-
) in model.function_cppad_autodiff_graphs.items():
85+
for function_index in needs_compile_function_indices:
8186
name = model.function_names[function_index]
8287
autodiff_structure = model.function_autodiff_structures[function_index]
88+
cppad_autodiff_graph = model.function_cppad_autodiff_graphs[function_index]
8389

8490
np = autodiff_structure.np
8591
ny = autodiff_structure.ny
@@ -131,14 +137,10 @@ def compile_functions_c(model: "Model", jit_compiler: TCCJITCompiler):
131137

132138
csrc = io.getvalue()
133139

134-
jit_compiler.source_code = csrc
140+
state = jit_compiler.create_state()
141+
jit_compiler.compile_string(state, csrc)
135142

136-
jit_compiler.compile_string(csrc.encode())
137-
138-
for (
139-
function_index,
140-
cppad_autodiff_graph,
141-
) in model.function_cppad_autodiff_graphs.items():
143+
for function_index in needs_compile_function_indices:
142144
name = model.function_names[function_index]
143145
autodiff_structure = model.function_autodiff_structures[function_index]
144146

@@ -147,13 +149,13 @@ def compile_functions_c(model: "Model", jit_compiler: TCCJITCompiler):
147149
gradient_name = name + "_gradient"
148150
hessian_name = name + "_hessian"
149151

150-
f_ptr = jit_compiler.get_symbol(f_name.encode())
152+
f_ptr = jit_compiler.get_symbol(state, f_name)
151153
jacobian_ptr = gradient_ptr = hessian_ptr = 0
152154
if autodiff_structure.has_jacobian:
153-
jacobian_ptr = jit_compiler.get_symbol(jacobian_name.encode())
154-
gradient_ptr = jit_compiler.get_symbol(gradient_name.encode())
155+
jacobian_ptr = jit_compiler.get_symbol(state, jacobian_name)
156+
gradient_ptr = jit_compiler.get_symbol(state, gradient_name)
155157
if autodiff_structure.has_hessian:
156-
hessian_ptr = jit_compiler.get_symbol(hessian_name.encode())
158+
hessian_ptr = jit_compiler.get_symbol(state, hessian_name)
157159

158160
evaluator = AutodiffEvaluator(
159161
autodiff_structure, f_ptr, jacobian_ptr, gradient_ptr, hessian_ptr
@@ -162,17 +164,23 @@ def compile_functions_c(model: "Model", jit_compiler: TCCJITCompiler):
162164

163165

164166
def compile_functions_llvm(model: "Model", jit_compiler: LLJITCompiler):
167+
needs_compile_function_indices = []
168+
for function_index in model.function_indices:
169+
if not model._has_function_evaluator(function_index):
170+
needs_compile_function_indices.append(function_index)
171+
172+
if len(needs_compile_function_indices) == 0:
173+
return
174+
165175
module = ir.Module(name="my_module")
166176
create_llvmir_basic_functions(module)
167177

168178
export_functions = []
169179

170-
for (
171-
function_index,
172-
cppad_autodiff_graph,
173-
) in model.function_cppad_autodiff_graphs.items():
180+
for function_index in needs_compile_function_indices:
174181
name = model.function_names[function_index]
175182
autodiff_structure = model.function_autodiff_structures[function_index]
183+
cppad_autodiff_graph = model.function_cppad_autodiff_graphs[function_index]
176184

177185
np = autodiff_structure.np
178186
ny = autodiff_structure.ny
@@ -224,12 +232,9 @@ def compile_functions_llvm(model: "Model", jit_compiler: LLJITCompiler):
224232

225233
export_functions.extend([f_name, jacobian_name, gradient_name, hessian_name])
226234

227-
jit_compiler.compile_module(module, export_functions)
235+
rt = jit_compiler.compile_module(module, export_functions)
228236

229-
for (
230-
function_index,
231-
cppad_autodiff_graph,
232-
) in model.function_cppad_autodiff_graphs.items():
237+
for function_index in needs_compile_function_indices:
233238
name = model.function_names[function_index]
234239
autodiff_structure = model.function_autodiff_structures[function_index]
235240

@@ -238,13 +243,13 @@ def compile_functions_llvm(model: "Model", jit_compiler: LLJITCompiler):
238243
gradient_name = name + "_gradient"
239244
hessian_name = name + "_hessian"
240245

241-
f_ptr = jit_compiler.get_symbol(f_name)
246+
f_ptr = rt[f_name]
242247
jacobian_ptr = gradient_ptr = hessian_ptr = 0
243248
if autodiff_structure.has_jacobian:
244-
jacobian_ptr = jit_compiler.get_symbol(jacobian_name)
245-
gradient_ptr = jit_compiler.get_symbol(gradient_name)
249+
jacobian_ptr = rt[jacobian_name]
250+
gradient_ptr = rt[gradient_name]
246251
if autodiff_structure.has_hessian:
247-
hessian_ptr = jit_compiler.get_symbol(hessian_name)
252+
hessian_ptr = rt[hessian_name]
248253

249254
evaluator = AutodiffEvaluator(
250255
autodiff_structure, f_ptr, jacobian_ptr, gradient_ptr, hessian_ptr
@@ -445,6 +450,7 @@ def __init__(self, jit: str = "LLVM"):
445450
self.jit = jit
446451
self.add_variables = types.MethodType(make_nd_variable, self)
447452

453+
self.function_indices: Set[FunctionIndex] = set()
448454
self.function_cppad_autodiff_graphs: Dict[FunctionIndex, CppADAutodiffGraph] = (
449455
{}
450456
)
@@ -514,6 +520,7 @@ def register_function(
514520
)
515521

516522
function_index = super()._register_function(autodiff_structure)
523+
self.function_indices.add(function_index)
517524
self.function_cppad_autodiff_graphs[function_index] = cppad_graph
518525
self.function_autodiff_structures[function_index] = autodiff_structure
519526
self.function_tracing_results[function_index] = tracing_result

src/pyoptinterface/_src/jit_c.py

Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -32,20 +32,26 @@ def __init__(self, libtcc_path=libtcc_path):
3232
# Initialize libtcc function prototypes
3333
self._initialize_function_prototypes()
3434

35+
# store all TCC states
36+
self.states = []
37+
38+
self.source_codes = []
39+
40+
def create_state(self):
3541
# Create a new TCC state
36-
self.state = self.libtcc.tcc_new()
42+
state = self.libtcc.tcc_new()
3743

3844
# Ensure the state was successfully created
39-
if not self.state:
45+
if not state:
4046
raise Exception("Failed to create TCC state")
4147

4248
# Set the output type to memory
43-
if self.libtcc.tcc_set_output_type(self.state, TCC_OUTPUT_MEMORY) == -1:
44-
self.cleanup()
49+
if self.libtcc.tcc_set_output_type(state, TCC_OUTPUT_MEMORY) == -1:
4550
raise Exception("Failed to set output type")
4651

47-
# relocate has been called
48-
self.relocated = False
52+
self.states.append(state)
53+
54+
return state
4955

5056
def _initialize_function_prototypes(self):
5157
libtcc = self.libtcc
@@ -60,37 +66,30 @@ def _initialize_function_prototypes(self):
6066
libtcc.tcc_get_symbol.argtypes = [TCCState, ctypes.c_char_p]
6167
libtcc.tcc_get_symbol.restype = ctypes.c_void_p
6268

63-
def compile_string(self, c_code):
64-
# Compile C code string
65-
if self.libtcc.tcc_compile_string(self.state, c_code) == -1:
69+
def compile_string(self, state, c_code: str):
70+
if self.libtcc.tcc_compile_string(state, c_code.encode()) == -1:
6671
raise Exception("Failed to compile code")
6772

68-
self.relocated = False
73+
if self.libtcc.tcc_relocate(state) == -1:
74+
raise Exception("Failed to relocate")
6975

70-
def add_symbol(self, symbol_name, symbol_address):
76+
self.source_codes.append(c_code)
77+
78+
def add_symbol(self, state, symbol_name: str, symbol_address):
7179
# Add a symbol to the TCC state
72-
if self.libtcc.tcc_add_symbol(self.state, symbol_name, symbol_address) == -1:
80+
if (
81+
self.libtcc.tcc_add_symbol(state, symbol_name.encode(), symbol_address)
82+
== -1
83+
):
7384
raise Exception(f"Failed to add symbol {symbol_name} to TCC state")
7485

75-
self.relocated = False
76-
77-
def get_symbol(self, symbol_name):
78-
if not self.relocated:
79-
if self.libtcc.tcc_relocate(self.state) == -1:
80-
raise Exception("Failed to relocate")
81-
self.relocated = True
86+
def get_symbol(self, state, symbol_name: str):
8287
# Get the symbol for the compiled function
83-
symbol = self.libtcc.tcc_get_symbol(self.state, symbol_name)
88+
symbol = self.libtcc.tcc_get_symbol(state, symbol_name.encode())
8489
if not symbol:
85-
raise Exception(f"Symbol {symbol_name.decode()} not found")
90+
raise Exception(f"Symbol {symbol_name} not found")
8691
return symbol
8792

88-
def cleanup(self):
89-
# Clean up the TCC state
90-
if self.state:
91-
self.libtcc.tcc_delete(self.state)
92-
self.state = None
93-
9493
def __del__(self):
95-
# Ensure clean up is called when the instance is destroyed
96-
self.cleanup()
94+
for state in self.states:
95+
self.libtcc.tcc_delete(state)

0 commit comments

Comments
 (0)