Skip to content

Commit

Permalink
Merge pull request #149 from moorepants/failed-compile
Browse files Browse the repository at this point in the history
Added option to show Cython compilation output and display more informative error.
  • Loading branch information
moorepants committed May 3, 2024
2 parents 599d874 + d7c7037 commit 651df81
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 9 deletions.
17 changes: 11 additions & 6 deletions opty/direct_collocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(self, obj, obj_grad, equations_of_motion, state_symbols,
known_parameter_map={}, known_trajectory_map={},
instance_constraints=None, time_symbol=None, tmp_dir=None,
integration_method='backward euler', parallel=False,
bounds=None):
bounds=None, show_compile_output=False):
"""
Parameters
Expand All @@ -119,8 +119,7 @@ def __init__(self, obj, obj_grad, equations_of_motion, state_symbols,
equations_of_motion, state_symbols, num_collocation_nodes,
node_time_interval, known_parameter_map, known_trajectory_map,
instance_constraints, time_symbol, tmp_dir, integration_method,
parallel
)
parallel, show_compile_output=show_compile_output)

self.bounds = bounds
self.obj = obj
Expand Down Expand Up @@ -388,7 +387,7 @@ def plot_constraint_violations(self, vector):
plot_inst_viols = self.collocator.instance_constraints is not None
fig, axes = plt.subplots(1 + plot_inst_viols, squeeze=False)
axes = axes.ravel()

axes[0].plot(con_nodes, state_violations.T)
axes[0].set_title('Constraint violations')
axes[0].set_xlabel('Node Number')
Expand Down Expand Up @@ -445,7 +444,8 @@ def __init__(self, equations_of_motion, state_symbols,
num_collocation_nodes, node_time_interval,
known_parameter_map={}, known_trajectory_map={},
instance_constraints=None, time_symbol=None, tmp_dir=None,
integration_method='backward euler', parallel=False):
integration_method='backward euler', parallel=False,
show_compile_output=False):
"""Instantiates a ConstraintCollocator object.
Parameters
Expand Down Expand Up @@ -499,6 +499,9 @@ def __init__(self, equations_of_motion, state_symbols,
the constraints will be executed across multiple threads. This is
only useful when the equations of motion have an extremely large
number of operations.
show_compile_output : boolean, optional
If True, STDOUT and STDERR of the Cython compilation call will be
shown.
"""
self.eom = equations_of_motion
Expand All @@ -525,6 +528,7 @@ def __init__(self, equations_of_motion, state_symbols,

self.tmp_dir = tmp_dir
self.parallel = parallel
self.show_compile_output = show_compile_output

self._sort_parameters()
self._check_known_trajectories()
Expand Down Expand Up @@ -934,7 +938,8 @@ def _gen_multi_arg_con_func(self):
logging.info('Compiling the constraint function.')
f = ufuncify_matrix(args, self.discrete_eom,
const=constant_syms + (h_sym,),
tmp_dir=self.tmp_dir, parallel=self.parallel)
tmp_dir=self.tmp_dir, parallel=self.parallel,
show_compile_output=self.show_compile_output)

def constraints(state_values, specified_values, constant_values,
interval_value):
Expand Down
48 changes: 48 additions & 0 deletions opty/tests/test_direct_collocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,54 @@
from pytest import raises

from ..direct_collocation import Problem, ConstraintCollocator
from ..utils import create_objective_function


def test_pendulum():

target_angle = np.pi
duration = 10.0
num_nodes = 500

interval_value = duration / (num_nodes - 1)

# Symbolic equations of motion
I, m, g, d, t = sym.symbols('I, m, g, d, t')
theta, omega, T = sym.symbols('theta, omega, T', cls=sym.Function)

state_symbols = (theta(t), omega(t))
specified_symbols = (T(t),)

eom = sym.Matrix([theta(t).diff() - omega(t),
I*omega(t).diff() + m*g*d*sym.sin(theta(t)) - T(t)])

# Specify the known system parameters.
par_map = OrderedDict()
par_map[I] = 1.0
par_map[m] = 1.0
par_map[g] = 9.81
par_map[d] = 1.0

# Specify the objective function and it's gradient.
obj_func = sym.Integral(T(t)**2, t)
obj, obj_grad = create_objective_function(obj_func, state_symbols,
specified_symbols, tuple(),
num_nodes,
node_time_interval=interval_value)

# Specify the symbolic instance constraints, i.e. initial and end
# conditions.
instance_constraints = (theta(0.0),
theta(duration) - target_angle,
omega(0.0),
omega(duration))

# This will test that a compilation works.
Problem(obj, obj_grad, eom, state_symbols, num_nodes, interval_value,
known_parameter_map=par_map,
instance_constraints=instance_constraints,
bounds={T(t): (-2.0, 2.0)},
show_compile_output=True)


def test_Problem():
Expand Down
20 changes: 17 additions & 3 deletions opty/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,8 @@ def openmp_installed():
return exit


def ufuncify_matrix(args, expr, const=None, tmp_dir=None, parallel=False):
def ufuncify_matrix(args, expr, const=None, tmp_dir=None, parallel=False,
show_compile_output=False):
"""Returns a function that evaluates a matrix of expressions in a tight
loop.
Expand All @@ -491,6 +492,9 @@ def ufuncify_matrix(args, expr, const=None, tmp_dir=None, parallel=False):
If True and openmp is installed, the generated code will be
parallelized across threads. This is only useful when expr are
extremely large.
show_compile_output : boolean, optional
If True, STDOUT and STDERR of the Cython compilation call will be
shown.
"""

Expand Down Expand Up @@ -613,8 +617,18 @@ def ufuncify_matrix(args, expr, const=None, tmp_dir=None, parallel=False):
f.write(code)
cmd = [sys.executable, d['file_prefix'] + '_setup.py', 'build_ext',
'--inplace']
subprocess.call(cmd, stderr=subprocess.STDOUT, stdout=subprocess.PIPE)
cython_module = importlib.import_module(d['file_prefix'])
proc = subprocess.run(cmd, capture_output=True, text=True)
if show_compile_output:
print(proc.stdout)
print(proc.stderr)
try:
cython_module = importlib.import_module(d['file_prefix'])
except ImportError as error:
msg = ('Unable to import the compiled Cython module {}, '
'compilation likely failed. STDERR output from '
'compilation:\n{}')
raise ImportError(msg.format(d['file_prefix'],
proc.stderr)) from error
finally:
module_counter += 1
sys.path.remove(codedir)
Expand Down

0 comments on commit 651df81

Please sign in to comment.