Skip to content

Commit

Permalink
do-mpc model can now be pickled
Browse files Browse the repository at this point in the history
Pickling is only supported for SX symbolic variables
Unfortunately, it requires a minor "hack" of the struct_SX class from the casadi.tools package. There is a bug which hopefully gets resolved in the future. 

Pickling is included in unit testing.
  • Loading branch information
Felix-Mac committed Aug 19, 2022
1 parent 250605d commit 1195360
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 43 deletions.
59 changes: 42 additions & 17 deletions do_mpc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from casadi.tools import *
import pdb
import warnings
from do_mpc.tools.casstructure import _SymVar, _struct_MX, _struct_SX


class IteratedVariables:
Expand Down Expand Up @@ -215,22 +216,6 @@ def t0(self,val):
raise Exception('Passing object of type {} to set the current time. Must be of type {}'.format(type(val), types))


class _SymVar:
def __init__(self, symvar_type):
assert symvar_type in ['SX', 'MX'], 'symvar_type must be either SX or MX, you have: {}'.format(symvar_type)

if symvar_type == 'MX':
self.sym = MX.sym
self.struct = struct_MX
self.sym_struct = struct_symMX
self.dtype = MX
if symvar_type == 'SX':
self.sym = SX.sym
self.struct = struct_SX
self.sym_struct = struct_symSX
self.dtype = SX


class Model:
"""The **do-mpc** model class. This class holds the full model description and is at the core of
:py:class:`do_mpc.simulator.Simulator`, :py:class:`do_mpc.controller.MPC` and :py:class:`do_mpc.estimator.Estimator`.
Expand Down Expand Up @@ -333,6 +318,41 @@ def __init__(self, model_type=None, symvar_type='SX'):
'setup': False
}

def __getstate__(self):
"""
Returns the state of the :py:class:`Model` for pickling.
.. warning::
The :py:class:`Model` class supports pickling only if:
1. The model is configured with ``SX`` variables.
2. The model is setup with :py:func:`setup`.
"""
# Raise exception if model is using MX symvars
if self.symvar_type == 'MX':
raise Exception('Pickling of models using MX symvars is not supported.')
# Raise exception if model is not setup
if not self.flags['setup']:
raise Exception('Pickling of unsetup models is not supported.')

state = self.__dict__.copy()
return state

def __setstate__(self, state):
"""
Sets the state of the :py:class:`Model` for unpickling. Please see :py:func:`__getstate__` for details and restrictions on pickling.
"""
self.__dict__.update(state)

# Update expressions with new symbolic variables created when unpickling:
self._rhs = self._rhs(self._rhs_fun(self._x, self._u, self._z, self._tvp, self._p, self._w))
self._alg = self._alg(self._alg_fun(self._x, self._u, self._z, self._tvp, self._p, self._w))
self._aux_expression = self._aux_expression(self._aux_expression_fun(self._x, self._u, self._z, self._tvp, self._p))
self._y_expression = self._y_expression(self._meas_fun(self._x, self._u, self._z, self._tvp, self._p, self._v))


def __getitem__(self, ind):
"""The :py:class:`Model` class supports the ``__getitem__`` method,
which can be used to retrieve the model variables (see attribute list).
Expand Down Expand Up @@ -1064,7 +1084,12 @@ def _substitute_struct_vars(self, var_dict_list, sym_struct_list, expr):
for var, name in zip(var_dict['var'], var_dict['name']):
subs = substitute(subs, var, sym_struct[name])

return expr(subs)
if self.symvar_type == 'MX':
expr = expr(subs)
else:
expr.master = subs

return expr

def _substitute_exported_vars(self, var_dict_list, sym_struct_list):
"""Helper function for :py:func:`setup`. Not part of the public API.
Expand Down
27 changes: 16 additions & 11 deletions do_mpc/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,26 +768,31 @@ def _setup_discretization(self):
There is no point in calling this method as part of the public API.
"""
# Scaled variables
_x, _u, _z, _tvp, _p, _w = self.model['x', 'u', 'z', 'tvp', 'p', 'w']

rhs = substitute(self.model._rhs, _x, _x*self._x_scaling.cat)
rhs = substitute(rhs, _u, _u*self._u_scaling.cat)
rhs = substitute(rhs, _z, _z*self._z_scaling.cat)
rhs = substitute(rhs, _p, _p*self._p_scaling.cat) # only meaningful for MHE.
# Unscale variables
_x_unscaled = _x*self._x_scaling.cat
_u_unscaled = _u*self._u_scaling.cat
_z_unscaled = _z*self._z_scaling.cat
_p_unscaled = _p*self._p_scaling.cat

# Create _rhs and _alg
_rhs = self.model._rhs_fun(_x_unscaled, _u_unscaled, _z_unscaled, _tvp, _p_unscaled, _w)
_alg = self.model._alg_fun(_x_unscaled, _u_unscaled, _z_unscaled, _tvp, _p_unscaled, _w)

# Scale (only _rhs)
_rhs_scaled = _rhs/self._x_scaling.cat

alg = substitute(self.model._alg, _x, _x*self._x_scaling.cat)
alg = substitute(alg, _u, _u*self._u_scaling.cat)
alg = substitute(alg, _z, _z*self._z_scaling.cat)
alg = substitute(alg, _p, _p*self._p_scaling.cat) # only meaningful for MHE.

if self.model.model_type == 'discrete':
_i = self.model.sv.sym('i', 0)
# discrete integrator ifcs mimics the API the collocation ifcn.
ifcn = Function('ifcn', [_x, _i, _u, _z, _tvp, _p, _w], [alg, rhs/self._x_scaling.cat])
ifcn = Function('ifcn', [_x, _i, _u, _z, _tvp, _p, _w], [_alg, _rhs_scaled])
n_total_coll_points = 0
elif self.state_discretization == 'collocation':
ffcn = Function('ffcn', [_x, _u, _z, _tvp, _p, _w], [rhs/self._x_scaling.cat])
afcn = Function('afcn', [_x, _u, _z, _tvp, _p, _w], [alg])
ffcn = Function('ffcn', [_x, _u, _z, _tvp, _p, _w], [_rhs_scaled])
afcn = Function('afcn', [_x, _u, _z, _tvp, _p, _w], [_alg])
# Get collocation information
coll = self.collocation_type
deg = self.collocation_deg
Expand Down
28 changes: 28 additions & 0 deletions do_mpc/tools/casstructure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from casadi import *
from casadi.tools import *

class _struct_SX(struct_SX):
"""Updated structure class for CasADi structures (SX). This class fixes a bug that prevents unpickeling of the structure."""
def __init__(self, *args, **kwargs):
kwargs.pop('order', None)
super().__init__(*args, **kwargs)

class _struct_MX(struct_MX):
def __init__(self, *args, **kwargs):
kwargs.pop('order', None)
super().__init__(*args, **kwargs)

class _SymVar:
def __init__(self, symvar_type):
assert symvar_type in ['SX', 'MX'], 'symvar_type must be either SX or MX, you have: {}'.format(symvar_type)

if symvar_type == 'MX':
self.sym = MX.sym
self.struct = _struct_MX
self.sym_struct = struct_symMX
self.dtype = MX
if symvar_type == 'SX':
self.sym = SX.sym
self.struct = _struct_SX
self.sym_struct = struct_symSX
self.dtype = SX
24 changes: 19 additions & 5 deletions testing/test_CSTR.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import pdb
import sys
import unittest
import pickle

from importlib import reload
import copy
Expand Down Expand Up @@ -59,18 +60,31 @@ def setUp(self):

def test_SX(self):
print('Testing SX implementation')
self.CSTR('SX')
model = self.template_model.template_model('SX')
self.CSTR(model)

def test_MX(self):
print('Testing MX implementation')
self.CSTR('MX')

def CSTR(self, symvar_type):
model = self.template_model.template_model('MX')
self.CSTR(model)

def test_pickle_unpickle(self):
print('Testing SX implementation with pickle / unpickle')
# Test if pickling / unpickling works for the SX model:
model = self.template_model.template_model('SX')
with open('model.pkl', 'wb') as f:
pickle.dump(model, f)

# Load the casadi structure
with open('model.pkl', 'rb') as f:
model_unpickled = pickle.load(f)
self.CSTR(model_unpickled)

def CSTR(self, model):
"""
Get configured do-mpc modules:
"""

model = self.template_model.template_model(symvar_type)
mpc = self.template_mpc.template_mpc(model)
simulator = self.template_simulator.template_simulator(model)
estimator = do_mpc.estimator.StateFeedback(model)
Expand Down
25 changes: 19 additions & 6 deletions testing/test_oscillating_masses_discrete_dae.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import pdb
import sys
import unittest
import pickle

from importlib import reload
import copy
Expand Down Expand Up @@ -57,18 +58,30 @@ def setUp(self):

def test_SX(self):
print('Testing SX implementation')
self.oscillating_masses_discrete('SX')
model = self.template_model.template_model('SX')
self.oscillating_masses_discrete(model)

def test_MX(self):
print('Testing MX implementation')
self.oscillating_masses_discrete('MX')

def oscillating_masses_discrete(self, symvar_type):
model = self.template_model.template_model('MX')
self.oscillating_masses_discrete(model)

def test_pickle_unpickle(self):
print('Testing pickle and unpickle')
model = self.template_model.template_model('SX')
with open('model.pkl', 'wb') as f:
pickle.dump(model, f)

# Load the casadi structure
with open('model.pkl', 'rb') as f:
model_unpickled = pickle.load(f)
self.oscillating_masses_discrete(model_unpickled)


def oscillating_masses_discrete(self, model):
"""
Get configured do-mpc modules:
"""

model = self.template_model.template_model(symvar_type)
mpc = self.template_mpc.template_mpc(model)
simulator = self.template_simulator.template_simulator(model)
estimator = do_mpc.estimator.StateFeedback(model)
Expand Down
20 changes: 16 additions & 4 deletions testing/test_rotating_oscillating_masses_mhe_mpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import pdb
import sys
import unittest
import pickle

from importlib import reload
import copy
Expand Down Expand Up @@ -58,17 +59,28 @@ def setUp(self):
sys.path = default_path

def test_SX(self):
self.RotatingMasses('SX')
model = self.template_model.template_model('SX')
self.RotatingMasses(model)

def test_MX(self):
self.RotatingMasses('MX')
model = self.template_model.template_model('MX')
self.RotatingMasses(model)

def RotatingMasses(self, symvar_type):
def test_pickle_unpickle(self):
model = self.template_model.template_model('SX')
with open('model.pkl', 'wb') as f:
pickle.dump(model, f)

# Load the casadi structure
with open('model.pkl', 'rb') as f:
model_unpickled = pickle.load(f)
self.RotatingMasses(model_unpickled)

def RotatingMasses(self, model):
"""
Get configured do-mpc modules:
"""

model = self.template_model.template_model(symvar_type)
mpc = self.template_mpc.template_mpc(model)
simulator = self.template_simulator.template_simulator(model)
mhe = self.template_mhe.template_mhe(model)
Expand Down

0 comments on commit 1195360

Please sign in to comment.