diff --git a/.github/workflows/ros-build-test.yml b/.github/workflows/ros-build-test.yml
index 8816680ca..720f066c5 100644
--- a/.github/workflows/ros-build-test.yml
+++ b/.github/workflows/ros-build-test.yml
@@ -30,7 +30,7 @@ jobs:
- name: System deps
run: |
apt-get update
- apt-get install -y git ninja-build liburdfdom-dev liboctomap-dev libassimp-dev checkinstall
+ apt-get install -y git ninja-build liburdfdom-dev liboctomap-dev libassimp-dev checkinstall wget rsync
- uses: actions/checkout@v2
with:
@@ -56,6 +56,12 @@ jobs:
cmake .. -DPYTHON_EXECUTABLE=$(python3 -c "import sys; print(sys.executable)")
make -j4 && checkinstall
+ - name: Install ONNX Runtime
+ run: |
+ wget https://github.com/microsoft/onnxruntime/releases/download/v1.7.0/onnxruntime-linux-x64-1.7.0.tgz -P src
+ tar xf src/onnxruntime-linux-x64-1.7.0.tgz -C src
+ rsync -a src/ocs2/ocs2_mpcnet/ocs2_mpcnet_core/misc/onnxruntime/cmake/ src/onnxruntime-linux-x64-1.7.0/cmake
+
- name: Build (${{ matrix.build_type }})
shell: bash
run: |
diff --git a/.gitignore b/.gitignore
index 15f4690f1..46f83bb33 100644
--- a/.gitignore
+++ b/.gitignore
@@ -25,3 +25,4 @@ ocs2_ddp/test/ddp_test_generated/
*.out
*.synctex.gz
.vscode/
+runs/
diff --git a/jenkins-pipeline b/jenkins-pipeline
index 7deeb22ce..02191156e 100644
--- a/jenkins-pipeline
+++ b/jenkins-pipeline
@@ -1,5 +1,5 @@
library 'continuous_integration_pipeline'
-ciPipeline("--ros-distro noetic --publish-doxygen --recipes raisimlib\
+ciPipeline("--ros-distro noetic --publish-doxygen --recipes onnxruntime raisimlib\
--dependencies 'git@github.com:leggedrobotics/hpp-fcl.git;master;git'\
'git@github.com:leggedrobotics/pinocchio.git;master;git'\
'git@github.com:leggedrobotics/ocs2_robotic_assets.git;main;git'\
diff --git a/ocs2/package.xml b/ocs2/package.xml
index 834b4090e..c088f80f2 100644
--- a/ocs2/package.xml
+++ b/ocs2/package.xml
@@ -24,6 +24,7 @@
ocs2_robotic_examples
ocs2_thirdparty
ocs2_raisim
+ ocs2_mpcnet
diff --git a/ocs2_core/include/ocs2_core/control/ControllerType.h b/ocs2_core/include/ocs2_core/control/ControllerType.h
index 65aa1962f..bbae196ee 100644
--- a/ocs2_core/include/ocs2_core/control/ControllerType.h
+++ b/ocs2_core/include/ocs2_core/control/ControllerType.h
@@ -34,6 +34,6 @@ namespace ocs2 {
/**
* Enum class for specifying controller type
*/
-enum class ControllerType { UNKNOWN, FEEDFORWARD, LINEAR, NEURAL_NETWORK };
+enum class ControllerType { UNKNOWN, FEEDFORWARD, LINEAR, ONNX, BEHAVIORAL };
} // namespace ocs2
diff --git a/ocs2_doc/docs/index.rst b/ocs2_doc/docs/index.rst
index 11450c76d..44a2bf541 100644
--- a/ocs2_doc/docs/index.rst
+++ b/ocs2_doc/docs/index.rst
@@ -11,6 +11,7 @@ Table of Contents
robotic_examples.rst
from_urdf_to_ocp.rst
profiling.rst
+ mpcnet.rst
faq.rst
.. rubric:: Reference and Index:
diff --git a/ocs2_doc/docs/installation.rst b/ocs2_doc/docs/installation.rst
index 1e8213f6e..88375bd5a 100644
--- a/ocs2_doc/docs/installation.rst
+++ b/ocs2_doc/docs/installation.rst
@@ -98,6 +98,42 @@ Optional Dependencies
* `Grid Map `__ catkin package, which may be installed with ``sudo apt install ros-noetic-grid-map-msgs``.
+* `ONNX Runtime `__ is an inferencing and training accelerator. Here, it is used for deploying learned :ref:`MPC-Net ` policies in C++ code. To locally install it, do the following:
+
+ .. code-block:: bash
+
+ wget https://github.com/microsoft/onnxruntime/releases/download/v1.7.0/onnxruntime-linux-x64-1.7.0.tgz -P ~/catkin_ws/src
+ tar xf ~/catkin_ws/src/onnxruntime-linux-x64-1.7.0.tgz -C ~/catkin_ws/src
+ rsync -a ~/catkin_ws/src/ocs2/ocs2_mpcnet/ocs2_mpcnet_core/misc/onnxruntime/cmake/ ~/catkin_ws/src/onnxruntime-linux-x64-1.7.0/cmake
+
+ We provide custom cmake config and version files to enable ``find_package(onnxruntime)`` without modifying ``LIBRARY_PATH`` and ``LD_LIBRARY_PATH``. Note that the last command above assumes that you cloned OCS2 into the folder ``git`` in your user's home directory.
+
+* `Virtual environments `__ are recommended when training :ref:`MPC-Net ` policies:
+
+ .. code-block:: bash
+
+ sudo apt-get install python3-venv
+
+ Create an environment and give it access to the system site packages:
+
+ .. code-block:: bash
+
+ mkdir venvs && cd venvs
+ python3 -m venv mpcnet
+
+ Activate the environment and install the requirements:
+
+ .. code-block:: bash
+
+ source ~/venvs/mpcnet/bin/activate
+ python3 -m pip install -r ~/git/ocs2/ocs2_mpcnet/ocs2_mpcnet_core/requirements.txt
+
+ Newer graphics cards might require a CUDA capability which is currently not supported by the standard PyTorch installation.
+ In that case check `PyTorch Start Locally `__ for a compatible version and, e.g., run:
+
+ .. code-block:: bash
+
+ pip3 install torch==1.10.2+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
.. _doxid-ocs2_doc_installation_ocs2_doc_install:
diff --git a/ocs2_doc/docs/mpcnet.rst b/ocs2_doc/docs/mpcnet.rst
new file mode 100644
index 000000000..39993c850
--- /dev/null
+++ b/ocs2_doc/docs/mpcnet.rst
@@ -0,0 +1,164 @@
+.. index:: pair: page; MPC-Net
+
+.. _doxid-ocs2_doc_mpcnet:
+
+MPC-Net
+=======
+
+MPC-Net is an imitation learning approach that uses solutions from MPC to guide the policy search.
+The main idea is to imitate MPC by minimizing the control Hamiltonian while representing the corresponding control inputs by a parametrized policy.
+MPC-Net can be used to clone a model predictive controller into a neural network policy, which can be evaluated much faster than MPC.
+Therefore, MPC-Net is a useful proxy for MPC in computationally demanding applications that do not require the most exact solution.
+
+The multi-threaded data generation and policy evaluation run asynchronously with the policy training.
+The data generation and policy evaluation are implemented in C++ and run on CPU, while the policy training is implemented in Python and runs on GPU.
+The control Hamiltonian is represented by a linear-quadratic approximation.
+Therefore, the training can run on GPU without callbacks to OCS2 C++ code running on CPU to evaluate the Hamiltonian, and one can exploit batch processing on GPU.
+
+Robots
+~~~~~~
+
+MPC-Net has been implemented for the following :ref:`Robotic Examples `:
+
+============= ================ ================= ======== =============
+Robot Recom. CPU Cores Recom. GPU Memory RaiSim Training Time
+============= ================ ================= ======== =============
+ballbot 4 2 GB No 0m 20s
+legged_robot 12 8 GB Yes / No 7m 40s
+============= ================ ================= ======== =============
+
+Setup
+~~~~~
+
+Make sure to follow the :ref:`Installation ` page.
+Follow all the instructions for the dependencies.
+Regarding the optional dependencies, make sure to follow the instruction for ONNX Runtime and the virtual environment, optionally set up RaiSim.
+
+To build all MPC-Net packages, build the meta package:
+
+.. code-block:: bash
+
+ cd
+ catkin_build ocs2_mpcnet
+
+ # Example:
+ cd ~/catkin_ws
+ catkin_build ocs2_mpcnet
+
+To build a robot-specific package, replace :code:`` with the robot name:
+
+.. code-block:: bash
+
+ cd
+ catkin_build ocs2__mpcnet
+
+ # Example:
+ cd ~/catkin_ws
+ catkin_build ocs2_ballbot_mpcnet
+
+Training
+~~~~~~~~
+
+To train an MPC-Net policy, run:
+
+.. code-block:: bash
+
+ cd /ocs2_mpcnet/ocs2__mpcnet/python/ocs2__mpcnet
+ source /devel/setup.bash
+ source /mpcnet/bin/activate
+ python3 train.py
+
+ # Example:
+ cd ~/git/ocs2/ocs2_mpcnet/ocs2_ballbot_mpcnet/python/ocs2_ballbot_mpcnet
+ source ~/catkin_ws/devel/setup.bash
+ source ~/venvs/mpcnet/bin/activate
+ python3 train.py
+
+To monitor the training progress with Tensorboard, run:
+
+.. code-block:: bash
+
+ cd /ocs2_mpcnet/ocs2__mpcnet/python/ocs2__mpcnet
+ source /mpcnet/bin/activate
+ tensorboard --logdir=runs
+
+ # Example:
+ cd ~/git/ocs2/ocs2_mpcnet/ocs2_ballbot_mpcnet/python/ocs2_ballbot_mpcnet
+ source ~/venvs/mpcnet/bin/activate
+ tensorboard --logdir=runs
+
+If you use RaiSim, you can visualize the data generation and policy evaluation rollouts with RaiSim Unity, where pre-built executables are provided in RaiSim's :code:`raisimUnity` folder. For example, on Linux run:
+
+.. code-block:: bash
+
+ /raisimUnity/linux/raisimUnity.x86_64
+
+ # Example:
+ ~/git/raisimLib/raisimUnity/linux/raisimUnity.x86_64
+
+Deployment
+~~~~~~~~~~
+
+To deploy the default policy stored in the robot-specific package's :code:`policy` folder, run:
+
+.. code-block:: bash
+
+ cd
+ source devel/setup.bash
+ roslaunch ocs2__mpcnet _mpcnet.launch
+
+ # Example:
+ cd ~/catkin_ws
+ source devel/setup.bash
+ roslaunch ocs2_ballbot_mpcnet ballbot_mpcnet.launch
+
+To deploy a new policy stored in the robot-specific package's :code:`python/ocs2__mpcnet/runs` folder, replace :code:`` with the absolute file path to the final policy and run:
+
+.. code-block:: bash
+
+ cd
+ source devel/setup.bash
+ roslaunch ocs2__mpcnet _mpcnet.launch policyFile:=
+
+ # Example:
+ cd ~/catkin_ws
+ source devel/setup.bash
+ roslaunch ocs2_ballbot_mpcnet ballbot_mpcnet.launch policyFile:='/home/user/git/ocs2/ocs2_mpcnet/ocs2_ballbot_mpcnet/python/ocs2_ballbot_mpcnet/runs/2022-04-01_12-00-00_ballbot_description/final_policy.onnx'
+
+How to Set Up a New Robot
+~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Setting up MPC-Net for a new robot is relatively easy, as the **ocs2_mpcnet_core** package takes care of the data generation as well as policy evaluation rollouts and implements important learning components, such as the memory, policy, and loss function.
+
+This section assumes that you already have the packages for the robot-specific MPC implementation:
+
+1. **ocs2_**: Provides the library with the robot-specific MPC implementation.
+2. **ocs2__ros**: Wraps around the MPC implementation with ROS to define ROS nodes.
+3. **ocs2__raisim**: (Optional) interface between the robot-specific MPC implementation and RaiSim.
+
+For the actual **ocs2__mpcnet** package, follow the structure of existing robot-specific MPC-Net packages.
+The most important classes/files that have to be implemented are:
+
+* **MpcnetDefinition**: Defines how OCS2 state variables are transformed to the policy observations. and how policy actions are transformed to OCS2 control inputs.
+* **MpcnetInterface**: Provides the interface between C++ and Python, allowing to exchange data and policies.
+* **.yaml**: Stores the configuration parameters.
+* **mpcnet.py**: Adds robot-specific methods, e.g. implements the tasks that the robot should execute, for the MPC-Net training.
+* **train.py**: Starts the main training script.
+
+Known Issues
+~~~~~~~~~~~~
+
+Stiff inequality constraints can lead to very large Hamiltonians and gradients of the Hamilltonian near the log barrier.
+This can obstruct the learning process and the policy might not learn something useful.
+In that case, enable the gradient clipping in the robot's MPC-Net YAML configuration file and tune the gradient clipping value.
+
+References
+~~~~~~~~~~
+
+This part of the toolbox has been developed based on the following publications:
+
+.. bibliography::
+ :list: enumerated
+
+ carius2020mpcnet
+ reske2021imitation
diff --git a/ocs2_doc/docs/overview.rst b/ocs2_doc/docs/overview.rst
index 6c7e5f68c..6141cc135 100644
--- a/ocs2_doc/docs/overview.rst
+++ b/ocs2_doc/docs/overview.rst
@@ -60,7 +60,8 @@ Michael Spieler (ETHZ),
Jan Carius (ETHZ),
Jean-Pierre Sleiman (ETHZ).
-**Other Developers**:
+**Other Developers**:
+Alexander Reske,
Mayank Mittal,
Johannes Pankert,
Perry Franklin,
diff --git a/ocs2_mpcnet/ocs2_ballbot_mpcnet/CMakeLists.txt b/ocs2_mpcnet/ocs2_ballbot_mpcnet/CMakeLists.txt
new file mode 100644
index 000000000..d95da9050
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_ballbot_mpcnet/CMakeLists.txt
@@ -0,0 +1,128 @@
+cmake_minimum_required(VERSION 3.0.2)
+project(ocs2_ballbot_mpcnet)
+
+set(CATKIN_PACKAGE_DEPENDENCIES
+ ocs2_ballbot
+ ocs2_ballbot_ros
+ ocs2_mpcnet_core
+)
+
+find_package(catkin REQUIRED COMPONENTS
+ ${CATKIN_PACKAGE_DEPENDENCIES}
+)
+
+# Generate compile_commands.json for clang tools
+set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
+
+###################################
+## catkin specific configuration ##
+###################################
+
+catkin_package(
+ INCLUDE_DIRS
+ include
+ LIBRARIES
+ ${PROJECT_NAME}
+ CATKIN_DEPENDS
+ ${CATKIN_PACKAGE_DEPENDENCIES}
+ DEPENDS
+)
+
+###########
+## Build ##
+###########
+
+include_directories(
+ include
+ ${catkin_INCLUDE_DIRS}
+)
+
+# main library
+add_library(${PROJECT_NAME}
+ src/BallbotMpcnetDefinition.cpp
+ src/BallbotMpcnetInterface.cpp
+)
+add_dependencies(${PROJECT_NAME}
+ ${catkin_EXPORTED_TARGETS}
+)
+target_link_libraries(${PROJECT_NAME}
+ ${catkin_LIBRARIES}
+)
+
+# python bindings
+pybind11_add_module(BallbotMpcnetPybindings SHARED
+ src/BallbotMpcnetPybindings.cpp
+)
+add_dependencies(BallbotMpcnetPybindings
+ ${PROJECT_NAME}
+ ${catkin_EXPORTED_TARGETS}
+)
+target_link_libraries(BallbotMpcnetPybindings PRIVATE
+ ${PROJECT_NAME}
+ ${catkin_LIBRARIES}
+)
+set_target_properties(BallbotMpcnetPybindings
+ PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CATKIN_DEVEL_PREFIX}/${CATKIN_PACKAGE_PYTHON_DESTINATION}
+)
+
+# MPC-Net dummy node
+add_executable(ballbot_mpcnet_dummy
+ src/BallbotMpcnetDummyNode.cpp
+)
+add_dependencies(ballbot_mpcnet_dummy
+ ${PROJECT_NAME}
+ ${catkin_EXPORTED_TARGETS}
+)
+target_link_libraries(ballbot_mpcnet_dummy
+ ${PROJECT_NAME}
+ ${catkin_LIBRARIES}
+)
+target_compile_options(ballbot_mpcnet_dummy PRIVATE ${OCS2_CXX_FLAGS})
+
+catkin_python_setup()
+
+#########################
+### CLANG TOOLING ###
+#########################
+find_package(cmake_clang_tools QUIET)
+if(cmake_clang_tools_FOUND)
+ message(STATUS "Run clang tooling for target ocs2_ballbot_mpcnet")
+ add_clang_tooling(
+ TARGETS ${PROJECT_NAME} ballbot_mpcnet_dummy
+ SOURCE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/src ${CMAKE_CURRENT_SOURCE_DIR}/include
+ CT_HEADER_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/include
+ CF_WERROR
+)
+endif(cmake_clang_tools_FOUND)
+
+#############
+## Install ##
+#############
+
+install(TARGETS ${PROJECT_NAME}
+ ARCHIVE DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION}
+ LIBRARY DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION}
+ RUNTIME DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
+)
+
+install(DIRECTORY include/${PROJECT_NAME}/
+ DESTINATION ${CATKIN_PACKAGE_INCLUDE_DESTINATION})
+
+install(TARGETS BallbotMpcnetPybindings
+ ARCHIVE DESTINATION ${CATKIN_PACKAGE_PYTHON_DESTINATION}
+ LIBRARY DESTINATION ${CATKIN_PACKAGE_PYTHON_DESTINATION}
+)
+
+install(TARGETS ballbot_mpcnet_dummy
+ RUNTIME DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
+)
+
+install(DIRECTORY launch policy
+ DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION}
+)
+
+#############
+## Testing ##
+#############
+
+# TODO(areske)
diff --git a/ocs2_mpcnet/ocs2_ballbot_mpcnet/include/ocs2_ballbot_mpcnet/BallbotMpcnetDefinition.h b/ocs2_mpcnet/ocs2_ballbot_mpcnet/include/ocs2_ballbot_mpcnet/BallbotMpcnetDefinition.h
new file mode 100644
index 000000000..e50cf779f
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_ballbot_mpcnet/include/ocs2_ballbot_mpcnet/BallbotMpcnetDefinition.h
@@ -0,0 +1,71 @@
+/******************************************************************************
+Copyright (c) 2022, Farbod Farshidian. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+ * Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+******************************************************************************/
+
+#pragma once
+
+#include
+
+namespace ocs2 {
+namespace ballbot {
+
+/**
+ * MPC-Net definitions for ballbot.
+ */
+class BallbotMpcnetDefinition final : public ocs2::mpcnet::MpcnetDefinitionBase {
+ public:
+ /**
+ * Default constructor.
+ */
+ BallbotMpcnetDefinition() = default;
+
+ /**
+ * Default destructor.
+ */
+ ~BallbotMpcnetDefinition() override = default;
+
+ /**
+ * @see MpcnetDefinitionBase::getObservation
+ */
+ vector_t getObservation(scalar_t t, const vector_t& x, const ModeSchedule& modeSchedule,
+ const TargetTrajectories& targetTrajectories) override;
+
+ /**
+ * @see MpcnetDefinitionBase::getActionTransformation
+ */
+ std::pair getActionTransformation(scalar_t t, const vector_t& x, const ModeSchedule& modeSchedule,
+ const TargetTrajectories& targetTrajectories) override;
+
+ /**
+ * @see MpcnetDefinitionBase::isValid
+ */
+ bool isValid(scalar_t t, const vector_t& x, const ModeSchedule& modeSchedule, const TargetTrajectories& targetTrajectories) override;
+};
+
+} // namespace ballbot
+} // namespace ocs2
diff --git a/ocs2_mpcnet/ocs2_ballbot_mpcnet/include/ocs2_ballbot_mpcnet/BallbotMpcnetInterface.h b/ocs2_mpcnet/ocs2_ballbot_mpcnet/include/ocs2_ballbot_mpcnet/BallbotMpcnetInterface.h
new file mode 100644
index 000000000..4456b7ab5
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_ballbot_mpcnet/include/ocs2_ballbot_mpcnet/BallbotMpcnetInterface.h
@@ -0,0 +1,66 @@
+/******************************************************************************
+Copyright (c) 2022, Farbod Farshidian. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+ * Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+******************************************************************************/
+
+#pragma once
+
+#include
+#include
+
+namespace ocs2 {
+namespace ballbot {
+
+/**
+ * Ballbot MPC-Net interface between C++ and Python.
+ */
+class BallbotMpcnetInterface final : public ocs2::mpcnet::MpcnetInterfaceBase {
+ public:
+ /**
+ * Constructor.
+ * @param [in] nDataGenerationThreads : Number of data generation threads.
+ * @param [in] nPolicyEvaluationThreads : Number of policy evaluation threads.
+ * @param [in] raisim : Whether to use RaiSim for the rollouts.
+ */
+ BallbotMpcnetInterface(size_t nDataGenerationThreads, size_t nPolicyEvaluationThreads, bool raisim);
+
+ /**
+ * Default destructor.
+ */
+ ~BallbotMpcnetInterface() override = default;
+
+ private:
+ /**
+ * Helper to get the MPC.
+ * @param [in] ballbotInterface : The ballbot interface.
+ * @return Pointer to the MPC.
+ */
+ std::unique_ptr getMpc(BallbotInterface& ballbotInterface);
+};
+
+} // namespace ballbot
+} // namespace ocs2
diff --git a/ocs2_mpcnet/ocs2_ballbot_mpcnet/launch/ballbot_mpcnet.launch b/ocs2_mpcnet/ocs2_ballbot_mpcnet/launch/ballbot_mpcnet.launch
new file mode 100644
index 000000000..ca53ee47b
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_ballbot_mpcnet/launch/ballbot_mpcnet.launch
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/ocs2_mpcnet/ocs2_ballbot_mpcnet/package.xml b/ocs2_mpcnet/ocs2_ballbot_mpcnet/package.xml
new file mode 100644
index 000000000..c2448bd1b
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_ballbot_mpcnet/package.xml
@@ -0,0 +1,22 @@
+
+
+ ocs2_ballbot_mpcnet
+ 0.0.0
+ The ocs2_ballbot_mpcnet package
+
+ Alexander Reske
+
+ Farbod Farshidian
+ Alexander Reske
+
+ BSD-3
+
+ catkin
+
+ cmake_clang_tools
+
+ ocs2_ballbot
+ ocs2_ballbot_ros
+ ocs2_mpcnet_core
+
+
diff --git a/ocs2_mpcnet/ocs2_ballbot_mpcnet/policy/ballbot.onnx b/ocs2_mpcnet/ocs2_ballbot_mpcnet/policy/ballbot.onnx
new file mode 100644
index 000000000..56c3a7a8e
Binary files /dev/null and b/ocs2_mpcnet/ocs2_ballbot_mpcnet/policy/ballbot.onnx differ
diff --git a/ocs2_mpcnet/ocs2_ballbot_mpcnet/policy/ballbot.pt b/ocs2_mpcnet/ocs2_ballbot_mpcnet/policy/ballbot.pt
new file mode 100644
index 000000000..06a0acb59
Binary files /dev/null and b/ocs2_mpcnet/ocs2_ballbot_mpcnet/policy/ballbot.pt differ
diff --git a/ocs2_mpcnet/ocs2_ballbot_mpcnet/python/ocs2_ballbot_mpcnet/__init__.py b/ocs2_mpcnet/ocs2_ballbot_mpcnet/python/ocs2_ballbot_mpcnet/__init__.py
new file mode 100644
index 000000000..d5c07a503
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_ballbot_mpcnet/python/ocs2_ballbot_mpcnet/__init__.py
@@ -0,0 +1,2 @@
+from ocs2_ballbot_mpcnet.BallbotMpcnetPybindings import MpcnetInterface
+from ocs2_ballbot_mpcnet.mpcnet import BallbotMpcnet
diff --git a/ocs2_mpcnet/ocs2_ballbot_mpcnet/python/ocs2_ballbot_mpcnet/config/ballbot.yaml b/ocs2_mpcnet/ocs2_ballbot_mpcnet/python/ocs2_ballbot_mpcnet/config/ballbot.yaml
new file mode 100644
index 000000000..1626c547c
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_ballbot_mpcnet/python/ocs2_ballbot_mpcnet/config/ballbot.yaml
@@ -0,0 +1,123 @@
+#
+# general
+#
+# name of the robot
+NAME: "ballbot"
+# description of the training run
+DESCRIPTION: "description"
+# state dimension
+STATE_DIM: 10
+# input dimension
+INPUT_DIM: 3
+# target trajectories state dimension
+TARGET_STATE_DIM: 10
+# target trajectories input dimension
+TARGET_INPUT_DIM: 3
+# observation dimension
+OBSERVATION_DIM: 10
+# action dimension
+ACTION_DIM: 3
+# expert number
+EXPERT_NUM: 1
+# default state
+DEFAULT_STATE:
+ - 0.0 # pose x
+ - 0.0 # pose y
+ - 0.0 # pose yaw
+ - 0.0 # pose pitch
+ - 0.0 # pose roll
+ - 0.0 # twist x
+ - 0.0 # twist y
+ - 0.0 # twist yaw
+ - 0.0 # twist pitch
+ - 0.0 # twist roll
+# default target state
+DEFAULT_TARGET_STATE:
+ - 0.0 # pose x
+ - 0.0 # pose y
+ - 0.0 # pose yaw
+ - 0.0 # pose pitch
+ - 0.0 # pose roll
+ - 0.0 # twist x
+ - 0.0 # twist y
+ - 0.0 # twist yaw
+ - 0.0 # twist pitch
+ - 0.0 # twist roll
+#
+# loss
+#
+# epsilon to improve numerical stability of logs and denominators
+EPSILON: 1.e-8
+# whether to cheat by adding the gating loss
+CHEATING: False
+# parameter to control the relative importance of both loss types
+LAMBDA: 1.0
+# dictionary for the gating loss (assigns modes to experts responsible for the corresponding contact configuration)
+EXPERT_FOR_MODE:
+ 0: 0
+# input cost for behavioral cloning
+R:
+ - 2.0 # torque
+ - 2.0 # torque
+ - 2.0 # torque
+#
+# memory
+#
+# capacity of the memory
+CAPACITY: 100000
+#
+# policy
+#
+# observation scaling
+OBSERVATION_SCALING:
+ - 1.0 # pose x
+ - 1.0 # pose y
+ - 1.0 # pose yaw
+ - 1.0 # pose pitch
+ - 1.0 # pose roll
+ - 1.0 # twist x
+ - 1.0 # twist y
+ - 1.0 # twist yaw
+ - 1.0 # twist pitch
+ - 1.0 # twist roll
+# action scaling
+ACTION_SCALING:
+ - 1.0 # torque
+ - 1.0 # torque
+ - 1.0 # torque
+#
+# rollout
+#
+# RaiSim or TimeTriggered rollout for data generation and policy evaluation
+RAISIM: False
+# settings for data generation
+DATA_GENERATION_TIME_STEP: 0.1
+DATA_GENERATION_DURATION: 3.0
+DATA_GENERATION_DATA_DECIMATION: 1
+DATA_GENERATION_THREADS: 2
+DATA_GENERATION_TASKS: 10
+DATA_GENERATION_SAMPLES: 2
+DATA_GENERATION_SAMPLING_VARIANCE:
+ - 0.01 # pose x
+ - 0.01 # pose y
+ - 0.01745329251 # pose yaw: 1.0 / 180.0 * pi
+ - 0.01745329251 # pose pitch: 1.0 / 180.0 * pi
+ - 0.01745329251 # pose roll: 1.0 / 180.0 * pi
+ - 0.05 # twist x
+ - 0.05 # twist y
+ - 0.08726646259 # twist yaw: 5.0 / 180.0 * pi
+ - 0.08726646259 # twist pitch: 5.0 / 180.0 * pi
+ - 0.08726646259 # twist roll: 5.0 / 180.0 * pi
+# settings for computing metrics
+POLICY_EVALUATION_TIME_STEP: 0.1
+POLICY_EVALUATION_DURATION: 3.0
+POLICY_EVALUATION_THREADS: 1
+POLICY_EVALUATION_TASKS: 5
+#
+# training
+#
+BATCH_SIZE: 32
+LEARNING_RATE: 1.e-2
+LEARNING_ITERATIONS: 10000
+GRADIENT_CLIPPING: False
+GRADIENT_CLIPPING_VALUE: 1.0
diff --git a/ocs2_mpcnet/ocs2_ballbot_mpcnet/python/ocs2_ballbot_mpcnet/mpcnet.py b/ocs2_mpcnet/ocs2_ballbot_mpcnet/python/ocs2_ballbot_mpcnet/mpcnet.py
new file mode 100644
index 000000000..86b8fe743
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_ballbot_mpcnet/python/ocs2_ballbot_mpcnet/mpcnet.py
@@ -0,0 +1,137 @@
+###############################################################################
+# Copyright (c) 2022, Farbod Farshidian. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+###############################################################################
+
+"""Ballbot MPC-Net class.
+
+Provides a class that handles the MPC-Net training for ballbot.
+"""
+
+import numpy as np
+from typing import Tuple
+
+from ocs2_mpcnet_core import helper
+from ocs2_mpcnet_core import mpcnet
+from ocs2_mpcnet_core import SystemObservationArray, ModeScheduleArray, TargetTrajectoriesArray
+
+
+class BallbotMpcnet(mpcnet.Mpcnet):
+ """Ballbot MPC-Net.
+
+ Adds robot-specific methods for the MPC-Net training.
+ """
+
+ @staticmethod
+ def get_default_event_times_and_mode_sequence(duration: float) -> Tuple[np.ndarray, np.ndarray]:
+ """Get the event times and mode sequence describing the default mode schedule.
+
+ Creates the default event times and mode sequence for a certain time duration.
+
+ Args:
+ duration: The duration of the mode schedule given by a float.
+
+ Returns:
+ A tuple containing the components of the mode schedule.
+ - event_times: The event times given by a NumPy array of shape (K-1) containing floats.
+ - mode_sequence: The mode sequence given by a NumPy array of shape (K) containing integers.
+ """
+ event_times_template = np.array([1.0], dtype=np.float64)
+ mode_sequence_template = np.array([0], dtype=np.uintp)
+ return helper.get_event_times_and_mode_sequence(0, duration, event_times_template, mode_sequence_template)
+
+ def get_random_initial_state(self) -> np.ndarray:
+ """Get a random initial state.
+
+ Samples a random initial state for the robot.
+
+ Returns:
+ x: A random initial state given by a NumPy array containing floats.
+ """
+ max_linear_velocity_x = 0.5
+ max_linear_velocity_y = 0.5
+ max_euler_angle_derivative_z = 45.0 / 180.0 * np.pi
+ max_euler_angle_derivative_y = 45.0 / 180.0 * np.pi
+ max_euler_angle_derivative_x = 45.0 / 180.0 * np.pi
+ random_deviation = np.zeros(self.config.STATE_DIM)
+ random_deviation[5] = np.random.uniform(-max_linear_velocity_x, max_linear_velocity_x)
+ random_deviation[6] = np.random.uniform(-max_linear_velocity_y, max_linear_velocity_y)
+ random_deviation[7] = np.random.uniform(-max_euler_angle_derivative_z, max_euler_angle_derivative_z)
+ random_deviation[8] = np.random.uniform(-max_euler_angle_derivative_y, max_euler_angle_derivative_y)
+ random_deviation[9] = np.random.uniform(-max_euler_angle_derivative_x, max_euler_angle_derivative_x)
+ return np.array(self.config.DEFAULT_STATE) + random_deviation
+
+ def get_random_target_state(self) -> np.ndarray:
+ """Get a random target state.
+
+ Samples a random target state for the robot.
+
+ Returns:
+ x: A random target state given by a NumPy array containing floats.
+ """
+ max_position_x = 1.0
+ max_position_y = 1.0
+ max_orientation_z = 45.0 / 180.0 * np.pi
+ random_deviation = np.zeros(self.config.TARGET_STATE_DIM)
+ random_deviation[0] = np.random.uniform(-max_position_x, max_position_x)
+ random_deviation[1] = np.random.uniform(-max_position_y, max_position_y)
+ random_deviation[2] = np.random.uniform(-max_orientation_z, max_orientation_z)
+ return np.array(self.config.DEFAULT_TARGET_STATE) + random_deviation
+
+ def get_tasks(
+ self, tasks_number: int, duration: float
+ ) -> Tuple[SystemObservationArray, ModeScheduleArray, TargetTrajectoriesArray]:
+ """Get tasks.
+
+ Get a random set of task that should be executed by the data generation or policy evaluation.
+
+ Args:
+ tasks_number: Number of tasks given by an integer.
+ duration: Duration of each task given by a float.
+
+ Returns:
+ A tuple containing the components of the task.
+ - initial_observations: The initial observations given by an OCS2 system observation array.
+ - mode_schedules: The desired mode schedules given by an OCS2 mode schedule array.
+ - target_trajectories: The desired target trajectories given by an OCS2 target trajectories array.
+ """
+ initial_mode = 0
+ initial_time = 0.0
+ initial_observations = helper.get_system_observation_array(tasks_number)
+ mode_schedules = helper.get_mode_schedule_array(tasks_number)
+ target_trajectories = helper.get_target_trajectories_array(tasks_number)
+ for i in range(tasks_number):
+ initial_observations[i] = helper.get_system_observation(
+ initial_mode, initial_time, self.get_random_initial_state(), np.zeros(self.config.INPUT_DIM)
+ )
+ mode_schedules[i] = helper.get_mode_schedule(*self.get_default_event_times_and_mode_sequence(duration))
+ target_trajectories[i] = helper.get_target_trajectories(
+ duration * np.ones((1, 1)),
+ self.get_random_target_state().reshape((1, self.config.TARGET_STATE_DIM)),
+ np.zeros((1, self.config.TARGET_INPUT_DIM)),
+ )
+ return initial_observations, mode_schedules, target_trajectories
diff --git a/ocs2_mpcnet/ocs2_ballbot_mpcnet/python/ocs2_ballbot_mpcnet/train.py b/ocs2_mpcnet/ocs2_ballbot_mpcnet/python/ocs2_ballbot_mpcnet/train.py
new file mode 100644
index 000000000..4fa5c59ce
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_ballbot_mpcnet/python/ocs2_ballbot_mpcnet/train.py
@@ -0,0 +1,71 @@
+#!/usr/bin/env python3
+
+###############################################################################
+# Copyright (c) 2022, Farbod Farshidian. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+###############################################################################
+
+"""Ballbot MPC-Net.
+
+Main script for training an MPC-Net policy for ballbot.
+"""
+
+import sys
+import os
+
+from ocs2_mpcnet_core.config import Config
+from ocs2_mpcnet_core.loss import HamiltonianLoss
+from ocs2_mpcnet_core.memory import CircularMemory
+from ocs2_mpcnet_core.policy import LinearPolicy
+
+from ocs2_ballbot_mpcnet import BallbotMpcnet
+from ocs2_ballbot_mpcnet import MpcnetInterface
+
+
+def main(root_dir: str, config_file_name: str) -> None:
+ # config
+ config = Config(os.path.join(root_dir, "config", config_file_name))
+ # interface
+ interface = MpcnetInterface(config.DATA_GENERATION_THREADS, config.POLICY_EVALUATION_THREADS, config.RAISIM)
+ # loss
+ loss = HamiltonianLoss(config)
+ # memory
+ memory = CircularMemory(config)
+ # policy
+ policy = LinearPolicy(config)
+ # mpcnet
+ mpcnet = BallbotMpcnet(root_dir, config, interface, memory, policy, loss)
+ # train
+ mpcnet.train()
+
+
+if __name__ == "__main__":
+ root_dir = os.path.dirname(os.path.abspath(__file__))
+ if len(sys.argv) > 1:
+ main(root_dir, sys.argv[1])
+ else:
+ main(root_dir, "ballbot.yaml")
diff --git a/ocs2_mpcnet/ocs2_ballbot_mpcnet/setup.py b/ocs2_mpcnet/ocs2_ballbot_mpcnet/setup.py
new file mode 100644
index 000000000..959f0cad9
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_ballbot_mpcnet/setup.py
@@ -0,0 +1,8 @@
+#!/usr/bin/env python
+
+from setuptools import setup
+from catkin_pkg.python_setup import generate_distutils_setup
+
+setup_args = generate_distutils_setup(packages=["ocs2_ballbot_mpcnet"], package_dir={"": "python"})
+
+setup(**setup_args)
diff --git a/ocs2_mpcnet/ocs2_ballbot_mpcnet/src/BallbotMpcnetDefinition.cpp b/ocs2_mpcnet/ocs2_ballbot_mpcnet/src/BallbotMpcnetDefinition.cpp
new file mode 100644
index 000000000..81ec0bc06
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_ballbot_mpcnet/src/BallbotMpcnetDefinition.cpp
@@ -0,0 +1,57 @@
+/******************************************************************************
+Copyright (c) 2022, Farbod Farshidian. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+ * Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+******************************************************************************/
+
+#include "ocs2_ballbot_mpcnet/BallbotMpcnetDefinition.h"
+
+namespace ocs2 {
+namespace ballbot {
+
+vector_t BallbotMpcnetDefinition::getObservation(scalar_t t, const vector_t& x, const ModeSchedule& modeSchedule,
+ const TargetTrajectories& targetTrajectories) {
+ vector_t observation = x - targetTrajectories.getDesiredState(t);
+ const Eigen::Matrix R =
+ (Eigen::Matrix() << cos(x(2)), -sin(x(2)), sin(x(2)), cos(x(2))).finished().transpose();
+ observation.segment<2>(0) = R * observation.segment<2>(0);
+ observation.segment<2>(5) = R * observation.segment<2>(5);
+ return observation;
+}
+
+std::pair BallbotMpcnetDefinition::getActionTransformation(scalar_t t, const vector_t& x,
+ const ModeSchedule& modeSchedule,
+ const TargetTrajectories& targetTrajectories) {
+ return {matrix_t::Identity(3, 3), vector_t::Zero(3)};
+}
+
+bool BallbotMpcnetDefinition::isValid(scalar_t t, const vector_t& x, const ModeSchedule& modeSchedule,
+ const TargetTrajectories& targetTrajectories) {
+ return true;
+}
+
+} // namespace ballbot
+} // namespace ocs2
diff --git a/ocs2_mpcnet/ocs2_ballbot_mpcnet/src/BallbotMpcnetDummyNode.cpp b/ocs2_mpcnet/ocs2_ballbot_mpcnet/src/BallbotMpcnetDummyNode.cpp
new file mode 100644
index 000000000..010fc2a8b
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_ballbot_mpcnet/src/BallbotMpcnetDummyNode.cpp
@@ -0,0 +1,110 @@
+/******************************************************************************
+Copyright (c) 2022, Farbod Farshidian. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+ * Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+******************************************************************************/
+
+#include
+#include
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "ocs2_ballbot_mpcnet/BallbotMpcnetDefinition.h"
+
+using namespace ocs2;
+using namespace ballbot;
+
+int main(int argc, char** argv) {
+ const std::string robotName = "ballbot";
+
+ // task and policy file
+ std::vector programArgs{};
+ ::ros::removeROSArgs(argc, argv, programArgs);
+ if (programArgs.size() <= 2) {
+ throw std::runtime_error("No task name and policy file path specified. Aborting.");
+ }
+ const std::string taskFileFolderName = std::string(programArgs[1]);
+ const std::string policyFilePath = std::string(programArgs[2]);
+
+ // initialize ros node
+ ros::init(argc, argv, robotName + "_mpcnet_dummy");
+ ros::NodeHandle nodeHandle;
+
+ // ballbot interface
+ const std::string taskFile = ros::package::getPath("ocs2_ballbot") + "/config/" + taskFileFolderName + "/task.info";
+ const std::string libraryFolder = ros::package::getPath("ocs2_ballbot") + "/auto_generated";
+ BallbotInterface ballbotInterface(taskFile, libraryFolder);
+
+ // ROS reference manager
+ auto rosReferenceManagerPtr = std::make_shared(robotName, ballbotInterface.getReferenceManagerPtr());
+ rosReferenceManagerPtr->subscribe(nodeHandle);
+
+ // policy (MPC-Net controller)
+ auto onnxEnvironmentPtr = ocs2::mpcnet::createOnnxEnvironment();
+ std::shared_ptr mpcnetDefinitionPtr(new BallbotMpcnetDefinition);
+ std::unique_ptr mpcnetControllerPtr(
+ new ocs2::mpcnet::MpcnetOnnxController(mpcnetDefinitionPtr, rosReferenceManagerPtr, onnxEnvironmentPtr));
+ mpcnetControllerPtr->loadPolicyModel(policyFilePath);
+
+ // rollout
+ std::unique_ptr rolloutPtr(ballbotInterface.getRollout().clone());
+
+ // observer
+ std::shared_ptr mpcnetDummyObserverRosPtr(
+ new ocs2::mpcnet::MpcnetDummyObserverRos(nodeHandle, robotName));
+
+ // visualization
+ std::shared_ptr ballbotDummyVisualization(new BallbotDummyVisualization(nodeHandle));
+
+ // MPC-Net dummy loop ROS
+ const scalar_t controlFrequency = ballbotInterface.mpcSettings().mrtDesiredFrequency_;
+ const scalar_t rosFrequency = ballbotInterface.mpcSettings().mpcDesiredFrequency_;
+ ocs2::mpcnet::MpcnetDummyLoopRos mpcnetDummyLoopRos(controlFrequency, rosFrequency, std::move(mpcnetControllerPtr), std::move(rolloutPtr),
+ rosReferenceManagerPtr);
+ mpcnetDummyLoopRos.addObserver(mpcnetDummyObserverRosPtr);
+ mpcnetDummyLoopRos.addObserver(ballbotDummyVisualization);
+
+ // initial system observation
+ SystemObservation systemObservation;
+ systemObservation.mode = 0;
+ systemObservation.time = 0.0;
+ systemObservation.state = ballbotInterface.getInitialState();
+ systemObservation.input = vector_t::Zero(ocs2::ballbot::INPUT_DIM);
+
+ // initial target trajectories
+ const TargetTrajectories targetTrajectories({systemObservation.time}, {systemObservation.state}, {systemObservation.input});
+
+ // run MPC-Net dummy loop ROS
+ mpcnetDummyLoopRos.run(systemObservation, targetTrajectories);
+
+ // successful exit
+ return 0;
+}
\ No newline at end of file
diff --git a/ocs2_mpcnet/ocs2_ballbot_mpcnet/src/BallbotMpcnetInterface.cpp b/ocs2_mpcnet/ocs2_ballbot_mpcnet/src/BallbotMpcnetInterface.cpp
new file mode 100644
index 000000000..dfe15a9a5
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_ballbot_mpcnet/src/BallbotMpcnetInterface.cpp
@@ -0,0 +1,111 @@
+/******************************************************************************
+Copyright (c) 2022, Farbod Farshidian. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+ * Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+******************************************************************************/
+
+#include "ocs2_ballbot_mpcnet/BallbotMpcnetInterface.h"
+
+#include
+
+#include
+#include
+#include
+#include
+
+#include "ocs2_ballbot_mpcnet/BallbotMpcnetDefinition.h"
+
+namespace ocs2 {
+namespace ballbot {
+
+BallbotMpcnetInterface::BallbotMpcnetInterface(size_t nDataGenerationThreads, size_t nPolicyEvaluationThreads, bool raisim) {
+ // create ONNX environment
+ auto onnxEnvironmentPtr = ocs2::mpcnet::createOnnxEnvironment();
+ // path to config file
+ const std::string taskFile = ros::package::getPath("ocs2_ballbot") + "/config/mpc/task.info";
+ // path to save auto-generated libraries
+ const std::string libraryFolder = ros::package::getPath("ocs2_ballbot") + "/auto_generated";
+ // set up MPC-Net rollout manager for data generation and policy evaluation
+ std::vector> mpcPtrs;
+ std::vector> mpcnetPtrs;
+ std::vector> rolloutPtrs;
+ std::vector> mpcnetDefinitionPtrs;
+ std::vector> referenceManagerPtrs;
+ mpcPtrs.reserve(nDataGenerationThreads + nPolicyEvaluationThreads);
+ mpcnetPtrs.reserve(nDataGenerationThreads + nPolicyEvaluationThreads);
+ rolloutPtrs.reserve(nDataGenerationThreads + nPolicyEvaluationThreads);
+ mpcnetDefinitionPtrs.reserve(nDataGenerationThreads + nPolicyEvaluationThreads);
+ referenceManagerPtrs.reserve(nDataGenerationThreads + nPolicyEvaluationThreads);
+ for (int i = 0; i < (nDataGenerationThreads + nPolicyEvaluationThreads); i++) {
+ BallbotInterface ballbotInterface(taskFile, libraryFolder);
+ std::shared_ptr mpcnetDefinitionPtr(new BallbotMpcnetDefinition);
+ mpcPtrs.push_back(getMpc(ballbotInterface));
+ mpcnetPtrs.push_back(std::unique_ptr(
+ new ocs2::mpcnet::MpcnetOnnxController(mpcnetDefinitionPtr, ballbotInterface.getReferenceManagerPtr(), onnxEnvironmentPtr)));
+ if (raisim) {
+ throw std::runtime_error("[BallbotMpcnetInterface::BallbotMpcnetInterface] raisim rollout not yet implemented for ballbot.");
+ } else {
+ rolloutPtrs.push_back(std::unique_ptr(ballbotInterface.getRollout().clone()));
+ }
+ mpcnetDefinitionPtrs.push_back(mpcnetDefinitionPtr);
+ referenceManagerPtrs.push_back(ballbotInterface.getReferenceManagerPtr());
+ }
+ mpcnetRolloutManagerPtr_.reset(new ocs2::mpcnet::MpcnetRolloutManager(nDataGenerationThreads, nPolicyEvaluationThreads,
+ std::move(mpcPtrs), std::move(mpcnetPtrs), std::move(rolloutPtrs),
+ mpcnetDefinitionPtrs, referenceManagerPtrs));
+}
+
+/******************************************************************************************************/
+/******************************************************************************************************/
+/******************************************************************************************************/
+std::unique_ptr BallbotMpcnetInterface::getMpc(BallbotInterface& ballbotInterface) {
+ // ensure MPC and DDP settings are as needed for MPC-Net
+ const auto mpcSettings = [&]() {
+ auto settings = ballbotInterface.mpcSettings();
+ settings.debugPrint_ = false;
+ settings.coldStart_ = false;
+ return settings;
+ }();
+ const auto ddpSettings = [&]() {
+ auto settings = ballbotInterface.ddpSettings();
+ settings.algorithm_ = ocs2::ddp::Algorithm::SLQ;
+ settings.nThreads_ = 1;
+ settings.displayInfo_ = false;
+ settings.displayShortSummary_ = false;
+ settings.checkNumericalStability_ = false;
+ settings.debugPrintRollout_ = false;
+ settings.useFeedbackPolicy_ = true;
+ return settings;
+ }();
+ // create one MPC instance
+ std::unique_ptr mpcPtr(new GaussNewtonDDP_MPC(mpcSettings, ddpSettings, ballbotInterface.getRollout(),
+ ballbotInterface.getOptimalControlProblem(), ballbotInterface.getInitializer()));
+ mpcPtr->getSolverPtr()->setReferenceManager(ballbotInterface.getReferenceManagerPtr());
+ return mpcPtr;
+}
+
+} // namespace ballbot
+} // namespace ocs2
diff --git a/ocs2_mpcnet/ocs2_ballbot_mpcnet/src/BallbotMpcnetPybindings.cpp b/ocs2_mpcnet/ocs2_ballbot_mpcnet/src/BallbotMpcnetPybindings.cpp
new file mode 100644
index 000000000..c08e80a5d
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_ballbot_mpcnet/src/BallbotMpcnetPybindings.cpp
@@ -0,0 +1,34 @@
+/******************************************************************************
+Copyright (c) 2022, Farbod Farshidian. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+ * Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+******************************************************************************/
+
+#include
+
+#include "ocs2_ballbot_mpcnet/BallbotMpcnetInterface.h"
+
+CREATE_ROBOT_MPCNET_PYTHON_BINDINGS(ocs2::ballbot::BallbotMpcnetInterface, BallbotMpcnetPybindings)
diff --git a/ocs2_mpcnet/ocs2_legged_robot_mpcnet/CMakeLists.txt b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/CMakeLists.txt
new file mode 100644
index 000000000..6addac114
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/CMakeLists.txt
@@ -0,0 +1,129 @@
+cmake_minimum_required(VERSION 3.0.2)
+project(ocs2_legged_robot_mpcnet)
+
+set(CATKIN_PACKAGE_DEPENDENCIES
+ ocs2_legged_robot
+ ocs2_legged_robot_raisim
+ ocs2_legged_robot_ros
+ ocs2_mpcnet_core
+)
+
+find_package(catkin REQUIRED COMPONENTS
+ ${CATKIN_PACKAGE_DEPENDENCIES}
+)
+
+# Generate compile_commands.json for clang tools
+set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
+
+###################################
+## catkin specific configuration ##
+###################################
+
+catkin_package(
+ INCLUDE_DIRS
+ include
+ LIBRARIES
+ ${PROJECT_NAME}
+ CATKIN_DEPENDS
+ ${CATKIN_PACKAGE_DEPENDENCIES}
+ DEPENDS
+)
+
+###########
+## Build ##
+###########
+
+include_directories(
+ include
+ ${catkin_INCLUDE_DIRS}
+)
+
+# main library
+add_library(${PROJECT_NAME}
+ src/LeggedRobotMpcnetDefinition.cpp
+ src/LeggedRobotMpcnetInterface.cpp
+)
+add_dependencies(${PROJECT_NAME}
+ ${catkin_EXPORTED_TARGETS}
+)
+target_link_libraries(${PROJECT_NAME}
+ ${catkin_LIBRARIES}
+)
+
+# python bindings
+pybind11_add_module(LeggedRobotMpcnetPybindings SHARED
+ src/LeggedRobotMpcnetPybindings.cpp
+)
+add_dependencies(LeggedRobotMpcnetPybindings
+ ${PROJECT_NAME}
+ ${catkin_EXPORTED_TARGETS}
+)
+target_link_libraries(LeggedRobotMpcnetPybindings PRIVATE
+ ${PROJECT_NAME}
+ ${catkin_LIBRARIES}
+)
+set_target_properties(LeggedRobotMpcnetPybindings
+ PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CATKIN_DEVEL_PREFIX}/${CATKIN_PACKAGE_PYTHON_DESTINATION}
+)
+
+# MPC-Net dummy node
+add_executable(legged_robot_mpcnet_dummy
+ src/LeggedRobotMpcnetDummyNode.cpp
+)
+add_dependencies(legged_robot_mpcnet_dummy
+ ${PROJECT_NAME}
+ ${catkin_EXPORTED_TARGETS}
+)
+target_link_libraries(legged_robot_mpcnet_dummy
+ ${PROJECT_NAME}
+ ${catkin_LIBRARIES}
+)
+target_compile_options(legged_robot_mpcnet_dummy PRIVATE ${OCS2_CXX_FLAGS})
+
+catkin_python_setup()
+
+#########################
+### CLANG TOOLING ###
+#########################
+find_package(cmake_clang_tools QUIET)
+if(cmake_clang_tools_FOUND)
+ message(STATUS "Run clang tooling for target ocs2_legged_robot_mpcnet")
+ add_clang_tooling(
+ TARGETS ${PROJECT_NAME} legged_robot_mpcnet_dummy
+ SOURCE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/src ${CMAKE_CURRENT_SOURCE_DIR}/include
+ CT_HEADER_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/include
+ CF_WERROR
+)
+endif(cmake_clang_tools_FOUND)
+
+#############
+## Install ##
+#############
+
+install(TARGETS ${PROJECT_NAME}
+ ARCHIVE DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION}
+ LIBRARY DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION}
+ RUNTIME DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
+)
+
+install(DIRECTORY include/${PROJECT_NAME}/
+ DESTINATION ${CATKIN_PACKAGE_INCLUDE_DESTINATION})
+
+install(TARGETS LeggedRobotMpcnetPybindings
+ ARCHIVE DESTINATION ${CATKIN_PACKAGE_PYTHON_DESTINATION}
+ LIBRARY DESTINATION ${CATKIN_PACKAGE_PYTHON_DESTINATION}
+)
+
+install(TARGETS legged_robot_mpcnet_dummy
+ RUNTIME DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
+)
+
+install(DIRECTORY launch policy
+ DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION}
+)
+
+#############
+## Testing ##
+#############
+
+# TODO(areske)
diff --git a/ocs2_mpcnet/ocs2_legged_robot_mpcnet/include/ocs2_legged_robot_mpcnet/LeggedRobotMpcnetDefinition.h b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/include/ocs2_legged_robot_mpcnet/LeggedRobotMpcnetDefinition.h
new file mode 100644
index 000000000..8a6b44f92
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/include/ocs2_legged_robot_mpcnet/LeggedRobotMpcnetDefinition.h
@@ -0,0 +1,81 @@
+/******************************************************************************
+Copyright (c) 2022, Farbod Farshidian. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+ * Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+******************************************************************************/
+
+#pragma once
+
+#include
+#include
+
+namespace ocs2 {
+namespace legged_robot {
+
+/**
+ * MPC-Net definitions for legged robot.
+ */
+class LeggedRobotMpcnetDefinition final : public ocs2::mpcnet::MpcnetDefinitionBase {
+ public:
+ /**
+ * Constructor.
+ * @param [in] leggedRobotInterface : Legged robot interface.
+ */
+ LeggedRobotMpcnetDefinition(const LeggedRobotInterface& leggedRobotInterface)
+ : defaultState_(leggedRobotInterface.getInitialState()), centroidalModelInfo_(leggedRobotInterface.getCentroidalModelInfo()) {}
+
+ /**
+ * Default destructor.
+ */
+ ~LeggedRobotMpcnetDefinition() override = default;
+
+ /**
+ * @see MpcnetDefinitionBase::getObservation
+ */
+ vector_t getObservation(scalar_t t, const vector_t& x, const ModeSchedule& modeSchedule,
+ const TargetTrajectories& targetTrajectories) override;
+
+ /**
+ * @see MpcnetDefinitionBase::getActionTransformation
+ */
+ std::pair getActionTransformation(scalar_t t, const vector_t& x, const ModeSchedule& modeSchedule,
+ const TargetTrajectories& targetTrajectories) override;
+
+ /**
+ * @see MpcnetDefinitionBase::isValid
+ */
+ bool isValid(scalar_t t, const vector_t& x, const ModeSchedule& modeSchedule, const TargetTrajectories& targetTrajectories) override;
+
+ private:
+ const scalar_t allowedHeightDeviation_ = 0.2;
+ const scalar_t allowedPitchDeviation_ = 30.0 * M_PI / 180.0;
+ const scalar_t allowedRollDeviation_ = 30.0 * M_PI / 180.0;
+ const vector_t defaultState_;
+ const CentroidalModelInfo centroidalModelInfo_;
+};
+
+} // namespace legged_robot
+} // namespace ocs2
diff --git a/ocs2_mpcnet/ocs2_legged_robot_mpcnet/include/ocs2_legged_robot_mpcnet/LeggedRobotMpcnetInterface.h b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/include/ocs2_legged_robot_mpcnet/LeggedRobotMpcnetInterface.h
new file mode 100644
index 000000000..3acb1d0ee
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/include/ocs2_legged_robot_mpcnet/LeggedRobotMpcnetInterface.h
@@ -0,0 +1,72 @@
+/******************************************************************************
+Copyright (c) 2022, Farbod Farshidian. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+ * Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+******************************************************************************/
+
+#pragma once
+
+#include
+#include
+#include
+
+namespace ocs2 {
+namespace legged_robot {
+
+/**
+ * Legged robot MPC-Net interface between C++ and Python.
+ */
+class LeggedRobotMpcnetInterface final : public ocs2::mpcnet::MpcnetInterfaceBase {
+ public:
+ /**
+ * Constructor.
+ * @param [in] nDataGenerationThreads : Number of data generation threads.
+ * @param [in] nPolicyEvaluationThreads : Number of policy evaluation threads.
+ * @param [in] raisim : Whether to use RaiSim for the rollouts.
+ */
+ LeggedRobotMpcnetInterface(size_t nDataGenerationThreads, size_t nPolicyEvaluationThreads, bool raisim);
+
+ /**
+ * Default destructor.
+ */
+ ~LeggedRobotMpcnetInterface() override = default;
+
+ private:
+ /**
+ * Helper to get the MPC.
+ * @param [in] leggedRobotInterface : The legged robot interface.
+ * @return Pointer to the MPC.
+ */
+ std::unique_ptr getMpc(LeggedRobotInterface& leggedRobotInterface);
+
+ // Legged robot interface pointers (keep alive for Pinocchio interface)
+ std::vector> leggedRobotInterfacePtrs_;
+ // Legged robot RaiSim conversions pointers (keep alive for RaiSim rollout)
+ std::vector> leggedRobotRaisimConversionsPtrs_;
+};
+
+} // namespace legged_robot
+} // namespace ocs2
diff --git a/ocs2_mpcnet/ocs2_legged_robot_mpcnet/launch/legged_robot_mpcnet.launch b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/launch/legged_robot_mpcnet.launch
new file mode 100644
index 000000000..59311840a
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/launch/legged_robot_mpcnet.launch
@@ -0,0 +1,56 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/ocs2_mpcnet/ocs2_legged_robot_mpcnet/package.xml b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/package.xml
new file mode 100644
index 000000000..c39765d14
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/package.xml
@@ -0,0 +1,23 @@
+
+
+ ocs2_legged_robot_mpcnet
+ 0.0.0
+ The ocs2_legged_robot_mpcnet package
+
+ Alexander Reske
+
+ Farbod Farshidian
+ Alexander Reske
+
+ BSD-3
+
+ catkin
+
+ cmake_clang_tools
+
+ ocs2_legged_robot
+ ocs2_legged_robot_raisim
+ ocs2_legged_robot_ros
+ ocs2_mpcnet_core
+
+
diff --git a/ocs2_mpcnet/ocs2_legged_robot_mpcnet/policy/legged_robot.onnx b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/policy/legged_robot.onnx
new file mode 100644
index 000000000..ef5939cdf
Binary files /dev/null and b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/policy/legged_robot.onnx differ
diff --git a/ocs2_mpcnet/ocs2_legged_robot_mpcnet/policy/legged_robot.pt b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/policy/legged_robot.pt
new file mode 100644
index 000000000..3eb2df09d
Binary files /dev/null and b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/policy/legged_robot.pt differ
diff --git a/ocs2_mpcnet/ocs2_legged_robot_mpcnet/python/ocs2_legged_robot_mpcnet/__init__.py b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/python/ocs2_legged_robot_mpcnet/__init__.py
new file mode 100644
index 000000000..688527a67
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/python/ocs2_legged_robot_mpcnet/__init__.py
@@ -0,0 +1,2 @@
+from ocs2_legged_robot_mpcnet.LeggedRobotMpcnetPybindings import MpcnetInterface
+from ocs2_legged_robot_mpcnet.mpcnet import LeggedRobotMpcnet
diff --git a/ocs2_mpcnet/ocs2_legged_robot_mpcnet/python/ocs2_legged_robot_mpcnet/config/legged_robot.yaml b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/python/ocs2_legged_robot_mpcnet/config/legged_robot.yaml
new file mode 100644
index 000000000..0b6dde905
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/python/ocs2_legged_robot_mpcnet/config/legged_robot.yaml
@@ -0,0 +1,240 @@
+#
+# general
+#
+# name of the robot
+NAME: "legged_robot"
+# description of the training run
+DESCRIPTION: "description"
+# state dimension
+STATE_DIM: 24
+# input dimension
+INPUT_DIM: 24
+# target trajectories state dimension
+TARGET_STATE_DIM: 24
+# target trajectories input dimension
+TARGET_INPUT_DIM: 24
+# observation dimension
+OBSERVATION_DIM: 36
+# action dimension
+ACTION_DIM: 24
+# expert number
+EXPERT_NUM: 3
+# default state
+DEFAULT_STATE:
+ - 0.0 # normalized linear momentum x
+ - 0.0 # normalized linear momentum y
+ - 0.0 # normalized linear momentum z
+ - 0.0 # normalized angular momentum x
+ - 0.0 # normalized angular momentum y
+ - 0.0 # normalized angular momentum z
+ - 0.0 # position x
+ - 0.0 # position y
+ - 0.575 # position z
+ - 0.0 # orientation z
+ - 0.0 # orientation y
+ - 0.0 # orientation x
+ - -0.25 # joint position LF HAA
+ - 0.6 # joint position LF HFE
+ - -0.85 # joint position LF KFE
+ - -0.25 # joint position LH HAA
+ - -0.6 # joint position LH HFE
+ - 0.85 # joint position LH KFE
+ - 0.25 # joint position RF HAA
+ - 0.6 # joint position RF HFE
+ - -0.85 # joint position RF KFE
+ - 0.25 # joint position RH HAA
+ - -0.6 # joint position RH HFE
+ - 0.85 # joint position RH KFE
+# default target state
+DEFAULT_TARGET_STATE:
+ - 0.0 # normalized linear momentum x
+ - 0.0 # normalized linear momentum y
+ - 0.0 # normalized linear momentum z
+ - 0.0 # normalized angular momentum x
+ - 0.0 # normalized angular momentum y
+ - 0.0 # normalized angular momentum z
+ - 0.0 # position x
+ - 0.0 # position y
+ - 0.575 # position z
+ - 0.0 # orientation z
+ - 0.0 # orientation y
+ - 0.0 # orientation x
+ - -0.25 # joint position LF HAA
+ - 0.6 # joint position LF HFE
+ - -0.85 # joint position LF KFE
+ - -0.25 # joint position LH HAA
+ - -0.6 # joint position LH HFE
+ - 0.85 # joint position LH KFE
+ - 0.25 # joint position RF HAA
+ - 0.6 # joint position RF HFE
+ - -0.85 # joint position RF KFE
+ - 0.25 # joint position RH HAA
+ - -0.6 # joint position RH HFE
+ - 0.85 # joint position RH KFE
+#
+# loss
+#
+# epsilon to improve numerical stability of logs and denominators
+EPSILON: 1.e-8
+# whether to cheat by adding the gating loss
+CHEATING: True
+# parameter to control the relative importance of both loss types
+LAMBDA: 1.0
+# dictionary for the gating loss (assigns modes to experts responsible for the corresponding contact configuration)
+EXPERT_FOR_MODE:
+ 6: 1 # trot
+ 9: 2 # trot
+ 15: 0 # stance
+# input cost for behavioral cloning
+R:
+ - 0.001 # contact force LF x
+ - 0.001 # contact force LF y
+ - 0.001 # contact force LF z
+ - 0.001 # contact force LH x
+ - 0.001 # contact force LH y
+ - 0.001 # contact force LH z
+ - 0.001 # contact force RF x
+ - 0.001 # contact force RF y
+ - 0.001 # contact force RF z
+ - 0.001 # contact force RH x
+ - 0.001 # contact force RH y
+ - 0.001 # contact force RH z
+ - 5.0 # joint velocity LF HAA
+ - 5.0 # joint velocity LF HFE
+ - 5.0 # joint velocity LF KFE
+ - 5.0 # joint velocity LH HAA
+ - 5.0 # joint velocity LH HFE
+ - 5.0 # joint velocity LH KFE
+ - 5.0 # joint velocity RF HAA
+ - 5.0 # joint velocity RF HFE
+ - 5.0 # joint velocity RF KFE
+ - 5.0 # joint velocity RH HAA
+ - 5.0 # joint velocity RH HFE
+ - 5.0 # joint velocity RH KFE
+#
+# memory
+#
+# capacity of the memory
+CAPACITY: 400000
+#
+# policy
+#
+# observation scaling
+OBSERVATION_SCALING:
+ - 1.0 # swing phase LF
+ - 1.0 # swing phase LH
+ - 1.0 # swing phase RF
+ - 1.0 # swing phase RH
+ - 1.0 # swing phase rate LF
+ - 1.0 # swing phase rate LH
+ - 1.0 # swing phase rate RF
+ - 1.0 # swing phase rate RH
+ - 1.0 # sinusoidal bump LF
+ - 1.0 # sinusoidal bump LH
+ - 1.0 # sinusoidal bump RF
+ - 1.0 # sinusoidal bump RH
+ - 1.0 # normalized linear momentum x
+ - 1.0 # normalized linear momentum y
+ - 1.0 # normalized linear momentum z
+ - 1.0 # normalized angular momentum x
+ - 1.0 # normalized angular momentum y
+ - 1.0 # normalized angular momentum z
+ - 1.0 # position x
+ - 1.0 # position y
+ - 1.0 # position z
+ - 1.0 # orientation z
+ - 1.0 # orientation y
+ - 1.0 # orientation x
+ - 1.0 # joint position LF HAA
+ - 1.0 # joint position LF HFE
+ - 1.0 # joint position LF KFE
+ - 1.0 # joint position LH HAA
+ - 1.0 # joint position LH HFE
+ - 1.0 # joint position LH KFE
+ - 1.0 # joint position RF HAA
+ - 1.0 # joint position RF HFE
+ - 1.0 # joint position RF KFE
+ - 1.0 # joint position RH HAA
+ - 1.0 # joint position RH HFE
+ - 1.0 # joint position RH KFE
+# action scaling
+ACTION_SCALING:
+ - 100.0 # contact force LF x
+ - 100.0 # contact force LF y
+ - 100.0 # contact force LF z
+ - 100.0 # contact force LH x
+ - 100.0 # contact force LH y
+ - 100.0 # contact force LH z
+ - 100.0 # contact force RF x
+ - 100.0 # contact force RF y
+ - 100.0 # contact force RF z
+ - 100.0 # contact force RH x
+ - 100.0 # contact force RH y
+ - 100.0 # contact force RH z
+ - 10.0 # joint velocity LF HAA
+ - 10.0 # joint velocity LF HFE
+ - 10.0 # joint velocity LF KFE
+ - 10.0 # joint velocity LH HAA
+ - 10.0 # joint velocity LH HFE
+ - 10.0 # joint velocity LH KFE
+ - 10.0 # joint velocity RF HAA
+ - 10.0 # joint velocity RF HFE
+ - 10.0 # joint velocity RF KFE
+ - 10.0 # joint velocity RH HAA
+ - 10.0 # joint velocity RH HFE
+ - 10.0 # joint velocity RH KFE
+#
+# rollout
+#
+# RaiSim or TimeTriggered rollout for data generation and policy evaluation
+RAISIM: True
+# weights defining how often a gait is chosen for rollout
+WEIGHTS_FOR_GAITS:
+ stance: 1.0
+ trot_1: 2.0
+ trot_2: 2.0
+# settings for data generation
+DATA_GENERATION_TIME_STEP: 0.0025
+DATA_GENERATION_DURATION: 4.0
+DATA_GENERATION_DATA_DECIMATION: 4
+DATA_GENERATION_THREADS: 12
+DATA_GENERATION_TASKS: 12
+DATA_GENERATION_SAMPLES: 2
+DATA_GENERATION_SAMPLING_VARIANCE:
+ - 0.05 # normalized linear momentum x
+ - 0.05 # normalized linear momentum y
+ - 0.05 # normalized linear momentum z
+ - 0.00135648942 # normalized angular momentum x: 1.62079 / 52.1348 * 2.5 / 180.0 * pi
+ - 0.00404705526 # normalized angular momentum y: 4.83559 / 52.1348 * 2.5 / 180.0 * pi
+ - 0.00395351148 # normalized angular momentum z: 4.72382 / 52.1348 * 2.5 / 180.0 * pi
+ - 0.01 # position x
+ - 0.01 # position y
+ - 0.01 # position z
+ - 0.00872664625 # orientation z: 0.5 / 180.0 * pi
+ - 0.00872664625 # orientation y: 0.5 / 180.0 * pi
+ - 0.00872664625 # orientation x: 0.5 / 180.0 * pi
+ - 0.00872664625 # joint position LF HAA: 0.5 / 180.0 * pi
+ - 0.00872664625 # joint position LF HFE: 0.5 / 180.0 * pi
+ - 0.00872664625 # joint position LF KFE: 0.5 / 180.0 * pi
+ - 0.00872664625 # joint position LH HAA: 0.5 / 180.0 * pi
+ - 0.00872664625 # joint position LH HFE: 0.5 / 180.0 * pi
+ - 0.00872664625 # joint position LH KFE: 0.5 / 180.0 * pi
+ - 0.00872664625 # joint position RF HAA: 0.5 / 180.0 * pi
+ - 0.00872664625 # joint position RF HFE: 0.5 / 180.0 * pi
+ - 0.00872664625 # joint position RF KFE: 0.5 / 180.0 * pi
+ - 0.00872664625 # joint position RH HAA: 0.5 / 180.0 * pi
+ - 0.00872664625 # joint position RH HFE: 0.5 / 180.0 * pi
+ - 0.00872664625 # joint position RH KFE: 0.5 / 180.0 * pi
+# settings for computing metrics
+POLICY_EVALUATION_TIME_STEP: 0.0025
+POLICY_EVALUATION_DURATION: 4.0
+POLICY_EVALUATION_THREADS: 3
+POLICY_EVALUATION_TASKS: 3
+#
+# training
+#
+BATCH_SIZE: 128
+LEARNING_RATE: 1.e-3
+LEARNING_ITERATIONS: 100000
+GRADIENT_CLIPPING: True
+GRADIENT_CLIPPING_VALUE: 1.0
diff --git a/ocs2_mpcnet/ocs2_legged_robot_mpcnet/python/ocs2_legged_robot_mpcnet/mpcnet.py b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/python/ocs2_legged_robot_mpcnet/mpcnet.py
new file mode 100644
index 000000000..213b600f8
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/python/ocs2_legged_robot_mpcnet/mpcnet.py
@@ -0,0 +1,254 @@
+###############################################################################
+# Copyright (c) 2022, Farbod Farshidian. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+###############################################################################
+
+"""Legged robot MPC-Net class.
+
+Provides a class that handles the MPC-Net training for legged robot.
+"""
+
+import random
+import numpy as np
+from typing import Tuple
+
+from ocs2_mpcnet_core import helper
+from ocs2_mpcnet_core import mpcnet
+from ocs2_mpcnet_core import SystemObservationArray, ModeScheduleArray, TargetTrajectoriesArray
+
+
+class LeggedRobotMpcnet(mpcnet.Mpcnet):
+ """Legged robot MPC-Net.
+
+ Adds robot-specific methods for the MPC-Net training.
+ """
+
+ @staticmethod
+ def get_stance(duration: float) -> Tuple[np.ndarray, np.ndarray]:
+ """Get the stance gait.
+
+ Creates the stance event times and mode sequence for a certain time duration:
+ - contact schedule: STANCE
+ - swing schedule: -
+
+ Args:
+ duration: The duration of the mode schedule given by a float.
+
+ Returns:
+ A tuple containing the components of the mode schedule.
+ - event_times: The event times given by a NumPy array of shape (K-1) containing floats.
+ - mode_sequence: The mode sequence given by a NumPy array of shape (K) containing integers.
+ """
+ event_times_template = np.array([1.0], dtype=np.float64)
+ mode_sequence_template = np.array([15], dtype=np.uintp)
+ return helper.get_event_times_and_mode_sequence(15, duration, event_times_template, mode_sequence_template)
+
+ def get_random_initial_state_stance(self) -> np.ndarray:
+ """Get a random initial state for stance.
+
+ Samples a random initial state for the robot in the stance gait.
+
+ Returns:
+ x: A random initial state given by a NumPy array containing floats.
+ """
+ max_normalized_linear_momentum_x = 0.1
+ max_normalized_linear_momentum_y = 0.1
+ max_normalized_linear_momentum_z = 0.1
+ max_normalized_angular_momentum_x = 1.62079 / 52.1348 * 30.0 / 180.0 * np.pi
+ max_normalized_angular_momentum_y = 4.83559 / 52.1348 * 30.0 / 180.0 * np.pi
+ max_normalized_angular_momentum_z = 4.72382 / 52.1348 * 30.0 / 180.0 * np.pi
+ random_deviation = np.zeros(self.config.STATE_DIM)
+ random_deviation[0] = np.random.uniform(-max_normalized_linear_momentum_x, max_normalized_linear_momentum_x)
+ random_deviation[1] = np.random.uniform(-max_normalized_linear_momentum_y, max_normalized_linear_momentum_y)
+ random_deviation[2] = np.random.uniform(
+ -max_normalized_linear_momentum_z, max_normalized_linear_momentum_z / 2.0
+ )
+ random_deviation[3] = np.random.uniform(-max_normalized_angular_momentum_x, max_normalized_angular_momentum_x)
+ random_deviation[4] = np.random.uniform(-max_normalized_angular_momentum_y, max_normalized_angular_momentum_y)
+ random_deviation[5] = np.random.uniform(-max_normalized_angular_momentum_z, max_normalized_angular_momentum_z)
+ return np.array(self.config.DEFAULT_STATE) + random_deviation
+
+ def get_random_target_state_stance(self) -> np.ndarray:
+ """Get a random target state for stance.
+
+ Samples a random target state for the robot in the stance gait.
+
+ Returns:
+ x: A random target state given by a NumPy array containing floats.
+ """
+ max_position_z = 0.075
+ max_orientation_z = 25.0 / 180.0 * np.pi
+ max_orientation_y = 15.0 / 180.0 * np.pi
+ max_orientation_x = 25.0 / 180.0 * np.pi
+ random_deviation = np.zeros(self.config.TARGET_STATE_DIM)
+ random_deviation[8] = np.random.uniform(-max_position_z, max_position_z)
+ random_deviation[9] = np.random.uniform(-max_orientation_z, max_orientation_z)
+ random_deviation[10] = np.random.uniform(-max_orientation_y, max_orientation_y)
+ random_deviation[11] = np.random.uniform(-max_orientation_x, max_orientation_x)
+ return np.array(self.config.DEFAULT_TARGET_STATE) + random_deviation
+
+ @staticmethod
+ def get_trot_1(duration: float) -> Tuple[np.ndarray, np.ndarray]:
+ """Get the first trot gait.
+
+ Creates the first trot event times and mode sequence for a certain time duration:
+ - contact schedule: LF_RH, RF_LH
+ - swing schedule: RF_LH, LF_RH
+
+ Args:
+ duration: The duration of the mode schedule given by a float.
+
+ Returns:
+ A tuple containing the components of the mode schedule.
+ - event_times: The event times given by a NumPy array of shape (K-1) containing floats.
+ - mode_sequence: The mode sequence given by a NumPy array of shape (K) containing integers.
+ """
+ event_times_template = np.array([0.35, 0.7], dtype=np.float64)
+ mode_sequence_template = np.array([9, 6], dtype=np.uintp)
+ return helper.get_event_times_and_mode_sequence(15, duration, event_times_template, mode_sequence_template)
+
+ @staticmethod
+ def get_trot_2(duration: float) -> Tuple[np.ndarray, np.ndarray]:
+ """Get the second trot gait.
+
+ Creates the second trot event times and mode sequence for a certain time duration:
+ - contact schedule: RF_LH, LF_RH
+ - swing schedule: LF_RH, RF_LH
+
+ Args:
+ duration: The duration of the mode schedule given by a float.
+
+ Returns:
+ A tuple containing the components of the mode schedule.
+ - event_times: The event times given by a NumPy array of shape (K-1) containing floats.
+ - mode_sequence: The mode sequence given by a NumPy array of shape (K) containing integers.
+ """
+ event_times_template = np.array([0.35, 0.7], dtype=np.float64)
+ mode_sequence_template = np.array([6, 9], dtype=np.uintp)
+ return helper.get_event_times_and_mode_sequence(15, duration, event_times_template, mode_sequence_template)
+
+ def get_random_initial_state_trot(self) -> np.ndarray:
+ """Get a random initial state for trot.
+
+ Samples a random initial state for the robot in a trot gait.
+
+ Returns:
+ x: A random initial state given by a NumPy array containing floats.
+ """
+ max_normalized_linear_momentum_x = 0.5
+ max_normalized_linear_momentum_y = 0.25
+ max_normalized_linear_momentum_z = 0.25
+ max_normalized_angular_momentum_x = 1.62079 / 52.1348 * 60.0 / 180.0 * np.pi
+ max_normalized_angular_momentum_y = 4.83559 / 52.1348 * 60.0 / 180.0 * np.pi
+ max_normalized_angular_momentum_z = 4.72382 / 52.1348 * 35.0 / 180.0 * np.pi
+ random_deviation = np.zeros(self.config.STATE_DIM)
+ random_deviation[0] = np.random.uniform(-max_normalized_linear_momentum_x, max_normalized_linear_momentum_x)
+ random_deviation[1] = np.random.uniform(-max_normalized_linear_momentum_y, max_normalized_linear_momentum_y)
+ random_deviation[2] = np.random.uniform(
+ -max_normalized_linear_momentum_z, max_normalized_linear_momentum_z / 2.0
+ )
+ random_deviation[3] = np.random.uniform(-max_normalized_angular_momentum_x, max_normalized_angular_momentum_x)
+ random_deviation[4] = np.random.uniform(-max_normalized_angular_momentum_y, max_normalized_angular_momentum_y)
+ random_deviation[5] = np.random.uniform(-max_normalized_angular_momentum_z, max_normalized_angular_momentum_z)
+ return np.array(self.config.DEFAULT_STATE) + random_deviation
+
+ def get_random_target_state_trot(self) -> np.ndarray:
+ """Get a random target state for trot.
+
+ Samples a random target state for the robot in a trot gait.
+
+ Returns:
+ x: A random target state given by a NumPy array containing floats.
+ """
+ max_position_x = 0.3
+ max_position_y = 0.15
+ max_orientation_z = 30.0 / 180.0 * np.pi
+ random_deviation = np.zeros(self.config.TARGET_STATE_DIM)
+ random_deviation[6] = np.random.uniform(-max_position_x, max_position_x)
+ random_deviation[7] = np.random.uniform(-max_position_y, max_position_y)
+ random_deviation[9] = np.random.uniform(-max_orientation_z, max_orientation_z)
+ return np.array(self.config.DEFAULT_TARGET_STATE) + random_deviation
+
+ def get_tasks(
+ self, tasks_number: int, duration: float
+ ) -> Tuple[SystemObservationArray, ModeScheduleArray, TargetTrajectoriesArray]:
+ """Get tasks.
+
+ Get a random set of task that should be executed by the data generation or policy evaluation.
+
+ Args:
+ tasks_number: Number of tasks given by an integer.
+ duration: Duration of each task given by a float.
+
+ Returns:
+ A tuple containing the components of the task.
+ - initial_observations: The initial observations given by an OCS2 system observation array.
+ - mode_schedules: The desired mode schedules given by an OCS2 mode schedule array.
+ - target_trajectories: The desired target trajectories given by an OCS2 target trajectories array.
+ """
+ initial_mode = 15
+ initial_time = 0.0
+ initial_observations = helper.get_system_observation_array(tasks_number)
+ mode_schedules = helper.get_mode_schedule_array(tasks_number)
+ target_trajectories = helper.get_target_trajectories_array(tasks_number)
+ choices = random.choices(
+ list(self.config.WEIGHTS_FOR_GAITS.keys()),
+ k=tasks_number,
+ weights=list(self.config.WEIGHTS_FOR_GAITS.values()),
+ )
+ for i in range(tasks_number):
+ if choices[i] == "stance":
+ initial_observations[i] = helper.get_system_observation(
+ initial_mode, initial_time, self.get_random_initial_state_stance(), np.zeros(self.config.INPUT_DIM)
+ )
+ mode_schedules[i] = helper.get_mode_schedule(*self.get_stance(duration))
+ target_trajectories[i] = helper.get_target_trajectories(
+ duration * np.ones((1, 1)),
+ self.get_random_target_state_stance().reshape((1, self.config.TARGET_STATE_DIM)),
+ np.zeros((1, self.config.TARGET_INPUT_DIM)),
+ )
+ elif choices[i] == "trot_1":
+ initial_observations[i] = helper.get_system_observation(
+ initial_mode, initial_time, self.get_random_initial_state_trot(), np.zeros(self.config.INPUT_DIM)
+ )
+ mode_schedules[i] = helper.get_mode_schedule(*self.get_trot_1(duration))
+ target_trajectories[i] = helper.get_target_trajectories(
+ duration * np.ones((1, 1)),
+ self.get_random_target_state_trot().reshape((1, self.config.TARGET_STATE_DIM)),
+ np.zeros((1, self.config.TARGET_INPUT_DIM)),
+ )
+ elif choices[i] == "trot_2":
+ initial_observations[i] = helper.get_system_observation(
+ initial_mode, initial_time, self.get_random_initial_state_trot(), np.zeros(self.config.INPUT_DIM)
+ )
+ mode_schedules[i] = helper.get_mode_schedule(*self.get_trot_2(duration))
+ target_trajectories[i] = helper.get_target_trajectories(
+ duration * np.ones((1, 1)),
+ self.get_random_target_state_trot().reshape((1, self.config.TARGET_STATE_DIM)),
+ np.zeros((1, self.config.TARGET_INPUT_DIM)),
+ )
+ return initial_observations, mode_schedules, target_trajectories
diff --git a/ocs2_mpcnet/ocs2_legged_robot_mpcnet/python/ocs2_legged_robot_mpcnet/train.py b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/python/ocs2_legged_robot_mpcnet/train.py
new file mode 100644
index 000000000..f00533bba
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/python/ocs2_legged_robot_mpcnet/train.py
@@ -0,0 +1,73 @@
+#!/usr/bin/env python3
+
+###############################################################################
+# Copyright (c) 2022, Farbod Farshidian. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+###############################################################################
+
+"""Legged robot MPC-Net.
+
+Main script for training an MPC-Net policy for legged robot.
+"""
+
+import sys
+import os
+
+from ocs2_mpcnet_core.config import Config
+from ocs2_mpcnet_core.loss import HamiltonianLoss
+from ocs2_mpcnet_core.loss import CrossEntropyLoss
+from ocs2_mpcnet_core.memory import CircularMemory
+from ocs2_mpcnet_core.policy import MixtureOfNonlinearExpertsPolicy
+
+from ocs2_legged_robot_mpcnet import LeggedRobotMpcnet
+from ocs2_legged_robot_mpcnet import MpcnetInterface
+
+
+def main(root_dir: str, config_file_name: str) -> None:
+ # config
+ config = Config(os.path.join(root_dir, "config", config_file_name))
+ # interface
+ interface = MpcnetInterface(config.DATA_GENERATION_THREADS, config.POLICY_EVALUATION_THREADS, config.RAISIM)
+ # loss
+ experts_loss = HamiltonianLoss(config)
+ gating_loss = CrossEntropyLoss(config)
+ # memory
+ memory = CircularMemory(config)
+ # policy
+ policy = MixtureOfNonlinearExpertsPolicy(config)
+ # mpcnet
+ mpcnet = LeggedRobotMpcnet(root_dir, config, interface, memory, policy, experts_loss, gating_loss)
+ # train
+ mpcnet.train()
+
+
+if __name__ == "__main__":
+ root_dir = os.path.dirname(os.path.abspath(__file__))
+ if len(sys.argv) > 1:
+ main(root_dir, sys.argv[1])
+ else:
+ main(root_dir, "legged_robot.yaml")
diff --git a/ocs2_mpcnet/ocs2_legged_robot_mpcnet/setup.py b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/setup.py
new file mode 100644
index 000000000..2227298c6
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/setup.py
@@ -0,0 +1,8 @@
+#!/usr/bin/env python
+
+from setuptools import setup
+from catkin_pkg.python_setup import generate_distutils_setup
+
+setup_args = generate_distutils_setup(packages=["ocs2_legged_robot_mpcnet"], package_dir={"": "python"})
+
+setup(**setup_args)
diff --git a/ocs2_mpcnet/ocs2_legged_robot_mpcnet/src/LeggedRobotMpcnetDefinition.cpp b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/src/LeggedRobotMpcnetDefinition.cpp
new file mode 100644
index 000000000..8c19df882
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/src/LeggedRobotMpcnetDefinition.cpp
@@ -0,0 +1,123 @@
+/******************************************************************************
+Copyright (c) 2022, Farbod Farshidian. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+ * Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+******************************************************************************/
+
+#include "ocs2_legged_robot_mpcnet/LeggedRobotMpcnetDefinition.h"
+
+#include
+
+#include
+#include
+#include
+#include
+
+namespace ocs2 {
+namespace legged_robot {
+
+vector_t LeggedRobotMpcnetDefinition::getObservation(scalar_t t, const vector_t& x, const ModeSchedule& modeSchedule,
+ const TargetTrajectories& targetTrajectories) {
+ /**
+ * generalized time
+ */
+ const feet_array_t swingPhasePerLeg = getSwingPhasePerLeg(t, modeSchedule);
+ vector_t generalizedTime(3 * swingPhasePerLeg.size());
+ // phase
+ for (int i = 0; i < swingPhasePerLeg.size(); i++) {
+ if (swingPhasePerLeg[i].phase < 0.0) {
+ generalizedTime[i] = 0.0;
+ } else {
+ generalizedTime[i] = swingPhasePerLeg[i].phase;
+ }
+ }
+ // phase rate
+ for (int i = 0; i < swingPhasePerLeg.size(); i++) {
+ if (swingPhasePerLeg[i].phase < 0.0) {
+ generalizedTime[i + swingPhasePerLeg.size()] = 0.0;
+ } else {
+ generalizedTime[i + swingPhasePerLeg.size()] = 1.0 / swingPhasePerLeg[i].duration;
+ }
+ }
+ // sin(pi * phase)
+ for (int i = 0; i < swingPhasePerLeg.size(); i++) {
+ if (swingPhasePerLeg[i].phase < 0.0) {
+ generalizedTime[i + 2 * swingPhasePerLeg.size()] = 0.0;
+ } else {
+ generalizedTime[i + 2 * swingPhasePerLeg.size()] = std::sin(M_PI * swingPhasePerLeg[i].phase);
+ }
+ }
+ /**
+ * relative state
+ */
+ vector_t relativeState = x - targetTrajectories.getDesiredState(t);
+ const matrix3_t R = getRotationMatrixFromZyxEulerAngles(x.segment<3>(9)).transpose();
+ relativeState.segment<3>(0) = R * relativeState.segment<3>(0);
+ relativeState.segment<3>(3) = R * relativeState.segment<3>(3);
+ relativeState.segment<3>(6) = R * relativeState.segment<3>(6);
+ // TODO(areske): use quaternionDistance() for orientation error?
+ /**
+ * observation
+ */
+ vector_t observation(36);
+ observation << generalizedTime, relativeState;
+ return observation;
+}
+
+std::pair LeggedRobotMpcnetDefinition::getActionTransformation(scalar_t t, const vector_t& x,
+ const ModeSchedule& modeSchedule,
+ const TargetTrajectories& targetTrajectories) {
+ const matrix3_t R = getRotationMatrixFromZyxEulerAngles(x.segment<3>(9));
+ matrix_t actionTransformationMatrix = matrix_t::Identity(24, 24);
+ actionTransformationMatrix.block<3, 3>(0, 0) = R;
+ actionTransformationMatrix.block<3, 3>(3, 3) = R;
+ actionTransformationMatrix.block<3, 3>(6, 6) = R;
+ actionTransformationMatrix.block<3, 3>(9, 9) = R;
+ // TODO(areske): check why less robust with weight compensating bias?
+ // const auto contactFlags = modeNumber2StanceLeg(modeSchedule.modeAtTime(t));
+ // const vector_t actionTransformationVector = weightCompensatingInput(centroidalModelInfo_, contactFlags);
+ return {actionTransformationMatrix, vector_t::Zero(24)};
+}
+
+bool LeggedRobotMpcnetDefinition::isValid(scalar_t t, const vector_t& x, const ModeSchedule& modeSchedule,
+ const TargetTrajectories& targetTrajectories) {
+ const vector_t deviation = x - defaultState_;
+ if (std::abs(deviation[8]) > allowedHeightDeviation_) {
+ std::cerr << "[LeggedRobotMpcnetDefinition::isValid] height diverged: " << x[8] << "\n";
+ return false;
+ } else if (std::abs(deviation[10]) > allowedPitchDeviation_) {
+ std::cerr << "[LeggedRobotMpcnetDefinition::isValid] pitch diverged: " << x[10] << "\n";
+ return false;
+ } else if (std::abs(deviation[11]) > allowedRollDeviation_) {
+ std::cerr << "[LeggedRobotMpcnetDefinition::isValid] roll diverged: " << x[11] << "\n";
+ return false;
+ } else {
+ return true;
+ }
+}
+
+} // namespace legged_robot
+} // namespace ocs2
diff --git a/ocs2_mpcnet/ocs2_legged_robot_mpcnet/src/LeggedRobotMpcnetDummyNode.cpp b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/src/LeggedRobotMpcnetDummyNode.cpp
new file mode 100644
index 000000000..e16898bac
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/src/LeggedRobotMpcnetDummyNode.cpp
@@ -0,0 +1,165 @@
+/******************************************************************************
+Copyright (c) 2022, Farbod Farshidian. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+ * Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+******************************************************************************/
+
+#include
+#include
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "ocs2_legged_robot_mpcnet/LeggedRobotMpcnetDefinition.h"
+
+using namespace ocs2;
+using namespace legged_robot;
+
+int main(int argc, char** argv) {
+ const std::string robotName = "legged_robot";
+
+ // initialize ros node
+ ros::init(argc, argv, robotName + "_mpcnet_dummy");
+ ros::NodeHandle nodeHandle;
+ // Get node parameters
+ bool useRaisim;
+ std::string taskFile, urdfFile, referenceFile, raisimFile, resourcePath, policyFile;
+ nodeHandle.getParam("/taskFile", taskFile);
+ nodeHandle.getParam("/urdfFile", urdfFile);
+ nodeHandle.getParam("/referenceFile", referenceFile);
+ nodeHandle.getParam("/raisimFile", raisimFile);
+ nodeHandle.getParam("/resourcePath", resourcePath);
+ nodeHandle.getParam("/policyFile", policyFile);
+ nodeHandle.getParam("/useRaisim", useRaisim);
+
+ // legged robot interface
+ LeggedRobotInterface leggedRobotInterface(taskFile, urdfFile, referenceFile);
+
+ // gait receiver
+ auto gaitReceiverPtr =
+ std::make_shared(nodeHandle, leggedRobotInterface.getSwitchedModelReferenceManagerPtr()->getGaitSchedule(), robotName);
+
+ // ROS reference manager
+ auto rosReferenceManagerPtr = std::make_shared(robotName, leggedRobotInterface.getReferenceManagerPtr());
+ rosReferenceManagerPtr->subscribe(nodeHandle);
+
+ // policy (MPC-Net controller)
+ auto onnxEnvironmentPtr = ocs2::mpcnet::createOnnxEnvironment();
+ std::shared_ptr mpcnetDefinitionPtr(new LeggedRobotMpcnetDefinition(leggedRobotInterface));
+ std::unique_ptr mpcnetControllerPtr(
+ new ocs2::mpcnet::MpcnetOnnxController(mpcnetDefinitionPtr, rosReferenceManagerPtr, onnxEnvironmentPtr));
+ mpcnetControllerPtr->loadPolicyModel(policyFile);
+
+ // rollout
+ std::unique_ptr rolloutPtr;
+ raisim::HeightMap* terrainPtr = nullptr;
+ std::unique_ptr heightmapPub;
+ std::unique_ptr conversions;
+ if (useRaisim) {
+ conversions.reset(new LeggedRobotRaisimConversions(leggedRobotInterface.getPinocchioInterface(),
+ leggedRobotInterface.getCentroidalModelInfo(),
+ leggedRobotInterface.getInitialState()));
+ RaisimRolloutSettings raisimRolloutSettings(raisimFile, "rollout", true);
+ conversions->loadSettings(raisimFile, "rollout", true);
+ rolloutPtr.reset(new RaisimRollout(
+ urdfFile, resourcePath,
+ [&](const vector_t& state, const vector_t& input) { return conversions->stateToRaisimGenCoordGenVel(state, input); },
+ [&](const Eigen::VectorXd& q, const Eigen::VectorXd& dq) { return conversions->raisimGenCoordGenVelToState(q, dq); },
+ [&](double time, const vector_t& input, const vector_t& state, const Eigen::VectorXd& q, const Eigen::VectorXd& dq) {
+ return conversions->inputToRaisimGeneralizedForce(time, input, state, q, dq);
+ },
+ nullptr, raisimRolloutSettings,
+ [&](double time, const vector_t& input, const vector_t& state, const Eigen::VectorXd& q, const Eigen::VectorXd& dq) {
+ return conversions->inputToRaisimPdTargets(time, input, state, q, dq);
+ }));
+ // terrain
+ if (raisimRolloutSettings.generateTerrain_) {
+ raisim::TerrainProperties terrainProperties;
+ terrainProperties.zScale = raisimRolloutSettings.terrainRoughness_;
+ terrainProperties.seed = raisimRolloutSettings.terrainSeed_;
+ terrainPtr = static_cast(rolloutPtr.get())->generateTerrain(terrainProperties);
+ conversions->setTerrain(*terrainPtr);
+ heightmapPub.reset(new ocs2::RaisimHeightmapRosConverter());
+ heightmapPub->publishGridmap(*terrainPtr, "odom");
+ }
+ } else {
+ rolloutPtr.reset(leggedRobotInterface.getRollout().clone());
+ }
+
+ // observer
+ std::shared_ptr mpcnetDummyObserverRosPtr(
+ new ocs2::mpcnet::MpcnetDummyObserverRos(nodeHandle, robotName));
+
+ // visualization
+ CentroidalModelPinocchioMapping pinocchioMapping(leggedRobotInterface.getCentroidalModelInfo());
+ PinocchioEndEffectorKinematics endEffectorKinematics(leggedRobotInterface.getPinocchioInterface(), pinocchioMapping,
+ leggedRobotInterface.modelSettings().contactNames3DoF);
+ std::shared_ptr leggedRobotVisualizerPtr;
+ if (useRaisim) {
+ leggedRobotVisualizerPtr.reset(new LeggedRobotRaisimVisualizer(
+ leggedRobotInterface.getPinocchioInterface(), leggedRobotInterface.getCentroidalModelInfo(), endEffectorKinematics, nodeHandle));
+ static_cast(leggedRobotVisualizerPtr.get())->updateTerrain();
+ } else {
+ leggedRobotVisualizerPtr.reset(new LeggedRobotVisualizer(
+ leggedRobotInterface.getPinocchioInterface(), leggedRobotInterface.getCentroidalModelInfo(), endEffectorKinematics, nodeHandle));
+ }
+
+ // MPC-Net dummy loop ROS
+ const scalar_t controlFrequency = leggedRobotInterface.mpcSettings().mrtDesiredFrequency_;
+ const scalar_t rosFrequency = leggedRobotInterface.mpcSettings().mpcDesiredFrequency_;
+ ocs2::mpcnet::MpcnetDummyLoopRos mpcnetDummyLoopRos(controlFrequency, rosFrequency, std::move(mpcnetControllerPtr), std::move(rolloutPtr),
+ rosReferenceManagerPtr);
+ mpcnetDummyLoopRos.addObserver(mpcnetDummyObserverRosPtr);
+ mpcnetDummyLoopRos.addObserver(leggedRobotVisualizerPtr);
+ mpcnetDummyLoopRos.addSynchronizedModule(gaitReceiverPtr);
+
+ // initial system observation
+ SystemObservation systemObservation;
+ systemObservation.mode = ModeNumber::STANCE;
+ systemObservation.time = 0.0;
+ systemObservation.state = leggedRobotInterface.getInitialState();
+ systemObservation.input = vector_t::Zero(leggedRobotInterface.getCentroidalModelInfo().inputDim);
+
+ // initial target trajectories
+ const TargetTrajectories targetTrajectories({systemObservation.time}, {systemObservation.state}, {systemObservation.input});
+
+ // run MPC-Net dummy loop ROS
+ mpcnetDummyLoopRos.run(systemObservation, targetTrajectories);
+
+ // successful exit
+ return 0;
+}
diff --git a/ocs2_mpcnet/ocs2_legged_robot_mpcnet/src/LeggedRobotMpcnetInterface.cpp b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/src/LeggedRobotMpcnetInterface.cpp
new file mode 100644
index 000000000..6c068bca9
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/src/LeggedRobotMpcnetInterface.cpp
@@ -0,0 +1,143 @@
+/******************************************************************************
+Copyright (c) 2022, Farbod Farshidian. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+ * Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+******************************************************************************/
+
+#include "ocs2_legged_robot_mpcnet/LeggedRobotMpcnetInterface.h"
+
+#include
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "ocs2_legged_robot_mpcnet/LeggedRobotMpcnetDefinition.h"
+
+namespace ocs2 {
+namespace legged_robot {
+
+LeggedRobotMpcnetInterface::LeggedRobotMpcnetInterface(size_t nDataGenerationThreads, size_t nPolicyEvaluationThreads, bool raisim) {
+ // create ONNX environment
+ auto onnxEnvironmentPtr = ocs2::mpcnet::createOnnxEnvironment();
+ // paths to files
+ const std::string taskFile = ros::package::getPath("ocs2_legged_robot") + "/config/mpc/task.info";
+ const std::string urdfFile = ros::package::getPath("ocs2_robotic_assets") + "/resources/anymal_c/urdf/anymal.urdf";
+ const std::string referenceFile = ros::package::getPath("ocs2_legged_robot") + "/config/command/reference.info";
+ const std::string raisimFile = ros::package::getPath("ocs2_legged_robot_raisim") + "/config/raisim.info";
+ const std::string resourcePath = ros::package::getPath("ocs2_robotic_assets") + "/resources/anymal_c/meshes";
+ // set up MPC-Net rollout manager for data generation and policy evaluation
+ std::vector> mpcPtrs;
+ std::vector> mpcnetPtrs;
+ std::vector> rolloutPtrs;
+ std::vector> mpcnetDefinitionPtrs;
+ std::vector> referenceManagerPtrs;
+ mpcPtrs.reserve(nDataGenerationThreads + nPolicyEvaluationThreads);
+ mpcnetPtrs.reserve(nDataGenerationThreads + nPolicyEvaluationThreads);
+ rolloutPtrs.reserve(nDataGenerationThreads + nPolicyEvaluationThreads);
+ mpcnetDefinitionPtrs.reserve(nDataGenerationThreads + nPolicyEvaluationThreads);
+ referenceManagerPtrs.reserve(nDataGenerationThreads + nPolicyEvaluationThreads);
+ for (int i = 0; i < (nDataGenerationThreads + nPolicyEvaluationThreads); i++) {
+ leggedRobotInterfacePtrs_.push_back(std::unique_ptr(new LeggedRobotInterface(taskFile, urdfFile, referenceFile)));
+ std::shared_ptr mpcnetDefinitionPtr(new LeggedRobotMpcnetDefinition(*leggedRobotInterfacePtrs_[i]));
+ mpcPtrs.push_back(getMpc(*leggedRobotInterfacePtrs_[i]));
+ mpcnetPtrs.push_back(std::unique_ptr(new ocs2::mpcnet::MpcnetOnnxController(
+ mpcnetDefinitionPtr, leggedRobotInterfacePtrs_[i]->getReferenceManagerPtr(), onnxEnvironmentPtr)));
+ if (raisim) {
+ RaisimRolloutSettings raisimRolloutSettings(raisimFile, "rollout");
+ raisimRolloutSettings.portNumber_ += i;
+ leggedRobotRaisimConversionsPtrs_.push_back(std::unique_ptr(new LeggedRobotRaisimConversions(
+ leggedRobotInterfacePtrs_[i]->getPinocchioInterface(), leggedRobotInterfacePtrs_[i]->getCentroidalModelInfo(),
+ leggedRobotInterfacePtrs_[i]->getInitialState())));
+ leggedRobotRaisimConversionsPtrs_[i]->loadSettings(raisimFile, "rollout", true);
+ rolloutPtrs.push_back(std::unique_ptr(new RaisimRollout(
+ urdfFile, resourcePath,
+ [&, i](const vector_t& state, const vector_t& input) {
+ return leggedRobotRaisimConversionsPtrs_[i]->stateToRaisimGenCoordGenVel(state, input);
+ },
+ [&, i](const Eigen::VectorXd& q, const Eigen::VectorXd& dq) {
+ return leggedRobotRaisimConversionsPtrs_[i]->raisimGenCoordGenVelToState(q, dq);
+ },
+ [&, i](double time, const vector_t& input, const vector_t& state, const Eigen::VectorXd& q, const Eigen::VectorXd& dq) {
+ return leggedRobotRaisimConversionsPtrs_[i]->inputToRaisimGeneralizedForce(time, input, state, q, dq);
+ },
+ nullptr, raisimRolloutSettings,
+ [&, i](double time, const vector_t& input, const vector_t& state, const Eigen::VectorXd& q, const Eigen::VectorXd& dq) {
+ return leggedRobotRaisimConversionsPtrs_[i]->inputToRaisimPdTargets(time, input, state, q, dq);
+ })));
+ if (raisimRolloutSettings.generateTerrain_) {
+ raisim::TerrainProperties terrainProperties;
+ terrainProperties.zScale = raisimRolloutSettings.terrainRoughness_;
+ terrainProperties.seed = raisimRolloutSettings.terrainSeed_ + i;
+ auto terrainPtr = static_cast(rolloutPtrs[i].get())->generateTerrain(terrainProperties);
+ leggedRobotRaisimConversionsPtrs_[i]->setTerrain(*terrainPtr);
+ }
+ } else {
+ rolloutPtrs.push_back(std::unique_ptr(leggedRobotInterfacePtrs_[i]->getRollout().clone()));
+ }
+ mpcnetDefinitionPtrs.push_back(mpcnetDefinitionPtr);
+ referenceManagerPtrs.push_back(leggedRobotInterfacePtrs_[i]->getReferenceManagerPtr());
+ }
+ mpcnetRolloutManagerPtr_.reset(new ocs2::mpcnet::MpcnetRolloutManager(nDataGenerationThreads, nPolicyEvaluationThreads,
+ std::move(mpcPtrs), std::move(mpcnetPtrs), std::move(rolloutPtrs),
+ mpcnetDefinitionPtrs, referenceManagerPtrs));
+}
+
+/******************************************************************************************************/
+/******************************************************************************************************/
+/******************************************************************************************************/
+std::unique_ptr LeggedRobotMpcnetInterface::getMpc(LeggedRobotInterface& leggedRobotInterface) {
+ // ensure MPC and DDP settings are as needed for MPC-Net
+ const auto mpcSettings = [&]() {
+ auto settings = leggedRobotInterface.mpcSettings();
+ settings.debugPrint_ = false;
+ settings.coldStart_ = false;
+ return settings;
+ }();
+ const auto ddpSettings = [&]() {
+ auto settings = leggedRobotInterface.ddpSettings();
+ settings.algorithm_ = ocs2::ddp::Algorithm::SLQ;
+ settings.nThreads_ = 1;
+ settings.displayInfo_ = false;
+ settings.displayShortSummary_ = false;
+ settings.checkNumericalStability_ = false;
+ settings.debugPrintRollout_ = false;
+ settings.useFeedbackPolicy_ = true;
+ return settings;
+ }();
+ // create one MPC instance
+ std::unique_ptr mpcPtr(new GaussNewtonDDP_MPC(mpcSettings, ddpSettings, leggedRobotInterface.getRollout(),
+ leggedRobotInterface.getOptimalControlProblem(),
+ leggedRobotInterface.getInitializer()));
+ mpcPtr->getSolverPtr()->setReferenceManager(leggedRobotInterface.getReferenceManagerPtr());
+ return mpcPtr;
+}
+
+} // namespace legged_robot
+} // namespace ocs2
diff --git a/ocs2_mpcnet/ocs2_legged_robot_mpcnet/src/LeggedRobotMpcnetPybindings.cpp b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/src/LeggedRobotMpcnetPybindings.cpp
new file mode 100644
index 000000000..f170bde7a
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/src/LeggedRobotMpcnetPybindings.cpp
@@ -0,0 +1,34 @@
+/******************************************************************************
+Copyright (c) 2022, Farbod Farshidian. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+ * Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+******************************************************************************/
+
+#include
+
+#include "ocs2_legged_robot_mpcnet/LeggedRobotMpcnetInterface.h"
+
+CREATE_ROBOT_MPCNET_PYTHON_BINDINGS(ocs2::legged_robot::LeggedRobotMpcnetInterface, LeggedRobotMpcnetPybindings)
diff --git a/ocs2_mpcnet/ocs2_mpcnet/CMakeLists.txt b/ocs2_mpcnet/ocs2_mpcnet/CMakeLists.txt
new file mode 100644
index 000000000..a9dba76cb
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_mpcnet/CMakeLists.txt
@@ -0,0 +1,4 @@
+cmake_minimum_required(VERSION 3.0.2)
+project(ocs2_mpcnet)
+find_package(catkin REQUIRED)
+catkin_metapackage()
diff --git a/ocs2_mpcnet/ocs2_mpcnet/package.xml b/ocs2_mpcnet/ocs2_mpcnet/package.xml
new file mode 100644
index 000000000..7eec28861
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_mpcnet/package.xml
@@ -0,0 +1,24 @@
+
+
+ ocs2_mpcnet
+ 0.0.0
+ The ocs2_mpcnet metapackage
+
+ Alexander Reske
+
+ Farbod Farshidian
+ Alexander Reske
+
+ BSD-3
+
+ catkin
+
+ ocs2_mpcnet_core
+ ocs2_ballbot_mpcnet
+ ocs2_legged_robot_mpcnet
+
+
+
+
+
+
diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/CMakeLists.txt b/ocs2_mpcnet/ocs2_mpcnet_core/CMakeLists.txt
new file mode 100644
index 000000000..7b91fd000
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_mpcnet_core/CMakeLists.txt
@@ -0,0 +1,118 @@
+cmake_minimum_required(VERSION 3.0.2)
+project(ocs2_mpcnet_core)
+
+set(CATKIN_PACKAGE_DEPENDENCIES
+ pybind11_catkin
+ ocs2_mpc
+ ocs2_python_interface
+ ocs2_ros_interfaces
+)
+
+find_package(catkin REQUIRED COMPONENTS
+ ${CATKIN_PACKAGE_DEPENDENCIES}
+)
+
+find_package(onnxruntime REQUIRED)
+
+# Generate compile_commands.json for clang tools
+set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
+
+###################################
+## catkin specific configuration ##
+###################################
+
+catkin_package(
+ INCLUDE_DIRS
+ include
+ LIBRARIES
+ ${PROJECT_NAME}
+ CATKIN_DEPENDS
+ ${CATKIN_PACKAGE_DEPENDENCIES}
+ DEPENDS
+ onnxruntime
+)
+
+###########
+## Build ##
+###########
+
+include_directories(
+ include
+ ${catkin_INCLUDE_DIRS}
+)
+
+# main library
+add_library(${PROJECT_NAME}
+ src/control/MpcnetBehavioralController.cpp
+ src/control/MpcnetOnnxController.cpp
+ src/dummy/MpcnetDummyLoopRos.cpp
+ src/dummy/MpcnetDummyObserverRos.cpp
+ src/rollout/MpcnetDataGeneration.cpp
+ src/rollout/MpcnetPolicyEvaluation.cpp
+ src/rollout/MpcnetRolloutBase.cpp
+ src/rollout/MpcnetRolloutManager.cpp
+ src/MpcnetInterfaceBase.cpp
+)
+add_dependencies(${PROJECT_NAME}
+ ${catkin_EXPORTED_TARGETS}
+)
+target_link_libraries(${PROJECT_NAME}
+ ${catkin_LIBRARIES}
+ onnxruntime
+)
+
+# python bindings
+pybind11_add_module(MpcnetPybindings SHARED
+ src/MpcnetPybindings.cpp
+)
+add_dependencies(MpcnetPybindings
+ ${PROJECT_NAME}
+ ${catkin_EXPORTED_TARGETS}
+)
+target_link_libraries(MpcnetPybindings PRIVATE
+ ${PROJECT_NAME}
+ ${catkin_LIBRARIES}
+)
+set_target_properties(MpcnetPybindings
+ PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CATKIN_DEVEL_PREFIX}/${CATKIN_PACKAGE_PYTHON_DESTINATION}
+)
+
+catkin_python_setup()
+
+#########################
+### CLANG TOOLING ###
+#########################
+find_package(cmake_clang_tools QUIET)
+if(cmake_clang_tools_FOUND)
+ message(STATUS "Run clang tooling for target ocs2_mpcnet_core")
+ add_clang_tooling(
+ TARGETS ${PROJECT_NAME}
+ SOURCE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/src ${CMAKE_CURRENT_SOURCE_DIR}/include
+ CT_HEADER_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/include
+ CF_WERROR
+)
+endif(cmake_clang_tools_FOUND)
+
+#############
+## Install ##
+#############
+
+install(TARGETS ${PROJECT_NAME}
+ ARCHIVE DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION}
+ LIBRARY DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION}
+ RUNTIME DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
+)
+
+install(DIRECTORY include/${PROJECT_NAME}/
+ DESTINATION ${CATKIN_PACKAGE_INCLUDE_DESTINATION})
+
+install(TARGETS MpcnetPybindings
+ ARCHIVE DESTINATION ${CATKIN_PACKAGE_PYTHON_DESTINATION}
+ LIBRARY DESTINATION ${CATKIN_PACKAGE_PYTHON_DESTINATION}
+)
+
+#############
+## Testing ##
+#############
+
+# TODO(areske)
diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/MpcnetDefinitionBase.h b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/MpcnetDefinitionBase.h
new file mode 100644
index 000000000..de322bcc2
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/MpcnetDefinitionBase.h
@@ -0,0 +1,100 @@
+/******************************************************************************
+Copyright (c) 2022, Farbod Farshidian. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+ * Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+******************************************************************************/
+
+#pragma once
+
+#include
+#include
+
+namespace ocs2 {
+namespace mpcnet {
+
+/**
+ * Base class for MPC-Net definitions.
+ */
+class MpcnetDefinitionBase {
+ public:
+ /**
+ * Default constructor.
+ */
+ MpcnetDefinitionBase() = default;
+
+ /**
+ * Default destructor.
+ */
+ virtual ~MpcnetDefinitionBase() = default;
+
+ /**
+ * Deleted copy constructor.
+ */
+ MpcnetDefinitionBase(const MpcnetDefinitionBase&) = delete;
+
+ /**
+ * Deleted copy assignment.
+ */
+ MpcnetDefinitionBase& operator=(const MpcnetDefinitionBase&) = delete;
+
+ /**
+ * Get the observation.
+ * @note The observation o is the input to the policy.
+ * @param[in] t : Absolute time.
+ * @param[in] x : Robot state.
+ * @param[in] modeSchedule : Mode schedule.
+ * @param[in] targetTrajectories : Target trajectories.
+ * @return The observation.
+ */
+ virtual vector_t getObservation(scalar_t t, const vector_t& x, const ModeSchedule& modeSchedule,
+ const TargetTrajectories& targetTrajectories) = 0;
+
+ /**
+ * Get the action transformation.
+ * @note Used for computing the control input u = A * a + b from the action a predicted by the policy.
+ * @param[in] t : Absolute time.
+ * @param[in] x : Robot state.
+ * @param[in] modeSchedule : Mode schedule.
+ * @param[in] targetTrajectories : Target trajectories.
+ * @return The action transformation pair {A, b}.
+ */
+ virtual std::pair getActionTransformation(scalar_t t, const vector_t& x, const ModeSchedule& modeSchedule,
+ const TargetTrajectories& targetTrajectories) = 0;
+
+ /**
+ * Check if the tuple (t, x, modeSchedule, targetTrajectories) is valid.
+ * @note E.g., check if the state diverged or if tracking is poor.
+ * @param[in] t : Absolute time.
+ * @param[in] x : Robot state.
+ * @param[in] modeSchedule : Mode schedule.
+ * @param[in] targetTrajectories : Target trajectories.
+ * @return True if valid.
+ */
+ virtual bool isValid(scalar_t t, const vector_t& x, const ModeSchedule& modeSchedule, const TargetTrajectories& targetTrajectories) = 0;
+};
+
+} // namespace mpcnet
+} // namespace ocs2
diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/MpcnetInterfaceBase.h b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/MpcnetInterfaceBase.h
new file mode 100644
index 000000000..ec4b72d00
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/MpcnetInterfaceBase.h
@@ -0,0 +1,91 @@
+/******************************************************************************
+Copyright (c) 2022, Farbod Farshidian. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+ * Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+******************************************************************************/
+
+#pragma once
+
+#include "ocs2_mpcnet_core/rollout/MpcnetRolloutManager.h"
+
+namespace ocs2 {
+namespace mpcnet {
+
+/**
+ * Base class for all MPC-Net interfaces between C++ and Python.
+ */
+class MpcnetInterfaceBase {
+ public:
+ /**
+ * Default destructor.
+ */
+ virtual ~MpcnetInterfaceBase() = default;
+
+ /**
+ * @see MpcnetRolloutManager::startDataGeneration()
+ */
+ void startDataGeneration(scalar_t alpha, const std::string& policyFilePath, scalar_t timeStep, size_t dataDecimation, size_t nSamples,
+ const matrix_t& samplingCovariance, const std::vector& initialObservations,
+ const std::vector& modeSchedules, const std::vector& targetTrajectories);
+
+ /**
+ * @see MpcnetRolloutManager::isDataGenerationDone()
+ */
+ bool isDataGenerationDone();
+
+ /**
+ * @see MpcnetRolloutManager::getGeneratedData()
+ */
+ data_array_t getGeneratedData();
+
+ /**
+ * @see MpcnetRolloutManager::startPolicyEvaluation()
+ */
+ void startPolicyEvaluation(scalar_t alpha, const std::string& policyFilePath, scalar_t timeStep,
+ const std::vector& initialObservations, const std::vector& modeSchedules,
+ const std::vector& targetTrajectories);
+
+ /**
+ * @see MpcnetRolloutManager::isPolicyEvaluationDone()
+ */
+ bool isPolicyEvaluationDone();
+
+ /**
+ * @see MpcnetRolloutManager::getComputedMetrics()
+ */
+ metrics_array_t getComputedMetrics();
+
+ protected:
+ /**
+ * Default constructor.
+ */
+ MpcnetInterfaceBase() = default;
+
+ std::unique_ptr mpcnetRolloutManagerPtr_;
+};
+
+} // namespace mpcnet
+} // namespace ocs2
diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/MpcnetPybindMacros.h b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/MpcnetPybindMacros.h
new file mode 100644
index 000000000..632fa4717
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/MpcnetPybindMacros.h
@@ -0,0 +1,133 @@
+/******************************************************************************
+Copyright (c) 2022, Farbod Farshidian. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+ * Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+******************************************************************************/
+
+#pragma once
+
+#include
+#include
+#include
+
+#include
+#include
+
+#include "ocs2_mpcnet_core/rollout/MpcnetData.h"
+#include "ocs2_mpcnet_core/rollout/MpcnetMetrics.h"
+
+using namespace pybind11::literals;
+
+/**
+ * Convenience macro to bind general MPC-Net functionalities and other classes with all required vectors.
+ */
+#define CREATE_MPCNET_PYTHON_BINDINGS(LIB_NAME) \
+ /* make vector types opaque so they are not converted to python lists */ \
+ PYBIND11_MAKE_OPAQUE(ocs2::size_array_t) \
+ PYBIND11_MAKE_OPAQUE(ocs2::scalar_array_t) \
+ PYBIND11_MAKE_OPAQUE(ocs2::vector_array_t) \
+ PYBIND11_MAKE_OPAQUE(ocs2::matrix_array_t) \
+ PYBIND11_MAKE_OPAQUE(std::vector) \
+ PYBIND11_MAKE_OPAQUE(std::vector) \
+ PYBIND11_MAKE_OPAQUE(std::vector) \
+ PYBIND11_MAKE_OPAQUE(ocs2::mpcnet::data_array_t) \
+ PYBIND11_MAKE_OPAQUE(ocs2::mpcnet::metrics_array_t) \
+ /* create a python module */ \
+ PYBIND11_MODULE(LIB_NAME, m) { \
+ /* bind vector types so they can be used natively in python */ \
+ VECTOR_TYPE_BINDING(ocs2::size_array_t, "size_array") \
+ VECTOR_TYPE_BINDING(ocs2::scalar_array_t, "scalar_array") \
+ VECTOR_TYPE_BINDING(ocs2::vector_array_t, "vector_array") \
+ VECTOR_TYPE_BINDING(ocs2::matrix_array_t, "matrix_array") \
+ VECTOR_TYPE_BINDING(std::vector, "SystemObservationArray") \
+ VECTOR_TYPE_BINDING(std::vector, "ModeScheduleArray") \
+ VECTOR_TYPE_BINDING(std::vector, "TargetTrajectoriesArray") \
+ VECTOR_TYPE_BINDING(ocs2::mpcnet::data_array_t, "DataArray") \
+ VECTOR_TYPE_BINDING(ocs2::mpcnet::metrics_array_t, "MetricsArray") \
+ /* bind approximation classes */ \
+ pybind11::class_(m, "ScalarFunctionQuadraticApproximation") \
+ .def_readwrite("f", &ocs2::ScalarFunctionQuadraticApproximation::f) \
+ .def_readwrite("dfdx", &ocs2::ScalarFunctionQuadraticApproximation::dfdx) \
+ .def_readwrite("dfdu", &ocs2::ScalarFunctionQuadraticApproximation::dfdu) \
+ .def_readwrite("dfdxx", &ocs2::ScalarFunctionQuadraticApproximation::dfdxx) \
+ .def_readwrite("dfdux", &ocs2::ScalarFunctionQuadraticApproximation::dfdux) \
+ .def_readwrite("dfduu", &ocs2::ScalarFunctionQuadraticApproximation::dfduu); \
+ /* bind system observation struct */ \
+ pybind11::class_(m, "SystemObservation") \
+ .def(pybind11::init<>()) \
+ .def_readwrite("mode", &ocs2::SystemObservation::mode) \
+ .def_readwrite("time", &ocs2::SystemObservation::time) \
+ .def_readwrite("state", &ocs2::SystemObservation::state) \
+ .def_readwrite("input", &ocs2::SystemObservation::input); \
+ /* bind mode schedule struct */ \
+ pybind11::class_(m, "ModeSchedule") \
+ .def(pybind11::init()) \
+ .def_readwrite("eventTimes", &ocs2::ModeSchedule::eventTimes) \
+ .def_readwrite("modeSequence", &ocs2::ModeSchedule::modeSequence); \
+ /* bind target trajectories class */ \
+ pybind11::class_(m, "TargetTrajectories") \
+ .def(pybind11::init()) \
+ .def_readwrite("timeTrajectory", &ocs2::TargetTrajectories::timeTrajectory) \
+ .def_readwrite("stateTrajectory", &ocs2::TargetTrajectories::stateTrajectory) \
+ .def_readwrite("inputTrajectory", &ocs2::TargetTrajectories::inputTrajectory); \
+ /* bind data point struct */ \
+ pybind11::class_(m, "DataPoint") \
+ .def(pybind11::init<>()) \
+ .def_readwrite("mode", &ocs2::mpcnet::data_point_t::mode) \
+ .def_readwrite("t", &ocs2::mpcnet::data_point_t::t) \
+ .def_readwrite("x", &ocs2::mpcnet::data_point_t::x) \
+ .def_readwrite("u", &ocs2::mpcnet::data_point_t::u) \
+ .def_readwrite("observation", &ocs2::mpcnet::data_point_t::observation) \
+ .def_readwrite("actionTransformation", &ocs2::mpcnet::data_point_t::actionTransformation) \
+ .def_readwrite("hamiltonian", &ocs2::mpcnet::data_point_t::hamiltonian); \
+ /* bind metrics struct */ \
+ pybind11::class_(m, "Metrics") \
+ .def(pybind11::init<>()) \
+ .def_readwrite("survivalTime", &ocs2::mpcnet::metrics_t::survivalTime) \
+ .def_readwrite("incurredHamiltonian", &ocs2::mpcnet::metrics_t::incurredHamiltonian); \
+ }
+
+/**
+ * Convenience macro to bind robot MPC-Net interface.
+ */
+#define CREATE_ROBOT_MPCNET_PYTHON_BINDINGS(MPCNET_INTERFACE, LIB_NAME) \
+ /* create a python module */ \
+ PYBIND11_MODULE(LIB_NAME, m) { \
+ /* import the general MPC-Net module */ \
+ pybind11::module::import("ocs2_mpcnet_core.MpcnetPybindings"); \
+ /* bind actual MPC-Net interface for specific robot */ \
+ pybind11::class_(m, "MpcnetInterface") \
+ .def(pybind11::init()) \
+ .def("startDataGeneration", &MPCNET_INTERFACE::startDataGeneration, "alpha"_a, "policyFilePath"_a, "timeStep"_a, \
+ "dataDecimation"_a, "nSamples"_a, "samplingCovariance"_a.noconvert(), "initialObservations"_a, "modeSchedules"_a, \
+ "targetTrajectories"_a) \
+ .def("isDataGenerationDone", &MPCNET_INTERFACE::isDataGenerationDone) \
+ .def("getGeneratedData", &MPCNET_INTERFACE::getGeneratedData) \
+ .def("startPolicyEvaluation", &MPCNET_INTERFACE::startPolicyEvaluation, "alpha"_a, "policyFilePath"_a, "timeStep"_a, \
+ "initialObservations"_a, "modeSchedules"_a, "targetTrajectories"_a) \
+ .def("isPolicyEvaluationDone", &MPCNET_INTERFACE::isPolicyEvaluationDone) \
+ .def("getComputedMetrics", &MPCNET_INTERFACE::getComputedMetrics); \
+ }
diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/control/MpcnetBehavioralController.h b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/control/MpcnetBehavioralController.h
new file mode 100644
index 000000000..60ec8a6c6
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/control/MpcnetBehavioralController.h
@@ -0,0 +1,97 @@
+/******************************************************************************
+Copyright (c) 2022, Farbod Farshidian. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+ * Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+******************************************************************************/
+
+#pragma once
+
+#include
+
+#include
+
+#include "ocs2_mpcnet_core/control/MpcnetControllerBase.h"
+
+namespace ocs2 {
+namespace mpcnet {
+
+/**
+ * A behavioral controller that computes the input based on a mixture of an optimal policy (e.g. implicitly found via MPC)
+ * and a learned policy (e.g. explicitly represented by a neural network).
+ * The behavioral policy is pi_behavioral = alpha * pi_optimal + (1 - alpha) * pi_learned with alpha in [0, 1].
+ */
+class MpcnetBehavioralController final : public ControllerBase {
+ public:
+ /**
+ * Constructor.
+ * @param [in] alpha : The mixture parameter.
+ * @param [in] optimalController : The optimal controller (this class takes ownership of a clone).
+ * @param [in] learnedController : The learned controller (this class takes ownership of a clone).
+ */
+ MpcnetBehavioralController(scalar_t alpha, const ControllerBase& optimalController, const MpcnetControllerBase& learnedController)
+ : alpha_(alpha), optimalControllerPtr_(optimalController.clone()), learnedControllerPtr_(learnedController.clone()) {}
+
+ MpcnetBehavioralController() = default;
+ ~MpcnetBehavioralController() override = default;
+ MpcnetBehavioralController* clone() const override { return new MpcnetBehavioralController(*this); }
+
+ /**
+ * Set the mixture parameter.
+ * @param [in] alpha : The mixture parameter.
+ */
+ void setAlpha(scalar_t alpha) { alpha_ = alpha; }
+
+ /**
+ * Set the optimal controller.
+ * @param [in] optimalController : The optimal controller (this class takes ownership of a clone).
+ */
+ void setOptimalController(const ControllerBase& optimalController) { optimalControllerPtr_.reset(optimalController.clone()); }
+
+ /**
+ * Set the learned controller.
+ * @param [in] learnedController : The learned controller (this class takes ownership of a clone).
+ */
+ void setLearnedController(const MpcnetControllerBase& learnedController) { learnedControllerPtr_.reset(learnedController.clone()); }
+
+ vector_t computeInput(scalar_t t, const vector_t& x) override;
+ ControllerType getType() const override { return ControllerType::BEHAVIORAL; }
+
+ int size() const override;
+ void clear() override;
+ bool empty() const override;
+ void concatenate(const ControllerBase* otherController, int index, int length) override;
+
+ private:
+ MpcnetBehavioralController(const MpcnetBehavioralController& other)
+ : MpcnetBehavioralController(other.alpha_, *other.optimalControllerPtr_, *other.learnedControllerPtr_) {}
+
+ scalar_t alpha_;
+ std::unique_ptr optimalControllerPtr_;
+ std::unique_ptr learnedControllerPtr_;
+};
+
+} // namespace mpcnet
+} // namespace ocs2
diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/control/MpcnetControllerBase.h b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/control/MpcnetControllerBase.h
new file mode 100644
index 000000000..8a669cf79
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/control/MpcnetControllerBase.h
@@ -0,0 +1,57 @@
+/******************************************************************************
+Copyright (c) 2022, Farbod Farshidian. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+ * Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+******************************************************************************/
+
+#pragma once
+
+#include
+
+namespace ocs2 {
+namespace mpcnet {
+
+/**
+ * The base class for all controllers that use a MPC-Net policy.
+ */
+class MpcnetControllerBase : public ControllerBase {
+ public:
+ MpcnetControllerBase() = default;
+ ~MpcnetControllerBase() override = default;
+ MpcnetControllerBase* clone() const override = 0;
+
+ /**
+ * Load the model of the policy.
+ * @param [in] policyFilePath : Path to the file with the model of the policy.
+ */
+ virtual void loadPolicyModel(const std::string& policyFilePath) = 0;
+
+ protected:
+ MpcnetControllerBase(const MpcnetControllerBase& other) = default;
+};
+
+} // namespace mpcnet
+} // namespace ocs2
diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/control/MpcnetOnnxController.h b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/control/MpcnetOnnxController.h
new file mode 100644
index 000000000..89c102eb3
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/control/MpcnetOnnxController.h
@@ -0,0 +1,111 @@
+/******************************************************************************
+Copyright (c) 2022, Farbod Farshidian. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+ * Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+******************************************************************************/
+
+#pragma once
+
+#include
+
+#include
+
+#include "ocs2_mpcnet_core/MpcnetDefinitionBase.h"
+#include "ocs2_mpcnet_core/control/MpcnetControllerBase.h"
+
+namespace ocs2 {
+namespace mpcnet {
+
+/**
+ * Convenience function for creating an environment for ONNX Runtime and getting a corresponding shared pointer.
+ * @note Only one environment per process can be created. The environment offers some threading and logging options.
+ * @return Pointer to the environment for ONNX Runtime.
+ */
+inline std::shared_ptr createOnnxEnvironment() {
+ return std::make_shared(ORT_LOGGING_LEVEL_WARNING, "MpcnetOnnxController");
+}
+
+/**
+ * A neural network controller using ONNX Runtime based on the Open Neural Network Exchange (ONNX) format.
+ * The model of the policy computes u = model(t, x) with
+ * t: generalized time (1 x dimensionOfTime),
+ * x: relative state (1 x dimensionOfState),
+ * u: predicted input (1 x dimensionOfInput),
+ * @note The additional first dimension with size 1 for the variables of the model comes from batch processing during training.
+ */
+class MpcnetOnnxController final : public MpcnetControllerBase {
+ public:
+ /**
+ * Constructor.
+ * @note The class is not fully instantiated until calling loadPolicyModel().
+ * @param [in] mpcnetDefinitionPtr : Pointer to the MPC-Net definitions.
+ * @param [in] referenceManagerPtr : Pointer to the reference manager.
+ * @param [in] onnxEnvironmentPtr : Pointer to the environment for ONNX Runtime.
+ */
+ MpcnetOnnxController(std::shared_ptr mpcnetDefinitionPtr,
+ std::shared_ptr referenceManagerPtr, std::shared_ptr onnxEnvironmentPtr)
+ : mpcnetDefinitionPtr_(std::move(mpcnetDefinitionPtr)),
+ referenceManagerPtr_(std::move(referenceManagerPtr)),
+ onnxEnvironmentPtr_(std::move(onnxEnvironmentPtr)) {}
+
+ ~MpcnetOnnxController() override = default;
+ MpcnetOnnxController* clone() const override { return new MpcnetOnnxController(*this); }
+
+ void loadPolicyModel(const std::string& policyFilePath) override;
+
+ vector_t computeInput(const scalar_t t, const vector_t& x) override;
+ ControllerType getType() const override { return ControllerType::ONNX; }
+
+ int size() const override { throw std::runtime_error("[MpcnetOnnxController::size] not implemented."); }
+ void clear() override { throw std::runtime_error("[MpcnetOnnxController::clear] not implemented."); }
+ bool empty() const override { throw std::runtime_error("[MpcnetOnnxController::empty] not implemented."); }
+ void concatenate(const ControllerBase* otherController, int index, int length) override {
+ throw std::runtime_error("[MpcnetOnnxController::concatenate] not implemented.");
+ }
+
+ private:
+ using tensor_element_t = float;
+
+ MpcnetOnnxController(const MpcnetOnnxController& other)
+ : MpcnetOnnxController(other.mpcnetDefinitionPtr_, other.referenceManagerPtr_, other.onnxEnvironmentPtr_) {
+ if (!other.policyFilePath_.empty()) {
+ loadPolicyModel(other.policyFilePath_);
+ }
+ }
+
+ std::shared_ptr mpcnetDefinitionPtr_;
+ std::shared_ptr referenceManagerPtr_;
+ std::shared_ptr onnxEnvironmentPtr_;
+ std::string policyFilePath_;
+ std::unique_ptr sessionPtr_;
+ std::vector inputNames_;
+ std::vector outputNames_;
+ std::vector> inputShapes_;
+ std::vector> outputShapes_;
+};
+
+} // namespace mpcnet
+} // namespace ocs2
diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/dummy/MpcnetDummyLoopRos.h b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/dummy/MpcnetDummyLoopRos.h
new file mode 100644
index 000000000..b1e954e28
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/dummy/MpcnetDummyLoopRos.h
@@ -0,0 +1,112 @@
+/******************************************************************************
+Copyright (c) 2022, Farbod Farshidian. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+ * Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+******************************************************************************/
+
+#pragma once
+
+#include
+#include
+#include
+#include
+#include
+
+#include "ocs2_mpcnet_core/control/MpcnetControllerBase.h"
+
+namespace ocs2 {
+namespace mpcnet {
+
+/**
+ * Dummy loop to test a robot controlled by an MPC-Net policy.
+ */
+class MpcnetDummyLoopRos {
+ public:
+ /**
+ * Constructor.
+ * @param [in] controlFrequency : Minimum frequency at which the MPC-Net policy should be called.
+ * @param [in] rosFrequency : Frequency at which the ROS observers are updated.
+ * @param [in] mpcnetPtr : Pointer to the MPC-Net policy to be used (this class takes ownership).
+ * @param [in] rolloutPtr : Pointer to the rollout to be used (this class takes ownership).
+ * @param [in] rosReferenceManagerPtr : Pointer to the reference manager to be used (shared ownership).
+ */
+ MpcnetDummyLoopRos(scalar_t controlFrequency, scalar_t rosFrequency, std::unique_ptr mpcnetPtr,
+ std::unique_ptr rolloutPtr, std::shared_ptr rosReferenceManagerPtr);
+
+ /**
+ * Default destructor.
+ */
+ virtual ~MpcnetDummyLoopRos() = default;
+
+ /**
+ * Runs the dummy loop.
+ * @param [in] systemObservation: The initial system observation.
+ * @param [in] targetTrajectories: The initial target trajectories.
+ */
+ void run(const SystemObservation& systemObservation, const TargetTrajectories& targetTrajectories);
+
+ /**
+ * Adds one observer to the vector of observers that need to be informed about the system observation, primal solution and command data.
+ * Each observer is updated once after running a rollout.
+ * @param [in] observer : The observer to add.
+ */
+ void addObserver(std::shared_ptr observer);
+
+ /**
+ * Adds one module to the vector of modules that need to be synchronized with the policy.
+ * Each module is updated once before calling the policy.
+ * @param [in] synchronizedModule : The module to add.
+ */
+ void addSynchronizedModule(std::shared_ptr synchronizedModule);
+
+ protected:
+ /**
+ * Runs a rollout.
+ * @param [in] duration : The duration of the run.
+ * @param [in] initialSystemObservation : The initial system observation.
+ * @param [out] finalSystemObservation : The final system observation.
+ */
+ void rollout(scalar_t duration, const SystemObservation& initialSystemObservation, SystemObservation& finalSystemObservation);
+
+ /**
+ * Update the reference manager and the synchronized modules.
+ * @param [in] time : The current time.
+ * @param [in] state : The cuurent state.
+ */
+ void preSolverRun(scalar_t time, const vector_t& state);
+
+ private:
+ scalar_t controlFrequency_;
+ scalar_t rosFrequency_;
+ std::unique_ptr mpcnetPtr_;
+ std::unique_ptr rolloutPtr_;
+ std::shared_ptr rosReferenceManagerPtr_;
+ std::vector> observerPtrs_;
+ std::vector> synchronizedModulePtrs_;
+};
+
+} // namespace mpcnet
+} // namespace ocs2
diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/dummy/MpcnetDummyObserverRos.h b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/dummy/MpcnetDummyObserverRos.h
new file mode 100644
index 000000000..a03661bb4
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/dummy/MpcnetDummyObserverRos.h
@@ -0,0 +1,69 @@
+/******************************************************************************
+Copyright (c) 2022, Farbod Farshidian. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+ * Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+******************************************************************************/
+
+#pragma once
+
+#include
+
+#include
+
+namespace ocs2 {
+namespace mpcnet {
+
+/**
+ * Dummy observer that publishes the current system observation that is required for some target command nodes.
+ */
+class MpcnetDummyObserverRos : public DummyObserver {
+ public:
+ /**
+ * Constructor.
+ * @param [in] nodeHandle : The ROS node handle.
+ * @param [in] topicPrefix : The prefix defines the names for the observation's publishing topic "topicPrefix_mpc_observation".
+ */
+ explicit MpcnetDummyObserverRos(ros::NodeHandle& nodeHandle, std::string topicPrefix = "anonymousRobot");
+
+ /**
+ * Default destructor.
+ */
+ ~MpcnetDummyObserverRos() override = default;
+
+ /**
+ * Update and publish.
+ * @param [in] observation: The current system observation.
+ * @param [in] primalSolution: The current primal solution.
+ * @param [in] command: The given command data.
+ */
+ void update(const SystemObservation& observation, const PrimalSolution& primalSolution, const CommandData& command) override;
+
+ private:
+ ros::Publisher observationPublisher_;
+};
+
+} // namespace mpcnet
+} // namespace ocs2
diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/rollout/MpcnetData.h b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/rollout/MpcnetData.h
new file mode 100644
index 000000000..7fd63215e
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/rollout/MpcnetData.h
@@ -0,0 +1,86 @@
+/******************************************************************************
+Copyright (c) 2022, Farbod Farshidian. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+ * Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+******************************************************************************/
+
+#pragma once
+
+#include
+#include
+
+#include "ocs2_mpcnet_core/MpcnetDefinitionBase.h"
+
+namespace ocs2 {
+namespace mpcnet {
+
+/**
+ * Data point collected during the data generation rollout.
+ */
+struct DataPoint {
+ /** Mode of the system. */
+ size_t mode;
+ /** Absolute time. */
+ scalar_t t;
+ /** Observed state. */
+ vector_t x;
+ /** Optimal control input. */
+ vector_t u;
+ /** Observation given as input to the policy. */
+ vector_t observation;
+ /** Action transformation applied to the output of the policy. */
+ std::pair actionTransformation;
+ /** Linear-quadratic approximation of the Hamiltonian, using x and u as development/expansion points. */
+ ScalarFunctionQuadraticApproximation hamiltonian;
+};
+using data_point_t = DataPoint;
+using data_array_t = std::vector;
+
+/**
+ * Get a data point.
+ * @param [in] mpc : The MPC with a pointer to the underlying solver.
+ * @param [in] mpcnetDefinition : The MPC-Net definitions.
+ * @param [in] deviation : The state deviation from the nominal state where to get the data point from.
+ * @return A data point.
+ */
+inline data_point_t getDataPoint(MPC_BASE& mpc, MpcnetDefinitionBase& mpcnetDefinition, const vector_t& deviation) {
+ data_point_t dataPoint;
+ const auto& referenceManager = mpc.getSolverPtr()->getReferenceManager();
+ const auto primalSolution = mpc.getSolverPtr()->primalSolution(mpc.getSolverPtr()->getFinalTime());
+ dataPoint.t = primalSolution.timeTrajectory_.front();
+ dataPoint.x = primalSolution.stateTrajectory_.front() + deviation;
+ dataPoint.u = primalSolution.controllerPtr_->computeInput(dataPoint.t, dataPoint.x);
+ dataPoint.mode = primalSolution.modeSchedule_.modeAtTime(dataPoint.t);
+ dataPoint.observation = mpcnetDefinition.getObservation(dataPoint.t, dataPoint.x, referenceManager.getModeSchedule(),
+ referenceManager.getTargetTrajectories());
+ dataPoint.actionTransformation = mpcnetDefinition.getActionTransformation(dataPoint.t, dataPoint.x, referenceManager.getModeSchedule(),
+ referenceManager.getTargetTrajectories());
+ dataPoint.hamiltonian = mpc.getSolverPtr()->getHamiltonian(dataPoint.t, dataPoint.x, dataPoint.u);
+ return dataPoint;
+}
+
+} // namespace mpcnet
+} // namespace ocs2
diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/rollout/MpcnetDataGeneration.h b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/rollout/MpcnetDataGeneration.h
new file mode 100644
index 000000000..6b0ee2a0e
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/rollout/MpcnetDataGeneration.h
@@ -0,0 +1,95 @@
+/******************************************************************************
+Copyright (c) 2022, Farbod Farshidian. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+ * Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+******************************************************************************/
+
+#pragma once
+
+#include "ocs2_mpcnet_core/rollout/MpcnetData.h"
+#include "ocs2_mpcnet_core/rollout/MpcnetRolloutBase.h"
+
+namespace ocs2 {
+namespace mpcnet {
+
+/**
+ * A class for generating data from a system that is forward simulated with a behavioral controller.
+ * @note Usually the behavioral controller moves from the MPC policy to the MPC-Net policy throughout the training process.
+ */
+class MpcnetDataGeneration final : public MpcnetRolloutBase {
+ public:
+ /**
+ * Constructor.
+ * @param [in] mpcPtr : Pointer to the MPC solver to be used (this class takes ownership).
+ * @param [in] mpcnetPtr : Pointer to the MPC-Net policy to be used (this class takes ownership).
+ * @param [in] rolloutPtr : Pointer to the rollout to be used (this class takes ownership).
+ * @param [in] mpcnetDefinitionPtr : Pointer to the MPC-Net definitions to be used (shared ownership).
+ * @param [in] referenceManagerPtr : Pointer to the reference manager to be used (shared ownership).
+ */
+ MpcnetDataGeneration(std::unique_ptr mpcPtr, std::unique_ptr mpcnetPtr,
+ std::unique_ptr rolloutPtr, std::shared_ptr mpcnetDefinitionPtr,
+ std::shared_ptr referenceManagerPtr)
+ : MpcnetRolloutBase(std::move(mpcPtr), std::move(mpcnetPtr), std::move(rolloutPtr), std::move(mpcnetDefinitionPtr),
+ std::move(referenceManagerPtr)) {}
+
+ /**
+ * Default destructor.
+ */
+ ~MpcnetDataGeneration() override = default;
+
+ /**
+ * Deleted copy constructor.
+ */
+ MpcnetDataGeneration(const MpcnetDataGeneration&) = delete;
+
+ /**
+ * Deleted copy assignment.
+ */
+ MpcnetDataGeneration& operator=(const MpcnetDataGeneration&) = delete;
+
+ /**
+ * Run the data generation.
+ * @param [in] alpha : The mixture parameter for the behavioral controller.
+ * @param [in] policyFilePath : The path to the file with the learned policy for the behavioral controller.
+ * @param [in] timeStep : The time step for the forward simulation of the system with the behavioral controller.
+ * @param [in] dataDecimation : The integer factor used for downsampling the data signal.
+ * @param [in] nSamples : The number of samples drawn from a multivariate normal distribution around the nominal states.
+ * @param [in] samplingCovariance : The covariance matrix used for sampling from a multivariate normal distribution.
+ * @param [in] initialObservation : The initial system observation to start from (time and state required).
+ * @param [in] modeSchedule : The mode schedule providing the event times and mode sequence.
+ * @param [in] targetTrajectories : The target trajectories to be tracked.
+ * @return Pointer to the data array with the generated data.
+ */
+ const data_array_t* run(scalar_t alpha, const std::string& policyFilePath, scalar_t timeStep, size_t dataDecimation, size_t nSamples,
+ const matrix_t& samplingCovariance, const SystemObservation& initialObservation, const ModeSchedule& modeSchedule,
+ const TargetTrajectories& targetTrajectories);
+
+ private:
+ data_array_t dataArray_;
+};
+
+} // namespace mpcnet
+} // namespace ocs2
diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/rollout/MpcnetMetrics.h b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/rollout/MpcnetMetrics.h
new file mode 100644
index 000000000..120da00c7
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/rollout/MpcnetMetrics.h
@@ -0,0 +1,50 @@
+/******************************************************************************
+Copyright (c) 2022, Farbod Farshidian. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+ * Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+******************************************************************************/
+
+#pragma once
+
+#include
+
+namespace ocs2 {
+namespace mpcnet {
+
+/**
+ * Metrics computed during the policy evaluation rollout.
+ */
+struct Metrics {
+ /** Survival time. */
+ scalar_t survivalTime = 0.0;
+ /** Hamiltonian incurred over time. */
+ scalar_t incurredHamiltonian = 0.0;
+};
+using metrics_t = Metrics;
+using metrics_array_t = std::vector;
+
+} // namespace mpcnet
+} // namespace ocs2
diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/rollout/MpcnetPolicyEvaluation.h b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/rollout/MpcnetPolicyEvaluation.h
new file mode 100644
index 000000000..24d8ac57e
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/rollout/MpcnetPolicyEvaluation.h
@@ -0,0 +1,88 @@
+/******************************************************************************
+Copyright (c) 2022, Farbod Farshidian. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+ * Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+******************************************************************************/
+
+#pragma once
+
+#include "ocs2_mpcnet_core/rollout/MpcnetMetrics.h"
+#include "ocs2_mpcnet_core/rollout/MpcnetRolloutBase.h"
+
+namespace ocs2 {
+namespace mpcnet {
+
+/**
+ * A class for evaluating a policy for a system that is forward simulated with a behavioral controller.
+ * @note Usually the behavioral controller is evaluated for the MPC-Net policy (alpha = 0).
+ */
+class MpcnetPolicyEvaluation final : public MpcnetRolloutBase {
+ public:
+ /**
+ * Constructor.
+ * @param [in] mpcPtr : Pointer to the MPC solver to be used (this class takes ownership).
+ * @param [in] mpcnetPtr : Pointer to the MPC-Net policy to be used (this class takes ownership).
+ * @param [in] rolloutPtr : Pointer to the rollout to be used (this class takes ownership).
+ * @param [in] mpcnetDefinitionPtr : Pointer to the MPC-Net definitions to be used (shared ownership).
+ * @param [in] referenceManagerPtr : Pointer to the reference manager to be used (shared ownership).
+ */
+ MpcnetPolicyEvaluation(std::unique_ptr mpcPtr, std::unique_ptr mpcnetPtr,
+ std::unique_ptr rolloutPtr, std::shared_ptr mpcnetDefinitionPtr,
+ std::shared_ptr referenceManagerPtr)
+ : MpcnetRolloutBase(std::move(mpcPtr), std::move(mpcnetPtr), std::move(rolloutPtr), std::move(mpcnetDefinitionPtr),
+ std::move(referenceManagerPtr)) {}
+
+ /**
+ * Default destructor.
+ */
+ ~MpcnetPolicyEvaluation() override = default;
+
+ /**
+ * Deleted copy constructor.
+ */
+ MpcnetPolicyEvaluation(const MpcnetPolicyEvaluation&) = delete;
+
+ /**
+ * Deleted copy assignment.
+ */
+ MpcnetPolicyEvaluation& operator=(const MpcnetPolicyEvaluation&) = delete;
+
+ /**
+ * Run the policy evaluation.
+ * @param [in] alpha : The mixture parameter for the behavioral controller.
+ * @param [in] policyFilePath : The path to the file with the learned policy for the behavioral controller.
+ * @param [in] timeStep : The time step for the forward simulation of the system with the behavioral controller.
+ * @param [in] initialObservation : The initial system observation to start from (time and state required).
+ * @param [in] modeSchedule : The mode schedule providing the event times and mode sequence.
+ * @param [in] targetTrajectories : The target trajectories to be tracked.
+ * @return The computed metrics.
+ */
+ metrics_t run(scalar_t alpha, const std::string& policyFilePath, scalar_t timeStep, const SystemObservation& initialObservation,
+ const ModeSchedule& modeSchedule, const TargetTrajectories& targetTrajectories);
+};
+
+} // namespace mpcnet
+} // namespace ocs2
diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/rollout/MpcnetRolloutBase.h b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/rollout/MpcnetRolloutBase.h
new file mode 100644
index 000000000..23c16f1d2
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/rollout/MpcnetRolloutBase.h
@@ -0,0 +1,117 @@
+/******************************************************************************
+Copyright (c) 2022, Farbod Farshidian. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+ * Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+******************************************************************************/
+
+#pragma once
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "ocs2_mpcnet_core/MpcnetDefinitionBase.h"
+#include "ocs2_mpcnet_core/control/MpcnetBehavioralController.h"
+#include "ocs2_mpcnet_core/control/MpcnetControllerBase.h"
+
+namespace ocs2 {
+namespace mpcnet {
+
+/**
+ * The base class for doing rollouts for a system that is forward simulated with a behavioral controller.
+ * The behavioral policy is a mixture of an MPC policy and an MPC-Net policy (e.g. a neural network).
+ */
+class MpcnetRolloutBase {
+ public:
+ /**
+ * Constructor.
+ * @param [in] mpcPtr : Pointer to the MPC solver to be used (this class takes ownership).
+ * @param [in] mpcnetPtr : Pointer to the MPC-Net policy to be used (this class takes ownership).
+ * @param [in] rolloutPtr : Pointer to the rollout to be used (this class takes ownership).
+ * @param [in] mpcnetDefinitionPtr : Pointer to the MPC-Net definitions to be used (shared ownership).
+ * @param [in] referenceManagerPtr : Pointer to the reference manager to be used (shared ownership).
+ */
+ MpcnetRolloutBase(std::unique_ptr mpcPtr, std::unique_ptr mpcnetPtr,
+ std::unique_ptr rolloutPtr, std::shared_ptr mpcnetDefinitionPtr,
+ std::shared_ptr referenceManagerPtr)
+ : mpcPtr_(std::move(mpcPtr)),
+ mpcnetPtr_(std::move(mpcnetPtr)),
+ rolloutPtr_(std::move(rolloutPtr)),
+ mpcnetDefinitionPtr_(std::move(mpcnetDefinitionPtr)),
+ referenceManagerPtr_(std::move(referenceManagerPtr)),
+ behavioralControllerPtr_(new MpcnetBehavioralController()) {}
+
+ /**
+ * Default destructor.
+ */
+ virtual ~MpcnetRolloutBase() = default;
+
+ /**
+ * Deleted copy constructor.
+ */
+ MpcnetRolloutBase(const MpcnetRolloutBase&) = delete;
+
+ /**
+ * Deleted copy assignment.
+ */
+ MpcnetRolloutBase& operator=(const MpcnetRolloutBase&) = delete;
+
+ protected:
+ /**
+ * (Re)set system components.
+ * @param [in] alpha : The mixture parameter for the behavioral controller.
+ * @param [in] policyFilePath : The path to the file with the learned policy for the controller.
+ * @param [in] initialObservation : The initial system observation to start from (time and state required).
+ * @param [in] modeSchedule : The mode schedule providing the event times and mode sequence.
+ * @param [in] targetTrajectories : The target trajectories to be tracked.
+ */
+ void set(scalar_t alpha, const std::string& policyFilePath, const SystemObservation& initialObservation, const ModeSchedule& modeSchedule,
+ const TargetTrajectories& targetTrajectories);
+
+ /**
+ * Simulate the system one step forward.
+ * @param [in] timeStep : The time step for the forward simulation of the system with the behavioral controller.
+ */
+ void step(scalar_t timeStep);
+
+ std::unique_ptr mpcPtr_;
+ std::shared_ptr mpcnetDefinitionPtr_;
+ std::unique_ptr behavioralControllerPtr_;
+ SystemObservation systemObservation_;
+ PrimalSolution primalSolution_;
+
+ private:
+ std::unique_ptr mpcnetPtr_;
+ std::unique_ptr rolloutPtr_;
+ std::shared_ptr referenceManagerPtr_;
+};
+
+} // namespace mpcnet
+} // namespace ocs2
diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/rollout/MpcnetRolloutManager.h b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/rollout/MpcnetRolloutManager.h
new file mode 100644
index 000000000..c1862e9bc
--- /dev/null
+++ b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/rollout/MpcnetRolloutManager.h
@@ -0,0 +1,137 @@
+/******************************************************************************
+Copyright (c) 2022, Farbod Farshidian. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+ * Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+******************************************************************************/
+
+#pragma once
+
+#include
+
+#include "ocs2_mpcnet_core/rollout/MpcnetDataGeneration.h"
+#include "ocs2_mpcnet_core/rollout/MpcnetPolicyEvaluation.h"
+
+namespace ocs2 {
+namespace mpcnet {
+
+/**
+ * A class to manage the data generation and policy evaluation rollouts for MPC-Net.
+ */
+class MpcnetRolloutManager {
+ public:
+ /**
+ * Constructor.
+ * @note The first nDataGenerationThreads pointers will be used for the data generation and the next nPolicyEvaluationThreads pointers for
+ * the policy evaluation.
+ * @param [in] nDataGenerationThreads : Number of data generation threads.
+ * @param [in] nPolicyEvaluationThreads : Number of policy evaluation threads.
+ * @param [in] mpcPtrs : Pointers to the MPC solvers to be used.
+ * @param [in] mpcnetPtrs : Pointers to the MPC-Net policies to be used.
+ * @param [in] rolloutPtrs : Pointers to the rollouts to be used.
+ * @param [in] mpcnetDefinitionPtrs : Pointers to the MPC-Net definitions to be used.
+ * @param [in] referenceManagerPtrs : Pointers to the reference managers to be used.
+ */
+ MpcnetRolloutManager(size_t nDataGenerationThreads, size_t nPolicyEvaluationThreads, std::vector> mpcPtrs,
+ std::vector> mpcnetPtrs, std::vector> rolloutPtrs,
+ std::vector> mpcnetDefinitionPtrs,
+ std::vector> referenceManagerPtrs);
+
+ /**
+ * Default destructor.
+ */
+ virtual ~MpcnetRolloutManager() = default;
+
+ /**
+ * Starts the data genration forward simulated by a behavioral controller.
+ * @param [in] alpha : The mixture parameter for the behavioral controller.
+ * @param [in] policyFilePath : The path to the file with the learned policy for the behavioral controller.
+ * @param [in] timeStep : The time step for the forward simulation of the system with the behavioral controller.
+ * @param [in] dataDecimation : The integer factor used for downsampling the data signal.
+ * @param [in] nSamples : The number of samples drawn from a multivariate normal distribution around the nominal states.
+ * @param [in] samplingCovariance : The covariance matrix used for sampling from a multivariate normal distribution.
+ * @param [in] initialObservations : The initial system observations to start from (time and state required).
+ * @param [in] modeSchedules : The mode schedules providing the event times and mode sequence.
+ * @param [in] targetTrajectories : The target trajectories to be tracked.
+ */
+ void startDataGeneration(scalar_t alpha, const std::string& policyFilePath, scalar_t timeStep, size_t dataDecimation, size_t nSamples,
+ const matrix_t& samplingCovariance, const std::vector& initialObservations,
+ const std::vector& modeSchedules, const std::vector& targetTrajectories);
+
+ /**
+ * Check if data generation is done.
+ * @return True if done.
+ */
+ bool isDataGenerationDone();
+
+ /**
+ * Get the data generated from the data generation rollout.
+ * @return The generated data.
+ */
+ const data_array_t& getGeneratedData();
+
+ /**
+ * Starts the policy evaluation forward simulated by a behavioral controller.
+ * @param [in] alpha : The mixture parameter for the behavioral controller.
+ * @param [in] policyFilePath : The path to the file with the learned policy for the behavioral controller.
+ * @param [in] timeStep : The time step for the forward simulation of the system with the behavioral controller.
+ * @param [in] initialObservations : The initial system observations to start from (time and state required).
+ * @param [in] modeSchedules : The mode schedules providing the event times and mode sequence.
+ * @param [in] targetTrajectories : The target trajectories to be tracked.
+ */
+ void startPolicyEvaluation(scalar_t alpha, const std::string& policyFilePath, scalar_t timeStep,
+ const std::vector