diff --git a/.github/workflows/ros-build-test.yml b/.github/workflows/ros-build-test.yml index 8816680ca..720f066c5 100644 --- a/.github/workflows/ros-build-test.yml +++ b/.github/workflows/ros-build-test.yml @@ -30,7 +30,7 @@ jobs: - name: System deps run: | apt-get update - apt-get install -y git ninja-build liburdfdom-dev liboctomap-dev libassimp-dev checkinstall + apt-get install -y git ninja-build liburdfdom-dev liboctomap-dev libassimp-dev checkinstall wget rsync - uses: actions/checkout@v2 with: @@ -56,6 +56,12 @@ jobs: cmake .. -DPYTHON_EXECUTABLE=$(python3 -c "import sys; print(sys.executable)") make -j4 && checkinstall + - name: Install ONNX Runtime + run: | + wget https://github.com/microsoft/onnxruntime/releases/download/v1.7.0/onnxruntime-linux-x64-1.7.0.tgz -P src + tar xf src/onnxruntime-linux-x64-1.7.0.tgz -C src + rsync -a src/ocs2/ocs2_mpcnet/ocs2_mpcnet_core/misc/onnxruntime/cmake/ src/onnxruntime-linux-x64-1.7.0/cmake + - name: Build (${{ matrix.build_type }}) shell: bash run: | diff --git a/.gitignore b/.gitignore index 15f4690f1..46f83bb33 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,4 @@ ocs2_ddp/test/ddp_test_generated/ *.out *.synctex.gz .vscode/ +runs/ diff --git a/jenkins-pipeline b/jenkins-pipeline index 7deeb22ce..02191156e 100644 --- a/jenkins-pipeline +++ b/jenkins-pipeline @@ -1,5 +1,5 @@ library 'continuous_integration_pipeline' -ciPipeline("--ros-distro noetic --publish-doxygen --recipes raisimlib\ +ciPipeline("--ros-distro noetic --publish-doxygen --recipes onnxruntime raisimlib\ --dependencies 'git@github.com:leggedrobotics/hpp-fcl.git;master;git'\ 'git@github.com:leggedrobotics/pinocchio.git;master;git'\ 'git@github.com:leggedrobotics/ocs2_robotic_assets.git;main;git'\ diff --git a/ocs2/package.xml b/ocs2/package.xml index 834b4090e..c088f80f2 100644 --- a/ocs2/package.xml +++ b/ocs2/package.xml @@ -24,6 +24,7 @@ ocs2_robotic_examples ocs2_thirdparty ocs2_raisim + ocs2_mpcnet diff --git a/ocs2_core/include/ocs2_core/control/ControllerType.h b/ocs2_core/include/ocs2_core/control/ControllerType.h index 65aa1962f..bbae196ee 100644 --- a/ocs2_core/include/ocs2_core/control/ControllerType.h +++ b/ocs2_core/include/ocs2_core/control/ControllerType.h @@ -34,6 +34,6 @@ namespace ocs2 { /** * Enum class for specifying controller type */ -enum class ControllerType { UNKNOWN, FEEDFORWARD, LINEAR, NEURAL_NETWORK }; +enum class ControllerType { UNKNOWN, FEEDFORWARD, LINEAR, ONNX, BEHAVIORAL }; } // namespace ocs2 diff --git a/ocs2_doc/docs/index.rst b/ocs2_doc/docs/index.rst index 11450c76d..44a2bf541 100644 --- a/ocs2_doc/docs/index.rst +++ b/ocs2_doc/docs/index.rst @@ -11,6 +11,7 @@ Table of Contents robotic_examples.rst from_urdf_to_ocp.rst profiling.rst + mpcnet.rst faq.rst .. rubric:: Reference and Index: diff --git a/ocs2_doc/docs/installation.rst b/ocs2_doc/docs/installation.rst index 1e8213f6e..88375bd5a 100644 --- a/ocs2_doc/docs/installation.rst +++ b/ocs2_doc/docs/installation.rst @@ -98,6 +98,42 @@ Optional Dependencies * `Grid Map `__ catkin package, which may be installed with ``sudo apt install ros-noetic-grid-map-msgs``. +* `ONNX Runtime `__ is an inferencing and training accelerator. Here, it is used for deploying learned :ref:`MPC-Net ` policies in C++ code. To locally install it, do the following: + + .. code-block:: bash + + wget https://github.com/microsoft/onnxruntime/releases/download/v1.7.0/onnxruntime-linux-x64-1.7.0.tgz -P ~/catkin_ws/src + tar xf ~/catkin_ws/src/onnxruntime-linux-x64-1.7.0.tgz -C ~/catkin_ws/src + rsync -a ~/catkin_ws/src/ocs2/ocs2_mpcnet/ocs2_mpcnet_core/misc/onnxruntime/cmake/ ~/catkin_ws/src/onnxruntime-linux-x64-1.7.0/cmake + + We provide custom cmake config and version files to enable ``find_package(onnxruntime)`` without modifying ``LIBRARY_PATH`` and ``LD_LIBRARY_PATH``. Note that the last command above assumes that you cloned OCS2 into the folder ``git`` in your user's home directory. + +* `Virtual environments `__ are recommended when training :ref:`MPC-Net ` policies: + + .. code-block:: bash + + sudo apt-get install python3-venv + + Create an environment and give it access to the system site packages: + + .. code-block:: bash + + mkdir venvs && cd venvs + python3 -m venv mpcnet + + Activate the environment and install the requirements: + + .. code-block:: bash + + source ~/venvs/mpcnet/bin/activate + python3 -m pip install -r ~/git/ocs2/ocs2_mpcnet/ocs2_mpcnet_core/requirements.txt + + Newer graphics cards might require a CUDA capability which is currently not supported by the standard PyTorch installation. + In that case check `PyTorch Start Locally `__ for a compatible version and, e.g., run: + + .. code-block:: bash + + pip3 install torch==1.10.2+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html .. _doxid-ocs2_doc_installation_ocs2_doc_install: diff --git a/ocs2_doc/docs/mpcnet.rst b/ocs2_doc/docs/mpcnet.rst new file mode 100644 index 000000000..39993c850 --- /dev/null +++ b/ocs2_doc/docs/mpcnet.rst @@ -0,0 +1,164 @@ +.. index:: pair: page; MPC-Net + +.. _doxid-ocs2_doc_mpcnet: + +MPC-Net +======= + +MPC-Net is an imitation learning approach that uses solutions from MPC to guide the policy search. +The main idea is to imitate MPC by minimizing the control Hamiltonian while representing the corresponding control inputs by a parametrized policy. +MPC-Net can be used to clone a model predictive controller into a neural network policy, which can be evaluated much faster than MPC. +Therefore, MPC-Net is a useful proxy for MPC in computationally demanding applications that do not require the most exact solution. + +The multi-threaded data generation and policy evaluation run asynchronously with the policy training. +The data generation and policy evaluation are implemented in C++ and run on CPU, while the policy training is implemented in Python and runs on GPU. +The control Hamiltonian is represented by a linear-quadratic approximation. +Therefore, the training can run on GPU without callbacks to OCS2 C++ code running on CPU to evaluate the Hamiltonian, and one can exploit batch processing on GPU. + +Robots +~~~~~~ + +MPC-Net has been implemented for the following :ref:`Robotic Examples `: + +============= ================ ================= ======== ============= +Robot Recom. CPU Cores Recom. GPU Memory RaiSim Training Time +============= ================ ================= ======== ============= +ballbot 4 2 GB No 0m 20s +legged_robot 12 8 GB Yes / No 7m 40s +============= ================ ================= ======== ============= + +Setup +~~~~~ + +Make sure to follow the :ref:`Installation ` page. +Follow all the instructions for the dependencies. +Regarding the optional dependencies, make sure to follow the instruction for ONNX Runtime and the virtual environment, optionally set up RaiSim. + +To build all MPC-Net packages, build the meta package: + +.. code-block:: bash + + cd + catkin_build ocs2_mpcnet + + # Example: + cd ~/catkin_ws + catkin_build ocs2_mpcnet + +To build a robot-specific package, replace :code:`` with the robot name: + +.. code-block:: bash + + cd + catkin_build ocs2__mpcnet + + # Example: + cd ~/catkin_ws + catkin_build ocs2_ballbot_mpcnet + +Training +~~~~~~~~ + +To train an MPC-Net policy, run: + +.. code-block:: bash + + cd /ocs2_mpcnet/ocs2__mpcnet/python/ocs2__mpcnet + source /devel/setup.bash + source /mpcnet/bin/activate + python3 train.py + + # Example: + cd ~/git/ocs2/ocs2_mpcnet/ocs2_ballbot_mpcnet/python/ocs2_ballbot_mpcnet + source ~/catkin_ws/devel/setup.bash + source ~/venvs/mpcnet/bin/activate + python3 train.py + +To monitor the training progress with Tensorboard, run: + +.. code-block:: bash + + cd /ocs2_mpcnet/ocs2__mpcnet/python/ocs2__mpcnet + source /mpcnet/bin/activate + tensorboard --logdir=runs + + # Example: + cd ~/git/ocs2/ocs2_mpcnet/ocs2_ballbot_mpcnet/python/ocs2_ballbot_mpcnet + source ~/venvs/mpcnet/bin/activate + tensorboard --logdir=runs + +If you use RaiSim, you can visualize the data generation and policy evaluation rollouts with RaiSim Unity, where pre-built executables are provided in RaiSim's :code:`raisimUnity` folder. For example, on Linux run: + +.. code-block:: bash + + /raisimUnity/linux/raisimUnity.x86_64 + + # Example: + ~/git/raisimLib/raisimUnity/linux/raisimUnity.x86_64 + +Deployment +~~~~~~~~~~ + +To deploy the default policy stored in the robot-specific package's :code:`policy` folder, run: + +.. code-block:: bash + + cd + source devel/setup.bash + roslaunch ocs2__mpcnet _mpcnet.launch + + # Example: + cd ~/catkin_ws + source devel/setup.bash + roslaunch ocs2_ballbot_mpcnet ballbot_mpcnet.launch + +To deploy a new policy stored in the robot-specific package's :code:`python/ocs2__mpcnet/runs` folder, replace :code:`` with the absolute file path to the final policy and run: + +.. code-block:: bash + + cd + source devel/setup.bash + roslaunch ocs2__mpcnet _mpcnet.launch policyFile:= + + # Example: + cd ~/catkin_ws + source devel/setup.bash + roslaunch ocs2_ballbot_mpcnet ballbot_mpcnet.launch policyFile:='/home/user/git/ocs2/ocs2_mpcnet/ocs2_ballbot_mpcnet/python/ocs2_ballbot_mpcnet/runs/2022-04-01_12-00-00_ballbot_description/final_policy.onnx' + +How to Set Up a New Robot +~~~~~~~~~~~~~~~~~~~~~~~~~ + +Setting up MPC-Net for a new robot is relatively easy, as the **ocs2_mpcnet_core** package takes care of the data generation as well as policy evaluation rollouts and implements important learning components, such as the memory, policy, and loss function. + +This section assumes that you already have the packages for the robot-specific MPC implementation: + +1. **ocs2_**: Provides the library with the robot-specific MPC implementation. +2. **ocs2__ros**: Wraps around the MPC implementation with ROS to define ROS nodes. +3. **ocs2__raisim**: (Optional) interface between the robot-specific MPC implementation and RaiSim. + +For the actual **ocs2__mpcnet** package, follow the structure of existing robot-specific MPC-Net packages. +The most important classes/files that have to be implemented are: + +* **MpcnetDefinition**: Defines how OCS2 state variables are transformed to the policy observations. and how policy actions are transformed to OCS2 control inputs. +* **MpcnetInterface**: Provides the interface between C++ and Python, allowing to exchange data and policies. +* **.yaml**: Stores the configuration parameters. +* **mpcnet.py**: Adds robot-specific methods, e.g. implements the tasks that the robot should execute, for the MPC-Net training. +* **train.py**: Starts the main training script. + +Known Issues +~~~~~~~~~~~~ + +Stiff inequality constraints can lead to very large Hamiltonians and gradients of the Hamilltonian near the log barrier. +This can obstruct the learning process and the policy might not learn something useful. +In that case, enable the gradient clipping in the robot's MPC-Net YAML configuration file and tune the gradient clipping value. + +References +~~~~~~~~~~ + +This part of the toolbox has been developed based on the following publications: + +.. bibliography:: + :list: enumerated + + carius2020mpcnet + reske2021imitation diff --git a/ocs2_doc/docs/overview.rst b/ocs2_doc/docs/overview.rst index 6c7e5f68c..6141cc135 100644 --- a/ocs2_doc/docs/overview.rst +++ b/ocs2_doc/docs/overview.rst @@ -60,7 +60,8 @@ Michael Spieler (ETHZ), Jan Carius (ETHZ), Jean-Pierre Sleiman (ETHZ). -**Other Developers**: +**Other Developers**: +Alexander Reske, Mayank Mittal, Johannes Pankert, Perry Franklin, diff --git a/ocs2_mpcnet/ocs2_ballbot_mpcnet/CMakeLists.txt b/ocs2_mpcnet/ocs2_ballbot_mpcnet/CMakeLists.txt new file mode 100644 index 000000000..d95da9050 --- /dev/null +++ b/ocs2_mpcnet/ocs2_ballbot_mpcnet/CMakeLists.txt @@ -0,0 +1,128 @@ +cmake_minimum_required(VERSION 3.0.2) +project(ocs2_ballbot_mpcnet) + +set(CATKIN_PACKAGE_DEPENDENCIES + ocs2_ballbot + ocs2_ballbot_ros + ocs2_mpcnet_core +) + +find_package(catkin REQUIRED COMPONENTS + ${CATKIN_PACKAGE_DEPENDENCIES} +) + +# Generate compile_commands.json for clang tools +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +################################### +## catkin specific configuration ## +################################### + +catkin_package( + INCLUDE_DIRS + include + LIBRARIES + ${PROJECT_NAME} + CATKIN_DEPENDS + ${CATKIN_PACKAGE_DEPENDENCIES} + DEPENDS +) + +########### +## Build ## +########### + +include_directories( + include + ${catkin_INCLUDE_DIRS} +) + +# main library +add_library(${PROJECT_NAME} + src/BallbotMpcnetDefinition.cpp + src/BallbotMpcnetInterface.cpp +) +add_dependencies(${PROJECT_NAME} + ${catkin_EXPORTED_TARGETS} +) +target_link_libraries(${PROJECT_NAME} + ${catkin_LIBRARIES} +) + +# python bindings +pybind11_add_module(BallbotMpcnetPybindings SHARED + src/BallbotMpcnetPybindings.cpp +) +add_dependencies(BallbotMpcnetPybindings + ${PROJECT_NAME} + ${catkin_EXPORTED_TARGETS} +) +target_link_libraries(BallbotMpcnetPybindings PRIVATE + ${PROJECT_NAME} + ${catkin_LIBRARIES} +) +set_target_properties(BallbotMpcnetPybindings + PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CATKIN_DEVEL_PREFIX}/${CATKIN_PACKAGE_PYTHON_DESTINATION} +) + +# MPC-Net dummy node +add_executable(ballbot_mpcnet_dummy + src/BallbotMpcnetDummyNode.cpp +) +add_dependencies(ballbot_mpcnet_dummy + ${PROJECT_NAME} + ${catkin_EXPORTED_TARGETS} +) +target_link_libraries(ballbot_mpcnet_dummy + ${PROJECT_NAME} + ${catkin_LIBRARIES} +) +target_compile_options(ballbot_mpcnet_dummy PRIVATE ${OCS2_CXX_FLAGS}) + +catkin_python_setup() + +######################### +### CLANG TOOLING ### +######################### +find_package(cmake_clang_tools QUIET) +if(cmake_clang_tools_FOUND) + message(STATUS "Run clang tooling for target ocs2_ballbot_mpcnet") + add_clang_tooling( + TARGETS ${PROJECT_NAME} ballbot_mpcnet_dummy + SOURCE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/src ${CMAKE_CURRENT_SOURCE_DIR}/include + CT_HEADER_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/include + CF_WERROR +) +endif(cmake_clang_tools_FOUND) + +############# +## Install ## +############# + +install(TARGETS ${PROJECT_NAME} + ARCHIVE DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} + LIBRARY DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} + RUNTIME DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} +) + +install(DIRECTORY include/${PROJECT_NAME}/ + DESTINATION ${CATKIN_PACKAGE_INCLUDE_DESTINATION}) + +install(TARGETS BallbotMpcnetPybindings + ARCHIVE DESTINATION ${CATKIN_PACKAGE_PYTHON_DESTINATION} + LIBRARY DESTINATION ${CATKIN_PACKAGE_PYTHON_DESTINATION} +) + +install(TARGETS ballbot_mpcnet_dummy + RUNTIME DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} +) + +install(DIRECTORY launch policy + DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION} +) + +############# +## Testing ## +############# + +# TODO(areske) diff --git a/ocs2_mpcnet/ocs2_ballbot_mpcnet/include/ocs2_ballbot_mpcnet/BallbotMpcnetDefinition.h b/ocs2_mpcnet/ocs2_ballbot_mpcnet/include/ocs2_ballbot_mpcnet/BallbotMpcnetDefinition.h new file mode 100644 index 000000000..e50cf779f --- /dev/null +++ b/ocs2_mpcnet/ocs2_ballbot_mpcnet/include/ocs2_ballbot_mpcnet/BallbotMpcnetDefinition.h @@ -0,0 +1,71 @@ +/****************************************************************************** +Copyright (c) 2022, Farbod Farshidian. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +******************************************************************************/ + +#pragma once + +#include + +namespace ocs2 { +namespace ballbot { + +/** + * MPC-Net definitions for ballbot. + */ +class BallbotMpcnetDefinition final : public ocs2::mpcnet::MpcnetDefinitionBase { + public: + /** + * Default constructor. + */ + BallbotMpcnetDefinition() = default; + + /** + * Default destructor. + */ + ~BallbotMpcnetDefinition() override = default; + + /** + * @see MpcnetDefinitionBase::getObservation + */ + vector_t getObservation(scalar_t t, const vector_t& x, const ModeSchedule& modeSchedule, + const TargetTrajectories& targetTrajectories) override; + + /** + * @see MpcnetDefinitionBase::getActionTransformation + */ + std::pair getActionTransformation(scalar_t t, const vector_t& x, const ModeSchedule& modeSchedule, + const TargetTrajectories& targetTrajectories) override; + + /** + * @see MpcnetDefinitionBase::isValid + */ + bool isValid(scalar_t t, const vector_t& x, const ModeSchedule& modeSchedule, const TargetTrajectories& targetTrajectories) override; +}; + +} // namespace ballbot +} // namespace ocs2 diff --git a/ocs2_mpcnet/ocs2_ballbot_mpcnet/include/ocs2_ballbot_mpcnet/BallbotMpcnetInterface.h b/ocs2_mpcnet/ocs2_ballbot_mpcnet/include/ocs2_ballbot_mpcnet/BallbotMpcnetInterface.h new file mode 100644 index 000000000..4456b7ab5 --- /dev/null +++ b/ocs2_mpcnet/ocs2_ballbot_mpcnet/include/ocs2_ballbot_mpcnet/BallbotMpcnetInterface.h @@ -0,0 +1,66 @@ +/****************************************************************************** +Copyright (c) 2022, Farbod Farshidian. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +******************************************************************************/ + +#pragma once + +#include +#include + +namespace ocs2 { +namespace ballbot { + +/** + * Ballbot MPC-Net interface between C++ and Python. + */ +class BallbotMpcnetInterface final : public ocs2::mpcnet::MpcnetInterfaceBase { + public: + /** + * Constructor. + * @param [in] nDataGenerationThreads : Number of data generation threads. + * @param [in] nPolicyEvaluationThreads : Number of policy evaluation threads. + * @param [in] raisim : Whether to use RaiSim for the rollouts. + */ + BallbotMpcnetInterface(size_t nDataGenerationThreads, size_t nPolicyEvaluationThreads, bool raisim); + + /** + * Default destructor. + */ + ~BallbotMpcnetInterface() override = default; + + private: + /** + * Helper to get the MPC. + * @param [in] ballbotInterface : The ballbot interface. + * @return Pointer to the MPC. + */ + std::unique_ptr getMpc(BallbotInterface& ballbotInterface); +}; + +} // namespace ballbot +} // namespace ocs2 diff --git a/ocs2_mpcnet/ocs2_ballbot_mpcnet/launch/ballbot_mpcnet.launch b/ocs2_mpcnet/ocs2_ballbot_mpcnet/launch/ballbot_mpcnet.launch new file mode 100644 index 000000000..ca53ee47b --- /dev/null +++ b/ocs2_mpcnet/ocs2_ballbot_mpcnet/launch/ballbot_mpcnet.launch @@ -0,0 +1,20 @@ + + + + + + + + + + + + + + + + + + diff --git a/ocs2_mpcnet/ocs2_ballbot_mpcnet/package.xml b/ocs2_mpcnet/ocs2_ballbot_mpcnet/package.xml new file mode 100644 index 000000000..c2448bd1b --- /dev/null +++ b/ocs2_mpcnet/ocs2_ballbot_mpcnet/package.xml @@ -0,0 +1,22 @@ + + + ocs2_ballbot_mpcnet + 0.0.0 + The ocs2_ballbot_mpcnet package + + Alexander Reske + + Farbod Farshidian + Alexander Reske + + BSD-3 + + catkin + + cmake_clang_tools + + ocs2_ballbot + ocs2_ballbot_ros + ocs2_mpcnet_core + + diff --git a/ocs2_mpcnet/ocs2_ballbot_mpcnet/policy/ballbot.onnx b/ocs2_mpcnet/ocs2_ballbot_mpcnet/policy/ballbot.onnx new file mode 100644 index 000000000..56c3a7a8e Binary files /dev/null and b/ocs2_mpcnet/ocs2_ballbot_mpcnet/policy/ballbot.onnx differ diff --git a/ocs2_mpcnet/ocs2_ballbot_mpcnet/policy/ballbot.pt b/ocs2_mpcnet/ocs2_ballbot_mpcnet/policy/ballbot.pt new file mode 100644 index 000000000..06a0acb59 Binary files /dev/null and b/ocs2_mpcnet/ocs2_ballbot_mpcnet/policy/ballbot.pt differ diff --git a/ocs2_mpcnet/ocs2_ballbot_mpcnet/python/ocs2_ballbot_mpcnet/__init__.py b/ocs2_mpcnet/ocs2_ballbot_mpcnet/python/ocs2_ballbot_mpcnet/__init__.py new file mode 100644 index 000000000..d5c07a503 --- /dev/null +++ b/ocs2_mpcnet/ocs2_ballbot_mpcnet/python/ocs2_ballbot_mpcnet/__init__.py @@ -0,0 +1,2 @@ +from ocs2_ballbot_mpcnet.BallbotMpcnetPybindings import MpcnetInterface +from ocs2_ballbot_mpcnet.mpcnet import BallbotMpcnet diff --git a/ocs2_mpcnet/ocs2_ballbot_mpcnet/python/ocs2_ballbot_mpcnet/config/ballbot.yaml b/ocs2_mpcnet/ocs2_ballbot_mpcnet/python/ocs2_ballbot_mpcnet/config/ballbot.yaml new file mode 100644 index 000000000..1626c547c --- /dev/null +++ b/ocs2_mpcnet/ocs2_ballbot_mpcnet/python/ocs2_ballbot_mpcnet/config/ballbot.yaml @@ -0,0 +1,123 @@ +# +# general +# +# name of the robot +NAME: "ballbot" +# description of the training run +DESCRIPTION: "description" +# state dimension +STATE_DIM: 10 +# input dimension +INPUT_DIM: 3 +# target trajectories state dimension +TARGET_STATE_DIM: 10 +# target trajectories input dimension +TARGET_INPUT_DIM: 3 +# observation dimension +OBSERVATION_DIM: 10 +# action dimension +ACTION_DIM: 3 +# expert number +EXPERT_NUM: 1 +# default state +DEFAULT_STATE: + - 0.0 # pose x + - 0.0 # pose y + - 0.0 # pose yaw + - 0.0 # pose pitch + - 0.0 # pose roll + - 0.0 # twist x + - 0.0 # twist y + - 0.0 # twist yaw + - 0.0 # twist pitch + - 0.0 # twist roll +# default target state +DEFAULT_TARGET_STATE: + - 0.0 # pose x + - 0.0 # pose y + - 0.0 # pose yaw + - 0.0 # pose pitch + - 0.0 # pose roll + - 0.0 # twist x + - 0.0 # twist y + - 0.0 # twist yaw + - 0.0 # twist pitch + - 0.0 # twist roll +# +# loss +# +# epsilon to improve numerical stability of logs and denominators +EPSILON: 1.e-8 +# whether to cheat by adding the gating loss +CHEATING: False +# parameter to control the relative importance of both loss types +LAMBDA: 1.0 +# dictionary for the gating loss (assigns modes to experts responsible for the corresponding contact configuration) +EXPERT_FOR_MODE: + 0: 0 +# input cost for behavioral cloning +R: + - 2.0 # torque + - 2.0 # torque + - 2.0 # torque +# +# memory +# +# capacity of the memory +CAPACITY: 100000 +# +# policy +# +# observation scaling +OBSERVATION_SCALING: + - 1.0 # pose x + - 1.0 # pose y + - 1.0 # pose yaw + - 1.0 # pose pitch + - 1.0 # pose roll + - 1.0 # twist x + - 1.0 # twist y + - 1.0 # twist yaw + - 1.0 # twist pitch + - 1.0 # twist roll +# action scaling +ACTION_SCALING: + - 1.0 # torque + - 1.0 # torque + - 1.0 # torque +# +# rollout +# +# RaiSim or TimeTriggered rollout for data generation and policy evaluation +RAISIM: False +# settings for data generation +DATA_GENERATION_TIME_STEP: 0.1 +DATA_GENERATION_DURATION: 3.0 +DATA_GENERATION_DATA_DECIMATION: 1 +DATA_GENERATION_THREADS: 2 +DATA_GENERATION_TASKS: 10 +DATA_GENERATION_SAMPLES: 2 +DATA_GENERATION_SAMPLING_VARIANCE: + - 0.01 # pose x + - 0.01 # pose y + - 0.01745329251 # pose yaw: 1.0 / 180.0 * pi + - 0.01745329251 # pose pitch: 1.0 / 180.0 * pi + - 0.01745329251 # pose roll: 1.0 / 180.0 * pi + - 0.05 # twist x + - 0.05 # twist y + - 0.08726646259 # twist yaw: 5.0 / 180.0 * pi + - 0.08726646259 # twist pitch: 5.0 / 180.0 * pi + - 0.08726646259 # twist roll: 5.0 / 180.0 * pi +# settings for computing metrics +POLICY_EVALUATION_TIME_STEP: 0.1 +POLICY_EVALUATION_DURATION: 3.0 +POLICY_EVALUATION_THREADS: 1 +POLICY_EVALUATION_TASKS: 5 +# +# training +# +BATCH_SIZE: 32 +LEARNING_RATE: 1.e-2 +LEARNING_ITERATIONS: 10000 +GRADIENT_CLIPPING: False +GRADIENT_CLIPPING_VALUE: 1.0 diff --git a/ocs2_mpcnet/ocs2_ballbot_mpcnet/python/ocs2_ballbot_mpcnet/mpcnet.py b/ocs2_mpcnet/ocs2_ballbot_mpcnet/python/ocs2_ballbot_mpcnet/mpcnet.py new file mode 100644 index 000000000..86b8fe743 --- /dev/null +++ b/ocs2_mpcnet/ocs2_ballbot_mpcnet/python/ocs2_ballbot_mpcnet/mpcnet.py @@ -0,0 +1,137 @@ +############################################################################### +# Copyright (c) 2022, Farbod Farshidian. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +############################################################################### + +"""Ballbot MPC-Net class. + +Provides a class that handles the MPC-Net training for ballbot. +""" + +import numpy as np +from typing import Tuple + +from ocs2_mpcnet_core import helper +from ocs2_mpcnet_core import mpcnet +from ocs2_mpcnet_core import SystemObservationArray, ModeScheduleArray, TargetTrajectoriesArray + + +class BallbotMpcnet(mpcnet.Mpcnet): + """Ballbot MPC-Net. + + Adds robot-specific methods for the MPC-Net training. + """ + + @staticmethod + def get_default_event_times_and_mode_sequence(duration: float) -> Tuple[np.ndarray, np.ndarray]: + """Get the event times and mode sequence describing the default mode schedule. + + Creates the default event times and mode sequence for a certain time duration. + + Args: + duration: The duration of the mode schedule given by a float. + + Returns: + A tuple containing the components of the mode schedule. + - event_times: The event times given by a NumPy array of shape (K-1) containing floats. + - mode_sequence: The mode sequence given by a NumPy array of shape (K) containing integers. + """ + event_times_template = np.array([1.0], dtype=np.float64) + mode_sequence_template = np.array([0], dtype=np.uintp) + return helper.get_event_times_and_mode_sequence(0, duration, event_times_template, mode_sequence_template) + + def get_random_initial_state(self) -> np.ndarray: + """Get a random initial state. + + Samples a random initial state for the robot. + + Returns: + x: A random initial state given by a NumPy array containing floats. + """ + max_linear_velocity_x = 0.5 + max_linear_velocity_y = 0.5 + max_euler_angle_derivative_z = 45.0 / 180.0 * np.pi + max_euler_angle_derivative_y = 45.0 / 180.0 * np.pi + max_euler_angle_derivative_x = 45.0 / 180.0 * np.pi + random_deviation = np.zeros(self.config.STATE_DIM) + random_deviation[5] = np.random.uniform(-max_linear_velocity_x, max_linear_velocity_x) + random_deviation[6] = np.random.uniform(-max_linear_velocity_y, max_linear_velocity_y) + random_deviation[7] = np.random.uniform(-max_euler_angle_derivative_z, max_euler_angle_derivative_z) + random_deviation[8] = np.random.uniform(-max_euler_angle_derivative_y, max_euler_angle_derivative_y) + random_deviation[9] = np.random.uniform(-max_euler_angle_derivative_x, max_euler_angle_derivative_x) + return np.array(self.config.DEFAULT_STATE) + random_deviation + + def get_random_target_state(self) -> np.ndarray: + """Get a random target state. + + Samples a random target state for the robot. + + Returns: + x: A random target state given by a NumPy array containing floats. + """ + max_position_x = 1.0 + max_position_y = 1.0 + max_orientation_z = 45.0 / 180.0 * np.pi + random_deviation = np.zeros(self.config.TARGET_STATE_DIM) + random_deviation[0] = np.random.uniform(-max_position_x, max_position_x) + random_deviation[1] = np.random.uniform(-max_position_y, max_position_y) + random_deviation[2] = np.random.uniform(-max_orientation_z, max_orientation_z) + return np.array(self.config.DEFAULT_TARGET_STATE) + random_deviation + + def get_tasks( + self, tasks_number: int, duration: float + ) -> Tuple[SystemObservationArray, ModeScheduleArray, TargetTrajectoriesArray]: + """Get tasks. + + Get a random set of task that should be executed by the data generation or policy evaluation. + + Args: + tasks_number: Number of tasks given by an integer. + duration: Duration of each task given by a float. + + Returns: + A tuple containing the components of the task. + - initial_observations: The initial observations given by an OCS2 system observation array. + - mode_schedules: The desired mode schedules given by an OCS2 mode schedule array. + - target_trajectories: The desired target trajectories given by an OCS2 target trajectories array. + """ + initial_mode = 0 + initial_time = 0.0 + initial_observations = helper.get_system_observation_array(tasks_number) + mode_schedules = helper.get_mode_schedule_array(tasks_number) + target_trajectories = helper.get_target_trajectories_array(tasks_number) + for i in range(tasks_number): + initial_observations[i] = helper.get_system_observation( + initial_mode, initial_time, self.get_random_initial_state(), np.zeros(self.config.INPUT_DIM) + ) + mode_schedules[i] = helper.get_mode_schedule(*self.get_default_event_times_and_mode_sequence(duration)) + target_trajectories[i] = helper.get_target_trajectories( + duration * np.ones((1, 1)), + self.get_random_target_state().reshape((1, self.config.TARGET_STATE_DIM)), + np.zeros((1, self.config.TARGET_INPUT_DIM)), + ) + return initial_observations, mode_schedules, target_trajectories diff --git a/ocs2_mpcnet/ocs2_ballbot_mpcnet/python/ocs2_ballbot_mpcnet/train.py b/ocs2_mpcnet/ocs2_ballbot_mpcnet/python/ocs2_ballbot_mpcnet/train.py new file mode 100644 index 000000000..4fa5c59ce --- /dev/null +++ b/ocs2_mpcnet/ocs2_ballbot_mpcnet/python/ocs2_ballbot_mpcnet/train.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 + +############################################################################### +# Copyright (c) 2022, Farbod Farshidian. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +############################################################################### + +"""Ballbot MPC-Net. + +Main script for training an MPC-Net policy for ballbot. +""" + +import sys +import os + +from ocs2_mpcnet_core.config import Config +from ocs2_mpcnet_core.loss import HamiltonianLoss +from ocs2_mpcnet_core.memory import CircularMemory +from ocs2_mpcnet_core.policy import LinearPolicy + +from ocs2_ballbot_mpcnet import BallbotMpcnet +from ocs2_ballbot_mpcnet import MpcnetInterface + + +def main(root_dir: str, config_file_name: str) -> None: + # config + config = Config(os.path.join(root_dir, "config", config_file_name)) + # interface + interface = MpcnetInterface(config.DATA_GENERATION_THREADS, config.POLICY_EVALUATION_THREADS, config.RAISIM) + # loss + loss = HamiltonianLoss(config) + # memory + memory = CircularMemory(config) + # policy + policy = LinearPolicy(config) + # mpcnet + mpcnet = BallbotMpcnet(root_dir, config, interface, memory, policy, loss) + # train + mpcnet.train() + + +if __name__ == "__main__": + root_dir = os.path.dirname(os.path.abspath(__file__)) + if len(sys.argv) > 1: + main(root_dir, sys.argv[1]) + else: + main(root_dir, "ballbot.yaml") diff --git a/ocs2_mpcnet/ocs2_ballbot_mpcnet/setup.py b/ocs2_mpcnet/ocs2_ballbot_mpcnet/setup.py new file mode 100644 index 000000000..959f0cad9 --- /dev/null +++ b/ocs2_mpcnet/ocs2_ballbot_mpcnet/setup.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python + +from setuptools import setup +from catkin_pkg.python_setup import generate_distutils_setup + +setup_args = generate_distutils_setup(packages=["ocs2_ballbot_mpcnet"], package_dir={"": "python"}) + +setup(**setup_args) diff --git a/ocs2_mpcnet/ocs2_ballbot_mpcnet/src/BallbotMpcnetDefinition.cpp b/ocs2_mpcnet/ocs2_ballbot_mpcnet/src/BallbotMpcnetDefinition.cpp new file mode 100644 index 000000000..81ec0bc06 --- /dev/null +++ b/ocs2_mpcnet/ocs2_ballbot_mpcnet/src/BallbotMpcnetDefinition.cpp @@ -0,0 +1,57 @@ +/****************************************************************************** +Copyright (c) 2022, Farbod Farshidian. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +******************************************************************************/ + +#include "ocs2_ballbot_mpcnet/BallbotMpcnetDefinition.h" + +namespace ocs2 { +namespace ballbot { + +vector_t BallbotMpcnetDefinition::getObservation(scalar_t t, const vector_t& x, const ModeSchedule& modeSchedule, + const TargetTrajectories& targetTrajectories) { + vector_t observation = x - targetTrajectories.getDesiredState(t); + const Eigen::Matrix R = + (Eigen::Matrix() << cos(x(2)), -sin(x(2)), sin(x(2)), cos(x(2))).finished().transpose(); + observation.segment<2>(0) = R * observation.segment<2>(0); + observation.segment<2>(5) = R * observation.segment<2>(5); + return observation; +} + +std::pair BallbotMpcnetDefinition::getActionTransformation(scalar_t t, const vector_t& x, + const ModeSchedule& modeSchedule, + const TargetTrajectories& targetTrajectories) { + return {matrix_t::Identity(3, 3), vector_t::Zero(3)}; +} + +bool BallbotMpcnetDefinition::isValid(scalar_t t, const vector_t& x, const ModeSchedule& modeSchedule, + const TargetTrajectories& targetTrajectories) { + return true; +} + +} // namespace ballbot +} // namespace ocs2 diff --git a/ocs2_mpcnet/ocs2_ballbot_mpcnet/src/BallbotMpcnetDummyNode.cpp b/ocs2_mpcnet/ocs2_ballbot_mpcnet/src/BallbotMpcnetDummyNode.cpp new file mode 100644 index 000000000..010fc2a8b --- /dev/null +++ b/ocs2_mpcnet/ocs2_ballbot_mpcnet/src/BallbotMpcnetDummyNode.cpp @@ -0,0 +1,110 @@ +/****************************************************************************** +Copyright (c) 2022, Farbod Farshidian. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +******************************************************************************/ + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "ocs2_ballbot_mpcnet/BallbotMpcnetDefinition.h" + +using namespace ocs2; +using namespace ballbot; + +int main(int argc, char** argv) { + const std::string robotName = "ballbot"; + + // task and policy file + std::vector programArgs{}; + ::ros::removeROSArgs(argc, argv, programArgs); + if (programArgs.size() <= 2) { + throw std::runtime_error("No task name and policy file path specified. Aborting."); + } + const std::string taskFileFolderName = std::string(programArgs[1]); + const std::string policyFilePath = std::string(programArgs[2]); + + // initialize ros node + ros::init(argc, argv, robotName + "_mpcnet_dummy"); + ros::NodeHandle nodeHandle; + + // ballbot interface + const std::string taskFile = ros::package::getPath("ocs2_ballbot") + "/config/" + taskFileFolderName + "/task.info"; + const std::string libraryFolder = ros::package::getPath("ocs2_ballbot") + "/auto_generated"; + BallbotInterface ballbotInterface(taskFile, libraryFolder); + + // ROS reference manager + auto rosReferenceManagerPtr = std::make_shared(robotName, ballbotInterface.getReferenceManagerPtr()); + rosReferenceManagerPtr->subscribe(nodeHandle); + + // policy (MPC-Net controller) + auto onnxEnvironmentPtr = ocs2::mpcnet::createOnnxEnvironment(); + std::shared_ptr mpcnetDefinitionPtr(new BallbotMpcnetDefinition); + std::unique_ptr mpcnetControllerPtr( + new ocs2::mpcnet::MpcnetOnnxController(mpcnetDefinitionPtr, rosReferenceManagerPtr, onnxEnvironmentPtr)); + mpcnetControllerPtr->loadPolicyModel(policyFilePath); + + // rollout + std::unique_ptr rolloutPtr(ballbotInterface.getRollout().clone()); + + // observer + std::shared_ptr mpcnetDummyObserverRosPtr( + new ocs2::mpcnet::MpcnetDummyObserverRos(nodeHandle, robotName)); + + // visualization + std::shared_ptr ballbotDummyVisualization(new BallbotDummyVisualization(nodeHandle)); + + // MPC-Net dummy loop ROS + const scalar_t controlFrequency = ballbotInterface.mpcSettings().mrtDesiredFrequency_; + const scalar_t rosFrequency = ballbotInterface.mpcSettings().mpcDesiredFrequency_; + ocs2::mpcnet::MpcnetDummyLoopRos mpcnetDummyLoopRos(controlFrequency, rosFrequency, std::move(mpcnetControllerPtr), std::move(rolloutPtr), + rosReferenceManagerPtr); + mpcnetDummyLoopRos.addObserver(mpcnetDummyObserverRosPtr); + mpcnetDummyLoopRos.addObserver(ballbotDummyVisualization); + + // initial system observation + SystemObservation systemObservation; + systemObservation.mode = 0; + systemObservation.time = 0.0; + systemObservation.state = ballbotInterface.getInitialState(); + systemObservation.input = vector_t::Zero(ocs2::ballbot::INPUT_DIM); + + // initial target trajectories + const TargetTrajectories targetTrajectories({systemObservation.time}, {systemObservation.state}, {systemObservation.input}); + + // run MPC-Net dummy loop ROS + mpcnetDummyLoopRos.run(systemObservation, targetTrajectories); + + // successful exit + return 0; +} \ No newline at end of file diff --git a/ocs2_mpcnet/ocs2_ballbot_mpcnet/src/BallbotMpcnetInterface.cpp b/ocs2_mpcnet/ocs2_ballbot_mpcnet/src/BallbotMpcnetInterface.cpp new file mode 100644 index 000000000..dfe15a9a5 --- /dev/null +++ b/ocs2_mpcnet/ocs2_ballbot_mpcnet/src/BallbotMpcnetInterface.cpp @@ -0,0 +1,111 @@ +/****************************************************************************** +Copyright (c) 2022, Farbod Farshidian. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +******************************************************************************/ + +#include "ocs2_ballbot_mpcnet/BallbotMpcnetInterface.h" + +#include + +#include +#include +#include +#include + +#include "ocs2_ballbot_mpcnet/BallbotMpcnetDefinition.h" + +namespace ocs2 { +namespace ballbot { + +BallbotMpcnetInterface::BallbotMpcnetInterface(size_t nDataGenerationThreads, size_t nPolicyEvaluationThreads, bool raisim) { + // create ONNX environment + auto onnxEnvironmentPtr = ocs2::mpcnet::createOnnxEnvironment(); + // path to config file + const std::string taskFile = ros::package::getPath("ocs2_ballbot") + "/config/mpc/task.info"; + // path to save auto-generated libraries + const std::string libraryFolder = ros::package::getPath("ocs2_ballbot") + "/auto_generated"; + // set up MPC-Net rollout manager for data generation and policy evaluation + std::vector> mpcPtrs; + std::vector> mpcnetPtrs; + std::vector> rolloutPtrs; + std::vector> mpcnetDefinitionPtrs; + std::vector> referenceManagerPtrs; + mpcPtrs.reserve(nDataGenerationThreads + nPolicyEvaluationThreads); + mpcnetPtrs.reserve(nDataGenerationThreads + nPolicyEvaluationThreads); + rolloutPtrs.reserve(nDataGenerationThreads + nPolicyEvaluationThreads); + mpcnetDefinitionPtrs.reserve(nDataGenerationThreads + nPolicyEvaluationThreads); + referenceManagerPtrs.reserve(nDataGenerationThreads + nPolicyEvaluationThreads); + for (int i = 0; i < (nDataGenerationThreads + nPolicyEvaluationThreads); i++) { + BallbotInterface ballbotInterface(taskFile, libraryFolder); + std::shared_ptr mpcnetDefinitionPtr(new BallbotMpcnetDefinition); + mpcPtrs.push_back(getMpc(ballbotInterface)); + mpcnetPtrs.push_back(std::unique_ptr( + new ocs2::mpcnet::MpcnetOnnxController(mpcnetDefinitionPtr, ballbotInterface.getReferenceManagerPtr(), onnxEnvironmentPtr))); + if (raisim) { + throw std::runtime_error("[BallbotMpcnetInterface::BallbotMpcnetInterface] raisim rollout not yet implemented for ballbot."); + } else { + rolloutPtrs.push_back(std::unique_ptr(ballbotInterface.getRollout().clone())); + } + mpcnetDefinitionPtrs.push_back(mpcnetDefinitionPtr); + referenceManagerPtrs.push_back(ballbotInterface.getReferenceManagerPtr()); + } + mpcnetRolloutManagerPtr_.reset(new ocs2::mpcnet::MpcnetRolloutManager(nDataGenerationThreads, nPolicyEvaluationThreads, + std::move(mpcPtrs), std::move(mpcnetPtrs), std::move(rolloutPtrs), + mpcnetDefinitionPtrs, referenceManagerPtrs)); +} + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +std::unique_ptr BallbotMpcnetInterface::getMpc(BallbotInterface& ballbotInterface) { + // ensure MPC and DDP settings are as needed for MPC-Net + const auto mpcSettings = [&]() { + auto settings = ballbotInterface.mpcSettings(); + settings.debugPrint_ = false; + settings.coldStart_ = false; + return settings; + }(); + const auto ddpSettings = [&]() { + auto settings = ballbotInterface.ddpSettings(); + settings.algorithm_ = ocs2::ddp::Algorithm::SLQ; + settings.nThreads_ = 1; + settings.displayInfo_ = false; + settings.displayShortSummary_ = false; + settings.checkNumericalStability_ = false; + settings.debugPrintRollout_ = false; + settings.useFeedbackPolicy_ = true; + return settings; + }(); + // create one MPC instance + std::unique_ptr mpcPtr(new GaussNewtonDDP_MPC(mpcSettings, ddpSettings, ballbotInterface.getRollout(), + ballbotInterface.getOptimalControlProblem(), ballbotInterface.getInitializer())); + mpcPtr->getSolverPtr()->setReferenceManager(ballbotInterface.getReferenceManagerPtr()); + return mpcPtr; +} + +} // namespace ballbot +} // namespace ocs2 diff --git a/ocs2_mpcnet/ocs2_ballbot_mpcnet/src/BallbotMpcnetPybindings.cpp b/ocs2_mpcnet/ocs2_ballbot_mpcnet/src/BallbotMpcnetPybindings.cpp new file mode 100644 index 000000000..c08e80a5d --- /dev/null +++ b/ocs2_mpcnet/ocs2_ballbot_mpcnet/src/BallbotMpcnetPybindings.cpp @@ -0,0 +1,34 @@ +/****************************************************************************** +Copyright (c) 2022, Farbod Farshidian. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +******************************************************************************/ + +#include + +#include "ocs2_ballbot_mpcnet/BallbotMpcnetInterface.h" + +CREATE_ROBOT_MPCNET_PYTHON_BINDINGS(ocs2::ballbot::BallbotMpcnetInterface, BallbotMpcnetPybindings) diff --git a/ocs2_mpcnet/ocs2_legged_robot_mpcnet/CMakeLists.txt b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/CMakeLists.txt new file mode 100644 index 000000000..6addac114 --- /dev/null +++ b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/CMakeLists.txt @@ -0,0 +1,129 @@ +cmake_minimum_required(VERSION 3.0.2) +project(ocs2_legged_robot_mpcnet) + +set(CATKIN_PACKAGE_DEPENDENCIES + ocs2_legged_robot + ocs2_legged_robot_raisim + ocs2_legged_robot_ros + ocs2_mpcnet_core +) + +find_package(catkin REQUIRED COMPONENTS + ${CATKIN_PACKAGE_DEPENDENCIES} +) + +# Generate compile_commands.json for clang tools +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +################################### +## catkin specific configuration ## +################################### + +catkin_package( + INCLUDE_DIRS + include + LIBRARIES + ${PROJECT_NAME} + CATKIN_DEPENDS + ${CATKIN_PACKAGE_DEPENDENCIES} + DEPENDS +) + +########### +## Build ## +########### + +include_directories( + include + ${catkin_INCLUDE_DIRS} +) + +# main library +add_library(${PROJECT_NAME} + src/LeggedRobotMpcnetDefinition.cpp + src/LeggedRobotMpcnetInterface.cpp +) +add_dependencies(${PROJECT_NAME} + ${catkin_EXPORTED_TARGETS} +) +target_link_libraries(${PROJECT_NAME} + ${catkin_LIBRARIES} +) + +# python bindings +pybind11_add_module(LeggedRobotMpcnetPybindings SHARED + src/LeggedRobotMpcnetPybindings.cpp +) +add_dependencies(LeggedRobotMpcnetPybindings + ${PROJECT_NAME} + ${catkin_EXPORTED_TARGETS} +) +target_link_libraries(LeggedRobotMpcnetPybindings PRIVATE + ${PROJECT_NAME} + ${catkin_LIBRARIES} +) +set_target_properties(LeggedRobotMpcnetPybindings + PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CATKIN_DEVEL_PREFIX}/${CATKIN_PACKAGE_PYTHON_DESTINATION} +) + +# MPC-Net dummy node +add_executable(legged_robot_mpcnet_dummy + src/LeggedRobotMpcnetDummyNode.cpp +) +add_dependencies(legged_robot_mpcnet_dummy + ${PROJECT_NAME} + ${catkin_EXPORTED_TARGETS} +) +target_link_libraries(legged_robot_mpcnet_dummy + ${PROJECT_NAME} + ${catkin_LIBRARIES} +) +target_compile_options(legged_robot_mpcnet_dummy PRIVATE ${OCS2_CXX_FLAGS}) + +catkin_python_setup() + +######################### +### CLANG TOOLING ### +######################### +find_package(cmake_clang_tools QUIET) +if(cmake_clang_tools_FOUND) + message(STATUS "Run clang tooling for target ocs2_legged_robot_mpcnet") + add_clang_tooling( + TARGETS ${PROJECT_NAME} legged_robot_mpcnet_dummy + SOURCE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/src ${CMAKE_CURRENT_SOURCE_DIR}/include + CT_HEADER_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/include + CF_WERROR +) +endif(cmake_clang_tools_FOUND) + +############# +## Install ## +############# + +install(TARGETS ${PROJECT_NAME} + ARCHIVE DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} + LIBRARY DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} + RUNTIME DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} +) + +install(DIRECTORY include/${PROJECT_NAME}/ + DESTINATION ${CATKIN_PACKAGE_INCLUDE_DESTINATION}) + +install(TARGETS LeggedRobotMpcnetPybindings + ARCHIVE DESTINATION ${CATKIN_PACKAGE_PYTHON_DESTINATION} + LIBRARY DESTINATION ${CATKIN_PACKAGE_PYTHON_DESTINATION} +) + +install(TARGETS legged_robot_mpcnet_dummy + RUNTIME DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} +) + +install(DIRECTORY launch policy + DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION} +) + +############# +## Testing ## +############# + +# TODO(areske) diff --git a/ocs2_mpcnet/ocs2_legged_robot_mpcnet/include/ocs2_legged_robot_mpcnet/LeggedRobotMpcnetDefinition.h b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/include/ocs2_legged_robot_mpcnet/LeggedRobotMpcnetDefinition.h new file mode 100644 index 000000000..8a6b44f92 --- /dev/null +++ b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/include/ocs2_legged_robot_mpcnet/LeggedRobotMpcnetDefinition.h @@ -0,0 +1,81 @@ +/****************************************************************************** +Copyright (c) 2022, Farbod Farshidian. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +******************************************************************************/ + +#pragma once + +#include +#include + +namespace ocs2 { +namespace legged_robot { + +/** + * MPC-Net definitions for legged robot. + */ +class LeggedRobotMpcnetDefinition final : public ocs2::mpcnet::MpcnetDefinitionBase { + public: + /** + * Constructor. + * @param [in] leggedRobotInterface : Legged robot interface. + */ + LeggedRobotMpcnetDefinition(const LeggedRobotInterface& leggedRobotInterface) + : defaultState_(leggedRobotInterface.getInitialState()), centroidalModelInfo_(leggedRobotInterface.getCentroidalModelInfo()) {} + + /** + * Default destructor. + */ + ~LeggedRobotMpcnetDefinition() override = default; + + /** + * @see MpcnetDefinitionBase::getObservation + */ + vector_t getObservation(scalar_t t, const vector_t& x, const ModeSchedule& modeSchedule, + const TargetTrajectories& targetTrajectories) override; + + /** + * @see MpcnetDefinitionBase::getActionTransformation + */ + std::pair getActionTransformation(scalar_t t, const vector_t& x, const ModeSchedule& modeSchedule, + const TargetTrajectories& targetTrajectories) override; + + /** + * @see MpcnetDefinitionBase::isValid + */ + bool isValid(scalar_t t, const vector_t& x, const ModeSchedule& modeSchedule, const TargetTrajectories& targetTrajectories) override; + + private: + const scalar_t allowedHeightDeviation_ = 0.2; + const scalar_t allowedPitchDeviation_ = 30.0 * M_PI / 180.0; + const scalar_t allowedRollDeviation_ = 30.0 * M_PI / 180.0; + const vector_t defaultState_; + const CentroidalModelInfo centroidalModelInfo_; +}; + +} // namespace legged_robot +} // namespace ocs2 diff --git a/ocs2_mpcnet/ocs2_legged_robot_mpcnet/include/ocs2_legged_robot_mpcnet/LeggedRobotMpcnetInterface.h b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/include/ocs2_legged_robot_mpcnet/LeggedRobotMpcnetInterface.h new file mode 100644 index 000000000..3acb1d0ee --- /dev/null +++ b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/include/ocs2_legged_robot_mpcnet/LeggedRobotMpcnetInterface.h @@ -0,0 +1,72 @@ +/****************************************************************************** +Copyright (c) 2022, Farbod Farshidian. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +******************************************************************************/ + +#pragma once + +#include +#include +#include + +namespace ocs2 { +namespace legged_robot { + +/** + * Legged robot MPC-Net interface between C++ and Python. + */ +class LeggedRobotMpcnetInterface final : public ocs2::mpcnet::MpcnetInterfaceBase { + public: + /** + * Constructor. + * @param [in] nDataGenerationThreads : Number of data generation threads. + * @param [in] nPolicyEvaluationThreads : Number of policy evaluation threads. + * @param [in] raisim : Whether to use RaiSim for the rollouts. + */ + LeggedRobotMpcnetInterface(size_t nDataGenerationThreads, size_t nPolicyEvaluationThreads, bool raisim); + + /** + * Default destructor. + */ + ~LeggedRobotMpcnetInterface() override = default; + + private: + /** + * Helper to get the MPC. + * @param [in] leggedRobotInterface : The legged robot interface. + * @return Pointer to the MPC. + */ + std::unique_ptr getMpc(LeggedRobotInterface& leggedRobotInterface); + + // Legged robot interface pointers (keep alive for Pinocchio interface) + std::vector> leggedRobotInterfacePtrs_; + // Legged robot RaiSim conversions pointers (keep alive for RaiSim rollout) + std::vector> leggedRobotRaisimConversionsPtrs_; +}; + +} // namespace legged_robot +} // namespace ocs2 diff --git a/ocs2_mpcnet/ocs2_legged_robot_mpcnet/launch/legged_robot_mpcnet.launch b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/launch/legged_robot_mpcnet.launch new file mode 100644 index 000000000..59311840a --- /dev/null +++ b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/launch/legged_robot_mpcnet.launch @@ -0,0 +1,56 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/ocs2_mpcnet/ocs2_legged_robot_mpcnet/package.xml b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/package.xml new file mode 100644 index 000000000..c39765d14 --- /dev/null +++ b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/package.xml @@ -0,0 +1,23 @@ + + + ocs2_legged_robot_mpcnet + 0.0.0 + The ocs2_legged_robot_mpcnet package + + Alexander Reske + + Farbod Farshidian + Alexander Reske + + BSD-3 + + catkin + + cmake_clang_tools + + ocs2_legged_robot + ocs2_legged_robot_raisim + ocs2_legged_robot_ros + ocs2_mpcnet_core + + diff --git a/ocs2_mpcnet/ocs2_legged_robot_mpcnet/policy/legged_robot.onnx b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/policy/legged_robot.onnx new file mode 100644 index 000000000..ef5939cdf Binary files /dev/null and b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/policy/legged_robot.onnx differ diff --git a/ocs2_mpcnet/ocs2_legged_robot_mpcnet/policy/legged_robot.pt b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/policy/legged_robot.pt new file mode 100644 index 000000000..3eb2df09d Binary files /dev/null and b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/policy/legged_robot.pt differ diff --git a/ocs2_mpcnet/ocs2_legged_robot_mpcnet/python/ocs2_legged_robot_mpcnet/__init__.py b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/python/ocs2_legged_robot_mpcnet/__init__.py new file mode 100644 index 000000000..688527a67 --- /dev/null +++ b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/python/ocs2_legged_robot_mpcnet/__init__.py @@ -0,0 +1,2 @@ +from ocs2_legged_robot_mpcnet.LeggedRobotMpcnetPybindings import MpcnetInterface +from ocs2_legged_robot_mpcnet.mpcnet import LeggedRobotMpcnet diff --git a/ocs2_mpcnet/ocs2_legged_robot_mpcnet/python/ocs2_legged_robot_mpcnet/config/legged_robot.yaml b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/python/ocs2_legged_robot_mpcnet/config/legged_robot.yaml new file mode 100644 index 000000000..0b6dde905 --- /dev/null +++ b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/python/ocs2_legged_robot_mpcnet/config/legged_robot.yaml @@ -0,0 +1,240 @@ +# +# general +# +# name of the robot +NAME: "legged_robot" +# description of the training run +DESCRIPTION: "description" +# state dimension +STATE_DIM: 24 +# input dimension +INPUT_DIM: 24 +# target trajectories state dimension +TARGET_STATE_DIM: 24 +# target trajectories input dimension +TARGET_INPUT_DIM: 24 +# observation dimension +OBSERVATION_DIM: 36 +# action dimension +ACTION_DIM: 24 +# expert number +EXPERT_NUM: 3 +# default state +DEFAULT_STATE: + - 0.0 # normalized linear momentum x + - 0.0 # normalized linear momentum y + - 0.0 # normalized linear momentum z + - 0.0 # normalized angular momentum x + - 0.0 # normalized angular momentum y + - 0.0 # normalized angular momentum z + - 0.0 # position x + - 0.0 # position y + - 0.575 # position z + - 0.0 # orientation z + - 0.0 # orientation y + - 0.0 # orientation x + - -0.25 # joint position LF HAA + - 0.6 # joint position LF HFE + - -0.85 # joint position LF KFE + - -0.25 # joint position LH HAA + - -0.6 # joint position LH HFE + - 0.85 # joint position LH KFE + - 0.25 # joint position RF HAA + - 0.6 # joint position RF HFE + - -0.85 # joint position RF KFE + - 0.25 # joint position RH HAA + - -0.6 # joint position RH HFE + - 0.85 # joint position RH KFE +# default target state +DEFAULT_TARGET_STATE: + - 0.0 # normalized linear momentum x + - 0.0 # normalized linear momentum y + - 0.0 # normalized linear momentum z + - 0.0 # normalized angular momentum x + - 0.0 # normalized angular momentum y + - 0.0 # normalized angular momentum z + - 0.0 # position x + - 0.0 # position y + - 0.575 # position z + - 0.0 # orientation z + - 0.0 # orientation y + - 0.0 # orientation x + - -0.25 # joint position LF HAA + - 0.6 # joint position LF HFE + - -0.85 # joint position LF KFE + - -0.25 # joint position LH HAA + - -0.6 # joint position LH HFE + - 0.85 # joint position LH KFE + - 0.25 # joint position RF HAA + - 0.6 # joint position RF HFE + - -0.85 # joint position RF KFE + - 0.25 # joint position RH HAA + - -0.6 # joint position RH HFE + - 0.85 # joint position RH KFE +# +# loss +# +# epsilon to improve numerical stability of logs and denominators +EPSILON: 1.e-8 +# whether to cheat by adding the gating loss +CHEATING: True +# parameter to control the relative importance of both loss types +LAMBDA: 1.0 +# dictionary for the gating loss (assigns modes to experts responsible for the corresponding contact configuration) +EXPERT_FOR_MODE: + 6: 1 # trot + 9: 2 # trot + 15: 0 # stance +# input cost for behavioral cloning +R: + - 0.001 # contact force LF x + - 0.001 # contact force LF y + - 0.001 # contact force LF z + - 0.001 # contact force LH x + - 0.001 # contact force LH y + - 0.001 # contact force LH z + - 0.001 # contact force RF x + - 0.001 # contact force RF y + - 0.001 # contact force RF z + - 0.001 # contact force RH x + - 0.001 # contact force RH y + - 0.001 # contact force RH z + - 5.0 # joint velocity LF HAA + - 5.0 # joint velocity LF HFE + - 5.0 # joint velocity LF KFE + - 5.0 # joint velocity LH HAA + - 5.0 # joint velocity LH HFE + - 5.0 # joint velocity LH KFE + - 5.0 # joint velocity RF HAA + - 5.0 # joint velocity RF HFE + - 5.0 # joint velocity RF KFE + - 5.0 # joint velocity RH HAA + - 5.0 # joint velocity RH HFE + - 5.0 # joint velocity RH KFE +# +# memory +# +# capacity of the memory +CAPACITY: 400000 +# +# policy +# +# observation scaling +OBSERVATION_SCALING: + - 1.0 # swing phase LF + - 1.0 # swing phase LH + - 1.0 # swing phase RF + - 1.0 # swing phase RH + - 1.0 # swing phase rate LF + - 1.0 # swing phase rate LH + - 1.0 # swing phase rate RF + - 1.0 # swing phase rate RH + - 1.0 # sinusoidal bump LF + - 1.0 # sinusoidal bump LH + - 1.0 # sinusoidal bump RF + - 1.0 # sinusoidal bump RH + - 1.0 # normalized linear momentum x + - 1.0 # normalized linear momentum y + - 1.0 # normalized linear momentum z + - 1.0 # normalized angular momentum x + - 1.0 # normalized angular momentum y + - 1.0 # normalized angular momentum z + - 1.0 # position x + - 1.0 # position y + - 1.0 # position z + - 1.0 # orientation z + - 1.0 # orientation y + - 1.0 # orientation x + - 1.0 # joint position LF HAA + - 1.0 # joint position LF HFE + - 1.0 # joint position LF KFE + - 1.0 # joint position LH HAA + - 1.0 # joint position LH HFE + - 1.0 # joint position LH KFE + - 1.0 # joint position RF HAA + - 1.0 # joint position RF HFE + - 1.0 # joint position RF KFE + - 1.0 # joint position RH HAA + - 1.0 # joint position RH HFE + - 1.0 # joint position RH KFE +# action scaling +ACTION_SCALING: + - 100.0 # contact force LF x + - 100.0 # contact force LF y + - 100.0 # contact force LF z + - 100.0 # contact force LH x + - 100.0 # contact force LH y + - 100.0 # contact force LH z + - 100.0 # contact force RF x + - 100.0 # contact force RF y + - 100.0 # contact force RF z + - 100.0 # contact force RH x + - 100.0 # contact force RH y + - 100.0 # contact force RH z + - 10.0 # joint velocity LF HAA + - 10.0 # joint velocity LF HFE + - 10.0 # joint velocity LF KFE + - 10.0 # joint velocity LH HAA + - 10.0 # joint velocity LH HFE + - 10.0 # joint velocity LH KFE + - 10.0 # joint velocity RF HAA + - 10.0 # joint velocity RF HFE + - 10.0 # joint velocity RF KFE + - 10.0 # joint velocity RH HAA + - 10.0 # joint velocity RH HFE + - 10.0 # joint velocity RH KFE +# +# rollout +# +# RaiSim or TimeTriggered rollout for data generation and policy evaluation +RAISIM: True +# weights defining how often a gait is chosen for rollout +WEIGHTS_FOR_GAITS: + stance: 1.0 + trot_1: 2.0 + trot_2: 2.0 +# settings for data generation +DATA_GENERATION_TIME_STEP: 0.0025 +DATA_GENERATION_DURATION: 4.0 +DATA_GENERATION_DATA_DECIMATION: 4 +DATA_GENERATION_THREADS: 12 +DATA_GENERATION_TASKS: 12 +DATA_GENERATION_SAMPLES: 2 +DATA_GENERATION_SAMPLING_VARIANCE: + - 0.05 # normalized linear momentum x + - 0.05 # normalized linear momentum y + - 0.05 # normalized linear momentum z + - 0.00135648942 # normalized angular momentum x: 1.62079 / 52.1348 * 2.5 / 180.0 * pi + - 0.00404705526 # normalized angular momentum y: 4.83559 / 52.1348 * 2.5 / 180.0 * pi + - 0.00395351148 # normalized angular momentum z: 4.72382 / 52.1348 * 2.5 / 180.0 * pi + - 0.01 # position x + - 0.01 # position y + - 0.01 # position z + - 0.00872664625 # orientation z: 0.5 / 180.0 * pi + - 0.00872664625 # orientation y: 0.5 / 180.0 * pi + - 0.00872664625 # orientation x: 0.5 / 180.0 * pi + - 0.00872664625 # joint position LF HAA: 0.5 / 180.0 * pi + - 0.00872664625 # joint position LF HFE: 0.5 / 180.0 * pi + - 0.00872664625 # joint position LF KFE: 0.5 / 180.0 * pi + - 0.00872664625 # joint position LH HAA: 0.5 / 180.0 * pi + - 0.00872664625 # joint position LH HFE: 0.5 / 180.0 * pi + - 0.00872664625 # joint position LH KFE: 0.5 / 180.0 * pi + - 0.00872664625 # joint position RF HAA: 0.5 / 180.0 * pi + - 0.00872664625 # joint position RF HFE: 0.5 / 180.0 * pi + - 0.00872664625 # joint position RF KFE: 0.5 / 180.0 * pi + - 0.00872664625 # joint position RH HAA: 0.5 / 180.0 * pi + - 0.00872664625 # joint position RH HFE: 0.5 / 180.0 * pi + - 0.00872664625 # joint position RH KFE: 0.5 / 180.0 * pi +# settings for computing metrics +POLICY_EVALUATION_TIME_STEP: 0.0025 +POLICY_EVALUATION_DURATION: 4.0 +POLICY_EVALUATION_THREADS: 3 +POLICY_EVALUATION_TASKS: 3 +# +# training +# +BATCH_SIZE: 128 +LEARNING_RATE: 1.e-3 +LEARNING_ITERATIONS: 100000 +GRADIENT_CLIPPING: True +GRADIENT_CLIPPING_VALUE: 1.0 diff --git a/ocs2_mpcnet/ocs2_legged_robot_mpcnet/python/ocs2_legged_robot_mpcnet/mpcnet.py b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/python/ocs2_legged_robot_mpcnet/mpcnet.py new file mode 100644 index 000000000..213b600f8 --- /dev/null +++ b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/python/ocs2_legged_robot_mpcnet/mpcnet.py @@ -0,0 +1,254 @@ +############################################################################### +# Copyright (c) 2022, Farbod Farshidian. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +############################################################################### + +"""Legged robot MPC-Net class. + +Provides a class that handles the MPC-Net training for legged robot. +""" + +import random +import numpy as np +from typing import Tuple + +from ocs2_mpcnet_core import helper +from ocs2_mpcnet_core import mpcnet +from ocs2_mpcnet_core import SystemObservationArray, ModeScheduleArray, TargetTrajectoriesArray + + +class LeggedRobotMpcnet(mpcnet.Mpcnet): + """Legged robot MPC-Net. + + Adds robot-specific methods for the MPC-Net training. + """ + + @staticmethod + def get_stance(duration: float) -> Tuple[np.ndarray, np.ndarray]: + """Get the stance gait. + + Creates the stance event times and mode sequence for a certain time duration: + - contact schedule: STANCE + - swing schedule: - + + Args: + duration: The duration of the mode schedule given by a float. + + Returns: + A tuple containing the components of the mode schedule. + - event_times: The event times given by a NumPy array of shape (K-1) containing floats. + - mode_sequence: The mode sequence given by a NumPy array of shape (K) containing integers. + """ + event_times_template = np.array([1.0], dtype=np.float64) + mode_sequence_template = np.array([15], dtype=np.uintp) + return helper.get_event_times_and_mode_sequence(15, duration, event_times_template, mode_sequence_template) + + def get_random_initial_state_stance(self) -> np.ndarray: + """Get a random initial state for stance. + + Samples a random initial state for the robot in the stance gait. + + Returns: + x: A random initial state given by a NumPy array containing floats. + """ + max_normalized_linear_momentum_x = 0.1 + max_normalized_linear_momentum_y = 0.1 + max_normalized_linear_momentum_z = 0.1 + max_normalized_angular_momentum_x = 1.62079 / 52.1348 * 30.0 / 180.0 * np.pi + max_normalized_angular_momentum_y = 4.83559 / 52.1348 * 30.0 / 180.0 * np.pi + max_normalized_angular_momentum_z = 4.72382 / 52.1348 * 30.0 / 180.0 * np.pi + random_deviation = np.zeros(self.config.STATE_DIM) + random_deviation[0] = np.random.uniform(-max_normalized_linear_momentum_x, max_normalized_linear_momentum_x) + random_deviation[1] = np.random.uniform(-max_normalized_linear_momentum_y, max_normalized_linear_momentum_y) + random_deviation[2] = np.random.uniform( + -max_normalized_linear_momentum_z, max_normalized_linear_momentum_z / 2.0 + ) + random_deviation[3] = np.random.uniform(-max_normalized_angular_momentum_x, max_normalized_angular_momentum_x) + random_deviation[4] = np.random.uniform(-max_normalized_angular_momentum_y, max_normalized_angular_momentum_y) + random_deviation[5] = np.random.uniform(-max_normalized_angular_momentum_z, max_normalized_angular_momentum_z) + return np.array(self.config.DEFAULT_STATE) + random_deviation + + def get_random_target_state_stance(self) -> np.ndarray: + """Get a random target state for stance. + + Samples a random target state for the robot in the stance gait. + + Returns: + x: A random target state given by a NumPy array containing floats. + """ + max_position_z = 0.075 + max_orientation_z = 25.0 / 180.0 * np.pi + max_orientation_y = 15.0 / 180.0 * np.pi + max_orientation_x = 25.0 / 180.0 * np.pi + random_deviation = np.zeros(self.config.TARGET_STATE_DIM) + random_deviation[8] = np.random.uniform(-max_position_z, max_position_z) + random_deviation[9] = np.random.uniform(-max_orientation_z, max_orientation_z) + random_deviation[10] = np.random.uniform(-max_orientation_y, max_orientation_y) + random_deviation[11] = np.random.uniform(-max_orientation_x, max_orientation_x) + return np.array(self.config.DEFAULT_TARGET_STATE) + random_deviation + + @staticmethod + def get_trot_1(duration: float) -> Tuple[np.ndarray, np.ndarray]: + """Get the first trot gait. + + Creates the first trot event times and mode sequence for a certain time duration: + - contact schedule: LF_RH, RF_LH + - swing schedule: RF_LH, LF_RH + + Args: + duration: The duration of the mode schedule given by a float. + + Returns: + A tuple containing the components of the mode schedule. + - event_times: The event times given by a NumPy array of shape (K-1) containing floats. + - mode_sequence: The mode sequence given by a NumPy array of shape (K) containing integers. + """ + event_times_template = np.array([0.35, 0.7], dtype=np.float64) + mode_sequence_template = np.array([9, 6], dtype=np.uintp) + return helper.get_event_times_and_mode_sequence(15, duration, event_times_template, mode_sequence_template) + + @staticmethod + def get_trot_2(duration: float) -> Tuple[np.ndarray, np.ndarray]: + """Get the second trot gait. + + Creates the second trot event times and mode sequence for a certain time duration: + - contact schedule: RF_LH, LF_RH + - swing schedule: LF_RH, RF_LH + + Args: + duration: The duration of the mode schedule given by a float. + + Returns: + A tuple containing the components of the mode schedule. + - event_times: The event times given by a NumPy array of shape (K-1) containing floats. + - mode_sequence: The mode sequence given by a NumPy array of shape (K) containing integers. + """ + event_times_template = np.array([0.35, 0.7], dtype=np.float64) + mode_sequence_template = np.array([6, 9], dtype=np.uintp) + return helper.get_event_times_and_mode_sequence(15, duration, event_times_template, mode_sequence_template) + + def get_random_initial_state_trot(self) -> np.ndarray: + """Get a random initial state for trot. + + Samples a random initial state for the robot in a trot gait. + + Returns: + x: A random initial state given by a NumPy array containing floats. + """ + max_normalized_linear_momentum_x = 0.5 + max_normalized_linear_momentum_y = 0.25 + max_normalized_linear_momentum_z = 0.25 + max_normalized_angular_momentum_x = 1.62079 / 52.1348 * 60.0 / 180.0 * np.pi + max_normalized_angular_momentum_y = 4.83559 / 52.1348 * 60.0 / 180.0 * np.pi + max_normalized_angular_momentum_z = 4.72382 / 52.1348 * 35.0 / 180.0 * np.pi + random_deviation = np.zeros(self.config.STATE_DIM) + random_deviation[0] = np.random.uniform(-max_normalized_linear_momentum_x, max_normalized_linear_momentum_x) + random_deviation[1] = np.random.uniform(-max_normalized_linear_momentum_y, max_normalized_linear_momentum_y) + random_deviation[2] = np.random.uniform( + -max_normalized_linear_momentum_z, max_normalized_linear_momentum_z / 2.0 + ) + random_deviation[3] = np.random.uniform(-max_normalized_angular_momentum_x, max_normalized_angular_momentum_x) + random_deviation[4] = np.random.uniform(-max_normalized_angular_momentum_y, max_normalized_angular_momentum_y) + random_deviation[5] = np.random.uniform(-max_normalized_angular_momentum_z, max_normalized_angular_momentum_z) + return np.array(self.config.DEFAULT_STATE) + random_deviation + + def get_random_target_state_trot(self) -> np.ndarray: + """Get a random target state for trot. + + Samples a random target state for the robot in a trot gait. + + Returns: + x: A random target state given by a NumPy array containing floats. + """ + max_position_x = 0.3 + max_position_y = 0.15 + max_orientation_z = 30.0 / 180.0 * np.pi + random_deviation = np.zeros(self.config.TARGET_STATE_DIM) + random_deviation[6] = np.random.uniform(-max_position_x, max_position_x) + random_deviation[7] = np.random.uniform(-max_position_y, max_position_y) + random_deviation[9] = np.random.uniform(-max_orientation_z, max_orientation_z) + return np.array(self.config.DEFAULT_TARGET_STATE) + random_deviation + + def get_tasks( + self, tasks_number: int, duration: float + ) -> Tuple[SystemObservationArray, ModeScheduleArray, TargetTrajectoriesArray]: + """Get tasks. + + Get a random set of task that should be executed by the data generation or policy evaluation. + + Args: + tasks_number: Number of tasks given by an integer. + duration: Duration of each task given by a float. + + Returns: + A tuple containing the components of the task. + - initial_observations: The initial observations given by an OCS2 system observation array. + - mode_schedules: The desired mode schedules given by an OCS2 mode schedule array. + - target_trajectories: The desired target trajectories given by an OCS2 target trajectories array. + """ + initial_mode = 15 + initial_time = 0.0 + initial_observations = helper.get_system_observation_array(tasks_number) + mode_schedules = helper.get_mode_schedule_array(tasks_number) + target_trajectories = helper.get_target_trajectories_array(tasks_number) + choices = random.choices( + list(self.config.WEIGHTS_FOR_GAITS.keys()), + k=tasks_number, + weights=list(self.config.WEIGHTS_FOR_GAITS.values()), + ) + for i in range(tasks_number): + if choices[i] == "stance": + initial_observations[i] = helper.get_system_observation( + initial_mode, initial_time, self.get_random_initial_state_stance(), np.zeros(self.config.INPUT_DIM) + ) + mode_schedules[i] = helper.get_mode_schedule(*self.get_stance(duration)) + target_trajectories[i] = helper.get_target_trajectories( + duration * np.ones((1, 1)), + self.get_random_target_state_stance().reshape((1, self.config.TARGET_STATE_DIM)), + np.zeros((1, self.config.TARGET_INPUT_DIM)), + ) + elif choices[i] == "trot_1": + initial_observations[i] = helper.get_system_observation( + initial_mode, initial_time, self.get_random_initial_state_trot(), np.zeros(self.config.INPUT_DIM) + ) + mode_schedules[i] = helper.get_mode_schedule(*self.get_trot_1(duration)) + target_trajectories[i] = helper.get_target_trajectories( + duration * np.ones((1, 1)), + self.get_random_target_state_trot().reshape((1, self.config.TARGET_STATE_DIM)), + np.zeros((1, self.config.TARGET_INPUT_DIM)), + ) + elif choices[i] == "trot_2": + initial_observations[i] = helper.get_system_observation( + initial_mode, initial_time, self.get_random_initial_state_trot(), np.zeros(self.config.INPUT_DIM) + ) + mode_schedules[i] = helper.get_mode_schedule(*self.get_trot_2(duration)) + target_trajectories[i] = helper.get_target_trajectories( + duration * np.ones((1, 1)), + self.get_random_target_state_trot().reshape((1, self.config.TARGET_STATE_DIM)), + np.zeros((1, self.config.TARGET_INPUT_DIM)), + ) + return initial_observations, mode_schedules, target_trajectories diff --git a/ocs2_mpcnet/ocs2_legged_robot_mpcnet/python/ocs2_legged_robot_mpcnet/train.py b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/python/ocs2_legged_robot_mpcnet/train.py new file mode 100644 index 000000000..f00533bba --- /dev/null +++ b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/python/ocs2_legged_robot_mpcnet/train.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 + +############################################################################### +# Copyright (c) 2022, Farbod Farshidian. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +############################################################################### + +"""Legged robot MPC-Net. + +Main script for training an MPC-Net policy for legged robot. +""" + +import sys +import os + +from ocs2_mpcnet_core.config import Config +from ocs2_mpcnet_core.loss import HamiltonianLoss +from ocs2_mpcnet_core.loss import CrossEntropyLoss +from ocs2_mpcnet_core.memory import CircularMemory +from ocs2_mpcnet_core.policy import MixtureOfNonlinearExpertsPolicy + +from ocs2_legged_robot_mpcnet import LeggedRobotMpcnet +from ocs2_legged_robot_mpcnet import MpcnetInterface + + +def main(root_dir: str, config_file_name: str) -> None: + # config + config = Config(os.path.join(root_dir, "config", config_file_name)) + # interface + interface = MpcnetInterface(config.DATA_GENERATION_THREADS, config.POLICY_EVALUATION_THREADS, config.RAISIM) + # loss + experts_loss = HamiltonianLoss(config) + gating_loss = CrossEntropyLoss(config) + # memory + memory = CircularMemory(config) + # policy + policy = MixtureOfNonlinearExpertsPolicy(config) + # mpcnet + mpcnet = LeggedRobotMpcnet(root_dir, config, interface, memory, policy, experts_loss, gating_loss) + # train + mpcnet.train() + + +if __name__ == "__main__": + root_dir = os.path.dirname(os.path.abspath(__file__)) + if len(sys.argv) > 1: + main(root_dir, sys.argv[1]) + else: + main(root_dir, "legged_robot.yaml") diff --git a/ocs2_mpcnet/ocs2_legged_robot_mpcnet/setup.py b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/setup.py new file mode 100644 index 000000000..2227298c6 --- /dev/null +++ b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/setup.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python + +from setuptools import setup +from catkin_pkg.python_setup import generate_distutils_setup + +setup_args = generate_distutils_setup(packages=["ocs2_legged_robot_mpcnet"], package_dir={"": "python"}) + +setup(**setup_args) diff --git a/ocs2_mpcnet/ocs2_legged_robot_mpcnet/src/LeggedRobotMpcnetDefinition.cpp b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/src/LeggedRobotMpcnetDefinition.cpp new file mode 100644 index 000000000..8c19df882 --- /dev/null +++ b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/src/LeggedRobotMpcnetDefinition.cpp @@ -0,0 +1,123 @@ +/****************************************************************************** +Copyright (c) 2022, Farbod Farshidian. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +******************************************************************************/ + +#include "ocs2_legged_robot_mpcnet/LeggedRobotMpcnetDefinition.h" + +#include + +#include +#include +#include +#include + +namespace ocs2 { +namespace legged_robot { + +vector_t LeggedRobotMpcnetDefinition::getObservation(scalar_t t, const vector_t& x, const ModeSchedule& modeSchedule, + const TargetTrajectories& targetTrajectories) { + /** + * generalized time + */ + const feet_array_t swingPhasePerLeg = getSwingPhasePerLeg(t, modeSchedule); + vector_t generalizedTime(3 * swingPhasePerLeg.size()); + // phase + for (int i = 0; i < swingPhasePerLeg.size(); i++) { + if (swingPhasePerLeg[i].phase < 0.0) { + generalizedTime[i] = 0.0; + } else { + generalizedTime[i] = swingPhasePerLeg[i].phase; + } + } + // phase rate + for (int i = 0; i < swingPhasePerLeg.size(); i++) { + if (swingPhasePerLeg[i].phase < 0.0) { + generalizedTime[i + swingPhasePerLeg.size()] = 0.0; + } else { + generalizedTime[i + swingPhasePerLeg.size()] = 1.0 / swingPhasePerLeg[i].duration; + } + } + // sin(pi * phase) + for (int i = 0; i < swingPhasePerLeg.size(); i++) { + if (swingPhasePerLeg[i].phase < 0.0) { + generalizedTime[i + 2 * swingPhasePerLeg.size()] = 0.0; + } else { + generalizedTime[i + 2 * swingPhasePerLeg.size()] = std::sin(M_PI * swingPhasePerLeg[i].phase); + } + } + /** + * relative state + */ + vector_t relativeState = x - targetTrajectories.getDesiredState(t); + const matrix3_t R = getRotationMatrixFromZyxEulerAngles(x.segment<3>(9)).transpose(); + relativeState.segment<3>(0) = R * relativeState.segment<3>(0); + relativeState.segment<3>(3) = R * relativeState.segment<3>(3); + relativeState.segment<3>(6) = R * relativeState.segment<3>(6); + // TODO(areske): use quaternionDistance() for orientation error? + /** + * observation + */ + vector_t observation(36); + observation << generalizedTime, relativeState; + return observation; +} + +std::pair LeggedRobotMpcnetDefinition::getActionTransformation(scalar_t t, const vector_t& x, + const ModeSchedule& modeSchedule, + const TargetTrajectories& targetTrajectories) { + const matrix3_t R = getRotationMatrixFromZyxEulerAngles(x.segment<3>(9)); + matrix_t actionTransformationMatrix = matrix_t::Identity(24, 24); + actionTransformationMatrix.block<3, 3>(0, 0) = R; + actionTransformationMatrix.block<3, 3>(3, 3) = R; + actionTransformationMatrix.block<3, 3>(6, 6) = R; + actionTransformationMatrix.block<3, 3>(9, 9) = R; + // TODO(areske): check why less robust with weight compensating bias? + // const auto contactFlags = modeNumber2StanceLeg(modeSchedule.modeAtTime(t)); + // const vector_t actionTransformationVector = weightCompensatingInput(centroidalModelInfo_, contactFlags); + return {actionTransformationMatrix, vector_t::Zero(24)}; +} + +bool LeggedRobotMpcnetDefinition::isValid(scalar_t t, const vector_t& x, const ModeSchedule& modeSchedule, + const TargetTrajectories& targetTrajectories) { + const vector_t deviation = x - defaultState_; + if (std::abs(deviation[8]) > allowedHeightDeviation_) { + std::cerr << "[LeggedRobotMpcnetDefinition::isValid] height diverged: " << x[8] << "\n"; + return false; + } else if (std::abs(deviation[10]) > allowedPitchDeviation_) { + std::cerr << "[LeggedRobotMpcnetDefinition::isValid] pitch diverged: " << x[10] << "\n"; + return false; + } else if (std::abs(deviation[11]) > allowedRollDeviation_) { + std::cerr << "[LeggedRobotMpcnetDefinition::isValid] roll diverged: " << x[11] << "\n"; + return false; + } else { + return true; + } +} + +} // namespace legged_robot +} // namespace ocs2 diff --git a/ocs2_mpcnet/ocs2_legged_robot_mpcnet/src/LeggedRobotMpcnetDummyNode.cpp b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/src/LeggedRobotMpcnetDummyNode.cpp new file mode 100644 index 000000000..e16898bac --- /dev/null +++ b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/src/LeggedRobotMpcnetDummyNode.cpp @@ -0,0 +1,165 @@ +/****************************************************************************** +Copyright (c) 2022, Farbod Farshidian. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +******************************************************************************/ + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ocs2_legged_robot_mpcnet/LeggedRobotMpcnetDefinition.h" + +using namespace ocs2; +using namespace legged_robot; + +int main(int argc, char** argv) { + const std::string robotName = "legged_robot"; + + // initialize ros node + ros::init(argc, argv, robotName + "_mpcnet_dummy"); + ros::NodeHandle nodeHandle; + // Get node parameters + bool useRaisim; + std::string taskFile, urdfFile, referenceFile, raisimFile, resourcePath, policyFile; + nodeHandle.getParam("/taskFile", taskFile); + nodeHandle.getParam("/urdfFile", urdfFile); + nodeHandle.getParam("/referenceFile", referenceFile); + nodeHandle.getParam("/raisimFile", raisimFile); + nodeHandle.getParam("/resourcePath", resourcePath); + nodeHandle.getParam("/policyFile", policyFile); + nodeHandle.getParam("/useRaisim", useRaisim); + + // legged robot interface + LeggedRobotInterface leggedRobotInterface(taskFile, urdfFile, referenceFile); + + // gait receiver + auto gaitReceiverPtr = + std::make_shared(nodeHandle, leggedRobotInterface.getSwitchedModelReferenceManagerPtr()->getGaitSchedule(), robotName); + + // ROS reference manager + auto rosReferenceManagerPtr = std::make_shared(robotName, leggedRobotInterface.getReferenceManagerPtr()); + rosReferenceManagerPtr->subscribe(nodeHandle); + + // policy (MPC-Net controller) + auto onnxEnvironmentPtr = ocs2::mpcnet::createOnnxEnvironment(); + std::shared_ptr mpcnetDefinitionPtr(new LeggedRobotMpcnetDefinition(leggedRobotInterface)); + std::unique_ptr mpcnetControllerPtr( + new ocs2::mpcnet::MpcnetOnnxController(mpcnetDefinitionPtr, rosReferenceManagerPtr, onnxEnvironmentPtr)); + mpcnetControllerPtr->loadPolicyModel(policyFile); + + // rollout + std::unique_ptr rolloutPtr; + raisim::HeightMap* terrainPtr = nullptr; + std::unique_ptr heightmapPub; + std::unique_ptr conversions; + if (useRaisim) { + conversions.reset(new LeggedRobotRaisimConversions(leggedRobotInterface.getPinocchioInterface(), + leggedRobotInterface.getCentroidalModelInfo(), + leggedRobotInterface.getInitialState())); + RaisimRolloutSettings raisimRolloutSettings(raisimFile, "rollout", true); + conversions->loadSettings(raisimFile, "rollout", true); + rolloutPtr.reset(new RaisimRollout( + urdfFile, resourcePath, + [&](const vector_t& state, const vector_t& input) { return conversions->stateToRaisimGenCoordGenVel(state, input); }, + [&](const Eigen::VectorXd& q, const Eigen::VectorXd& dq) { return conversions->raisimGenCoordGenVelToState(q, dq); }, + [&](double time, const vector_t& input, const vector_t& state, const Eigen::VectorXd& q, const Eigen::VectorXd& dq) { + return conversions->inputToRaisimGeneralizedForce(time, input, state, q, dq); + }, + nullptr, raisimRolloutSettings, + [&](double time, const vector_t& input, const vector_t& state, const Eigen::VectorXd& q, const Eigen::VectorXd& dq) { + return conversions->inputToRaisimPdTargets(time, input, state, q, dq); + })); + // terrain + if (raisimRolloutSettings.generateTerrain_) { + raisim::TerrainProperties terrainProperties; + terrainProperties.zScale = raisimRolloutSettings.terrainRoughness_; + terrainProperties.seed = raisimRolloutSettings.terrainSeed_; + terrainPtr = static_cast(rolloutPtr.get())->generateTerrain(terrainProperties); + conversions->setTerrain(*terrainPtr); + heightmapPub.reset(new ocs2::RaisimHeightmapRosConverter()); + heightmapPub->publishGridmap(*terrainPtr, "odom"); + } + } else { + rolloutPtr.reset(leggedRobotInterface.getRollout().clone()); + } + + // observer + std::shared_ptr mpcnetDummyObserverRosPtr( + new ocs2::mpcnet::MpcnetDummyObserverRos(nodeHandle, robotName)); + + // visualization + CentroidalModelPinocchioMapping pinocchioMapping(leggedRobotInterface.getCentroidalModelInfo()); + PinocchioEndEffectorKinematics endEffectorKinematics(leggedRobotInterface.getPinocchioInterface(), pinocchioMapping, + leggedRobotInterface.modelSettings().contactNames3DoF); + std::shared_ptr leggedRobotVisualizerPtr; + if (useRaisim) { + leggedRobotVisualizerPtr.reset(new LeggedRobotRaisimVisualizer( + leggedRobotInterface.getPinocchioInterface(), leggedRobotInterface.getCentroidalModelInfo(), endEffectorKinematics, nodeHandle)); + static_cast(leggedRobotVisualizerPtr.get())->updateTerrain(); + } else { + leggedRobotVisualizerPtr.reset(new LeggedRobotVisualizer( + leggedRobotInterface.getPinocchioInterface(), leggedRobotInterface.getCentroidalModelInfo(), endEffectorKinematics, nodeHandle)); + } + + // MPC-Net dummy loop ROS + const scalar_t controlFrequency = leggedRobotInterface.mpcSettings().mrtDesiredFrequency_; + const scalar_t rosFrequency = leggedRobotInterface.mpcSettings().mpcDesiredFrequency_; + ocs2::mpcnet::MpcnetDummyLoopRos mpcnetDummyLoopRos(controlFrequency, rosFrequency, std::move(mpcnetControllerPtr), std::move(rolloutPtr), + rosReferenceManagerPtr); + mpcnetDummyLoopRos.addObserver(mpcnetDummyObserverRosPtr); + mpcnetDummyLoopRos.addObserver(leggedRobotVisualizerPtr); + mpcnetDummyLoopRos.addSynchronizedModule(gaitReceiverPtr); + + // initial system observation + SystemObservation systemObservation; + systemObservation.mode = ModeNumber::STANCE; + systemObservation.time = 0.0; + systemObservation.state = leggedRobotInterface.getInitialState(); + systemObservation.input = vector_t::Zero(leggedRobotInterface.getCentroidalModelInfo().inputDim); + + // initial target trajectories + const TargetTrajectories targetTrajectories({systemObservation.time}, {systemObservation.state}, {systemObservation.input}); + + // run MPC-Net dummy loop ROS + mpcnetDummyLoopRos.run(systemObservation, targetTrajectories); + + // successful exit + return 0; +} diff --git a/ocs2_mpcnet/ocs2_legged_robot_mpcnet/src/LeggedRobotMpcnetInterface.cpp b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/src/LeggedRobotMpcnetInterface.cpp new file mode 100644 index 000000000..6c068bca9 --- /dev/null +++ b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/src/LeggedRobotMpcnetInterface.cpp @@ -0,0 +1,143 @@ +/****************************************************************************** +Copyright (c) 2022, Farbod Farshidian. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +******************************************************************************/ + +#include "ocs2_legged_robot_mpcnet/LeggedRobotMpcnetInterface.h" + +#include + +#include +#include +#include +#include +#include +#include + +#include "ocs2_legged_robot_mpcnet/LeggedRobotMpcnetDefinition.h" + +namespace ocs2 { +namespace legged_robot { + +LeggedRobotMpcnetInterface::LeggedRobotMpcnetInterface(size_t nDataGenerationThreads, size_t nPolicyEvaluationThreads, bool raisim) { + // create ONNX environment + auto onnxEnvironmentPtr = ocs2::mpcnet::createOnnxEnvironment(); + // paths to files + const std::string taskFile = ros::package::getPath("ocs2_legged_robot") + "/config/mpc/task.info"; + const std::string urdfFile = ros::package::getPath("ocs2_robotic_assets") + "/resources/anymal_c/urdf/anymal.urdf"; + const std::string referenceFile = ros::package::getPath("ocs2_legged_robot") + "/config/command/reference.info"; + const std::string raisimFile = ros::package::getPath("ocs2_legged_robot_raisim") + "/config/raisim.info"; + const std::string resourcePath = ros::package::getPath("ocs2_robotic_assets") + "/resources/anymal_c/meshes"; + // set up MPC-Net rollout manager for data generation and policy evaluation + std::vector> mpcPtrs; + std::vector> mpcnetPtrs; + std::vector> rolloutPtrs; + std::vector> mpcnetDefinitionPtrs; + std::vector> referenceManagerPtrs; + mpcPtrs.reserve(nDataGenerationThreads + nPolicyEvaluationThreads); + mpcnetPtrs.reserve(nDataGenerationThreads + nPolicyEvaluationThreads); + rolloutPtrs.reserve(nDataGenerationThreads + nPolicyEvaluationThreads); + mpcnetDefinitionPtrs.reserve(nDataGenerationThreads + nPolicyEvaluationThreads); + referenceManagerPtrs.reserve(nDataGenerationThreads + nPolicyEvaluationThreads); + for (int i = 0; i < (nDataGenerationThreads + nPolicyEvaluationThreads); i++) { + leggedRobotInterfacePtrs_.push_back(std::unique_ptr(new LeggedRobotInterface(taskFile, urdfFile, referenceFile))); + std::shared_ptr mpcnetDefinitionPtr(new LeggedRobotMpcnetDefinition(*leggedRobotInterfacePtrs_[i])); + mpcPtrs.push_back(getMpc(*leggedRobotInterfacePtrs_[i])); + mpcnetPtrs.push_back(std::unique_ptr(new ocs2::mpcnet::MpcnetOnnxController( + mpcnetDefinitionPtr, leggedRobotInterfacePtrs_[i]->getReferenceManagerPtr(), onnxEnvironmentPtr))); + if (raisim) { + RaisimRolloutSettings raisimRolloutSettings(raisimFile, "rollout"); + raisimRolloutSettings.portNumber_ += i; + leggedRobotRaisimConversionsPtrs_.push_back(std::unique_ptr(new LeggedRobotRaisimConversions( + leggedRobotInterfacePtrs_[i]->getPinocchioInterface(), leggedRobotInterfacePtrs_[i]->getCentroidalModelInfo(), + leggedRobotInterfacePtrs_[i]->getInitialState()))); + leggedRobotRaisimConversionsPtrs_[i]->loadSettings(raisimFile, "rollout", true); + rolloutPtrs.push_back(std::unique_ptr(new RaisimRollout( + urdfFile, resourcePath, + [&, i](const vector_t& state, const vector_t& input) { + return leggedRobotRaisimConversionsPtrs_[i]->stateToRaisimGenCoordGenVel(state, input); + }, + [&, i](const Eigen::VectorXd& q, const Eigen::VectorXd& dq) { + return leggedRobotRaisimConversionsPtrs_[i]->raisimGenCoordGenVelToState(q, dq); + }, + [&, i](double time, const vector_t& input, const vector_t& state, const Eigen::VectorXd& q, const Eigen::VectorXd& dq) { + return leggedRobotRaisimConversionsPtrs_[i]->inputToRaisimGeneralizedForce(time, input, state, q, dq); + }, + nullptr, raisimRolloutSettings, + [&, i](double time, const vector_t& input, const vector_t& state, const Eigen::VectorXd& q, const Eigen::VectorXd& dq) { + return leggedRobotRaisimConversionsPtrs_[i]->inputToRaisimPdTargets(time, input, state, q, dq); + }))); + if (raisimRolloutSettings.generateTerrain_) { + raisim::TerrainProperties terrainProperties; + terrainProperties.zScale = raisimRolloutSettings.terrainRoughness_; + terrainProperties.seed = raisimRolloutSettings.terrainSeed_ + i; + auto terrainPtr = static_cast(rolloutPtrs[i].get())->generateTerrain(terrainProperties); + leggedRobotRaisimConversionsPtrs_[i]->setTerrain(*terrainPtr); + } + } else { + rolloutPtrs.push_back(std::unique_ptr(leggedRobotInterfacePtrs_[i]->getRollout().clone())); + } + mpcnetDefinitionPtrs.push_back(mpcnetDefinitionPtr); + referenceManagerPtrs.push_back(leggedRobotInterfacePtrs_[i]->getReferenceManagerPtr()); + } + mpcnetRolloutManagerPtr_.reset(new ocs2::mpcnet::MpcnetRolloutManager(nDataGenerationThreads, nPolicyEvaluationThreads, + std::move(mpcPtrs), std::move(mpcnetPtrs), std::move(rolloutPtrs), + mpcnetDefinitionPtrs, referenceManagerPtrs)); +} + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +std::unique_ptr LeggedRobotMpcnetInterface::getMpc(LeggedRobotInterface& leggedRobotInterface) { + // ensure MPC and DDP settings are as needed for MPC-Net + const auto mpcSettings = [&]() { + auto settings = leggedRobotInterface.mpcSettings(); + settings.debugPrint_ = false; + settings.coldStart_ = false; + return settings; + }(); + const auto ddpSettings = [&]() { + auto settings = leggedRobotInterface.ddpSettings(); + settings.algorithm_ = ocs2::ddp::Algorithm::SLQ; + settings.nThreads_ = 1; + settings.displayInfo_ = false; + settings.displayShortSummary_ = false; + settings.checkNumericalStability_ = false; + settings.debugPrintRollout_ = false; + settings.useFeedbackPolicy_ = true; + return settings; + }(); + // create one MPC instance + std::unique_ptr mpcPtr(new GaussNewtonDDP_MPC(mpcSettings, ddpSettings, leggedRobotInterface.getRollout(), + leggedRobotInterface.getOptimalControlProblem(), + leggedRobotInterface.getInitializer())); + mpcPtr->getSolverPtr()->setReferenceManager(leggedRobotInterface.getReferenceManagerPtr()); + return mpcPtr; +} + +} // namespace legged_robot +} // namespace ocs2 diff --git a/ocs2_mpcnet/ocs2_legged_robot_mpcnet/src/LeggedRobotMpcnetPybindings.cpp b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/src/LeggedRobotMpcnetPybindings.cpp new file mode 100644 index 000000000..f170bde7a --- /dev/null +++ b/ocs2_mpcnet/ocs2_legged_robot_mpcnet/src/LeggedRobotMpcnetPybindings.cpp @@ -0,0 +1,34 @@ +/****************************************************************************** +Copyright (c) 2022, Farbod Farshidian. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +******************************************************************************/ + +#include + +#include "ocs2_legged_robot_mpcnet/LeggedRobotMpcnetInterface.h" + +CREATE_ROBOT_MPCNET_PYTHON_BINDINGS(ocs2::legged_robot::LeggedRobotMpcnetInterface, LeggedRobotMpcnetPybindings) diff --git a/ocs2_mpcnet/ocs2_mpcnet/CMakeLists.txt b/ocs2_mpcnet/ocs2_mpcnet/CMakeLists.txt new file mode 100644 index 000000000..a9dba76cb --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet/CMakeLists.txt @@ -0,0 +1,4 @@ +cmake_minimum_required(VERSION 3.0.2) +project(ocs2_mpcnet) +find_package(catkin REQUIRED) +catkin_metapackage() diff --git a/ocs2_mpcnet/ocs2_mpcnet/package.xml b/ocs2_mpcnet/ocs2_mpcnet/package.xml new file mode 100644 index 000000000..7eec28861 --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet/package.xml @@ -0,0 +1,24 @@ + + + ocs2_mpcnet + 0.0.0 + The ocs2_mpcnet metapackage + + Alexander Reske + + Farbod Farshidian + Alexander Reske + + BSD-3 + + catkin + + ocs2_mpcnet_core + ocs2_ballbot_mpcnet + ocs2_legged_robot_mpcnet + + + + + + diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/CMakeLists.txt b/ocs2_mpcnet/ocs2_mpcnet_core/CMakeLists.txt new file mode 100644 index 000000000..7b91fd000 --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/CMakeLists.txt @@ -0,0 +1,118 @@ +cmake_minimum_required(VERSION 3.0.2) +project(ocs2_mpcnet_core) + +set(CATKIN_PACKAGE_DEPENDENCIES + pybind11_catkin + ocs2_mpc + ocs2_python_interface + ocs2_ros_interfaces +) + +find_package(catkin REQUIRED COMPONENTS + ${CATKIN_PACKAGE_DEPENDENCIES} +) + +find_package(onnxruntime REQUIRED) + +# Generate compile_commands.json for clang tools +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +################################### +## catkin specific configuration ## +################################### + +catkin_package( + INCLUDE_DIRS + include + LIBRARIES + ${PROJECT_NAME} + CATKIN_DEPENDS + ${CATKIN_PACKAGE_DEPENDENCIES} + DEPENDS + onnxruntime +) + +########### +## Build ## +########### + +include_directories( + include + ${catkin_INCLUDE_DIRS} +) + +# main library +add_library(${PROJECT_NAME} + src/control/MpcnetBehavioralController.cpp + src/control/MpcnetOnnxController.cpp + src/dummy/MpcnetDummyLoopRos.cpp + src/dummy/MpcnetDummyObserverRos.cpp + src/rollout/MpcnetDataGeneration.cpp + src/rollout/MpcnetPolicyEvaluation.cpp + src/rollout/MpcnetRolloutBase.cpp + src/rollout/MpcnetRolloutManager.cpp + src/MpcnetInterfaceBase.cpp +) +add_dependencies(${PROJECT_NAME} + ${catkin_EXPORTED_TARGETS} +) +target_link_libraries(${PROJECT_NAME} + ${catkin_LIBRARIES} + onnxruntime +) + +# python bindings +pybind11_add_module(MpcnetPybindings SHARED + src/MpcnetPybindings.cpp +) +add_dependencies(MpcnetPybindings + ${PROJECT_NAME} + ${catkin_EXPORTED_TARGETS} +) +target_link_libraries(MpcnetPybindings PRIVATE + ${PROJECT_NAME} + ${catkin_LIBRARIES} +) +set_target_properties(MpcnetPybindings + PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CATKIN_DEVEL_PREFIX}/${CATKIN_PACKAGE_PYTHON_DESTINATION} +) + +catkin_python_setup() + +######################### +### CLANG TOOLING ### +######################### +find_package(cmake_clang_tools QUIET) +if(cmake_clang_tools_FOUND) + message(STATUS "Run clang tooling for target ocs2_mpcnet_core") + add_clang_tooling( + TARGETS ${PROJECT_NAME} + SOURCE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/src ${CMAKE_CURRENT_SOURCE_DIR}/include + CT_HEADER_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/include + CF_WERROR +) +endif(cmake_clang_tools_FOUND) + +############# +## Install ## +############# + +install(TARGETS ${PROJECT_NAME} + ARCHIVE DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} + LIBRARY DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} + RUNTIME DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} +) + +install(DIRECTORY include/${PROJECT_NAME}/ + DESTINATION ${CATKIN_PACKAGE_INCLUDE_DESTINATION}) + +install(TARGETS MpcnetPybindings + ARCHIVE DESTINATION ${CATKIN_PACKAGE_PYTHON_DESTINATION} + LIBRARY DESTINATION ${CATKIN_PACKAGE_PYTHON_DESTINATION} +) + +############# +## Testing ## +############# + +# TODO(areske) diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/MpcnetDefinitionBase.h b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/MpcnetDefinitionBase.h new file mode 100644 index 000000000..de322bcc2 --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/MpcnetDefinitionBase.h @@ -0,0 +1,100 @@ +/****************************************************************************** +Copyright (c) 2022, Farbod Farshidian. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +******************************************************************************/ + +#pragma once + +#include +#include + +namespace ocs2 { +namespace mpcnet { + +/** + * Base class for MPC-Net definitions. + */ +class MpcnetDefinitionBase { + public: + /** + * Default constructor. + */ + MpcnetDefinitionBase() = default; + + /** + * Default destructor. + */ + virtual ~MpcnetDefinitionBase() = default; + + /** + * Deleted copy constructor. + */ + MpcnetDefinitionBase(const MpcnetDefinitionBase&) = delete; + + /** + * Deleted copy assignment. + */ + MpcnetDefinitionBase& operator=(const MpcnetDefinitionBase&) = delete; + + /** + * Get the observation. + * @note The observation o is the input to the policy. + * @param[in] t : Absolute time. + * @param[in] x : Robot state. + * @param[in] modeSchedule : Mode schedule. + * @param[in] targetTrajectories : Target trajectories. + * @return The observation. + */ + virtual vector_t getObservation(scalar_t t, const vector_t& x, const ModeSchedule& modeSchedule, + const TargetTrajectories& targetTrajectories) = 0; + + /** + * Get the action transformation. + * @note Used for computing the control input u = A * a + b from the action a predicted by the policy. + * @param[in] t : Absolute time. + * @param[in] x : Robot state. + * @param[in] modeSchedule : Mode schedule. + * @param[in] targetTrajectories : Target trajectories. + * @return The action transformation pair {A, b}. + */ + virtual std::pair getActionTransformation(scalar_t t, const vector_t& x, const ModeSchedule& modeSchedule, + const TargetTrajectories& targetTrajectories) = 0; + + /** + * Check if the tuple (t, x, modeSchedule, targetTrajectories) is valid. + * @note E.g., check if the state diverged or if tracking is poor. + * @param[in] t : Absolute time. + * @param[in] x : Robot state. + * @param[in] modeSchedule : Mode schedule. + * @param[in] targetTrajectories : Target trajectories. + * @return True if valid. + */ + virtual bool isValid(scalar_t t, const vector_t& x, const ModeSchedule& modeSchedule, const TargetTrajectories& targetTrajectories) = 0; +}; + +} // namespace mpcnet +} // namespace ocs2 diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/MpcnetInterfaceBase.h b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/MpcnetInterfaceBase.h new file mode 100644 index 000000000..ec4b72d00 --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/MpcnetInterfaceBase.h @@ -0,0 +1,91 @@ +/****************************************************************************** +Copyright (c) 2022, Farbod Farshidian. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +******************************************************************************/ + +#pragma once + +#include "ocs2_mpcnet_core/rollout/MpcnetRolloutManager.h" + +namespace ocs2 { +namespace mpcnet { + +/** + * Base class for all MPC-Net interfaces between C++ and Python. + */ +class MpcnetInterfaceBase { + public: + /** + * Default destructor. + */ + virtual ~MpcnetInterfaceBase() = default; + + /** + * @see MpcnetRolloutManager::startDataGeneration() + */ + void startDataGeneration(scalar_t alpha, const std::string& policyFilePath, scalar_t timeStep, size_t dataDecimation, size_t nSamples, + const matrix_t& samplingCovariance, const std::vector& initialObservations, + const std::vector& modeSchedules, const std::vector& targetTrajectories); + + /** + * @see MpcnetRolloutManager::isDataGenerationDone() + */ + bool isDataGenerationDone(); + + /** + * @see MpcnetRolloutManager::getGeneratedData() + */ + data_array_t getGeneratedData(); + + /** + * @see MpcnetRolloutManager::startPolicyEvaluation() + */ + void startPolicyEvaluation(scalar_t alpha, const std::string& policyFilePath, scalar_t timeStep, + const std::vector& initialObservations, const std::vector& modeSchedules, + const std::vector& targetTrajectories); + + /** + * @see MpcnetRolloutManager::isPolicyEvaluationDone() + */ + bool isPolicyEvaluationDone(); + + /** + * @see MpcnetRolloutManager::getComputedMetrics() + */ + metrics_array_t getComputedMetrics(); + + protected: + /** + * Default constructor. + */ + MpcnetInterfaceBase() = default; + + std::unique_ptr mpcnetRolloutManagerPtr_; +}; + +} // namespace mpcnet +} // namespace ocs2 diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/MpcnetPybindMacros.h b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/MpcnetPybindMacros.h new file mode 100644 index 000000000..632fa4717 --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/MpcnetPybindMacros.h @@ -0,0 +1,133 @@ +/****************************************************************************** +Copyright (c) 2022, Farbod Farshidian. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +******************************************************************************/ + +#pragma once + +#include +#include +#include + +#include +#include + +#include "ocs2_mpcnet_core/rollout/MpcnetData.h" +#include "ocs2_mpcnet_core/rollout/MpcnetMetrics.h" + +using namespace pybind11::literals; + +/** + * Convenience macro to bind general MPC-Net functionalities and other classes with all required vectors. + */ +#define CREATE_MPCNET_PYTHON_BINDINGS(LIB_NAME) \ + /* make vector types opaque so they are not converted to python lists */ \ + PYBIND11_MAKE_OPAQUE(ocs2::size_array_t) \ + PYBIND11_MAKE_OPAQUE(ocs2::scalar_array_t) \ + PYBIND11_MAKE_OPAQUE(ocs2::vector_array_t) \ + PYBIND11_MAKE_OPAQUE(ocs2::matrix_array_t) \ + PYBIND11_MAKE_OPAQUE(std::vector) \ + PYBIND11_MAKE_OPAQUE(std::vector) \ + PYBIND11_MAKE_OPAQUE(std::vector) \ + PYBIND11_MAKE_OPAQUE(ocs2::mpcnet::data_array_t) \ + PYBIND11_MAKE_OPAQUE(ocs2::mpcnet::metrics_array_t) \ + /* create a python module */ \ + PYBIND11_MODULE(LIB_NAME, m) { \ + /* bind vector types so they can be used natively in python */ \ + VECTOR_TYPE_BINDING(ocs2::size_array_t, "size_array") \ + VECTOR_TYPE_BINDING(ocs2::scalar_array_t, "scalar_array") \ + VECTOR_TYPE_BINDING(ocs2::vector_array_t, "vector_array") \ + VECTOR_TYPE_BINDING(ocs2::matrix_array_t, "matrix_array") \ + VECTOR_TYPE_BINDING(std::vector, "SystemObservationArray") \ + VECTOR_TYPE_BINDING(std::vector, "ModeScheduleArray") \ + VECTOR_TYPE_BINDING(std::vector, "TargetTrajectoriesArray") \ + VECTOR_TYPE_BINDING(ocs2::mpcnet::data_array_t, "DataArray") \ + VECTOR_TYPE_BINDING(ocs2::mpcnet::metrics_array_t, "MetricsArray") \ + /* bind approximation classes */ \ + pybind11::class_(m, "ScalarFunctionQuadraticApproximation") \ + .def_readwrite("f", &ocs2::ScalarFunctionQuadraticApproximation::f) \ + .def_readwrite("dfdx", &ocs2::ScalarFunctionQuadraticApproximation::dfdx) \ + .def_readwrite("dfdu", &ocs2::ScalarFunctionQuadraticApproximation::dfdu) \ + .def_readwrite("dfdxx", &ocs2::ScalarFunctionQuadraticApproximation::dfdxx) \ + .def_readwrite("dfdux", &ocs2::ScalarFunctionQuadraticApproximation::dfdux) \ + .def_readwrite("dfduu", &ocs2::ScalarFunctionQuadraticApproximation::dfduu); \ + /* bind system observation struct */ \ + pybind11::class_(m, "SystemObservation") \ + .def(pybind11::init<>()) \ + .def_readwrite("mode", &ocs2::SystemObservation::mode) \ + .def_readwrite("time", &ocs2::SystemObservation::time) \ + .def_readwrite("state", &ocs2::SystemObservation::state) \ + .def_readwrite("input", &ocs2::SystemObservation::input); \ + /* bind mode schedule struct */ \ + pybind11::class_(m, "ModeSchedule") \ + .def(pybind11::init()) \ + .def_readwrite("eventTimes", &ocs2::ModeSchedule::eventTimes) \ + .def_readwrite("modeSequence", &ocs2::ModeSchedule::modeSequence); \ + /* bind target trajectories class */ \ + pybind11::class_(m, "TargetTrajectories") \ + .def(pybind11::init()) \ + .def_readwrite("timeTrajectory", &ocs2::TargetTrajectories::timeTrajectory) \ + .def_readwrite("stateTrajectory", &ocs2::TargetTrajectories::stateTrajectory) \ + .def_readwrite("inputTrajectory", &ocs2::TargetTrajectories::inputTrajectory); \ + /* bind data point struct */ \ + pybind11::class_(m, "DataPoint") \ + .def(pybind11::init<>()) \ + .def_readwrite("mode", &ocs2::mpcnet::data_point_t::mode) \ + .def_readwrite("t", &ocs2::mpcnet::data_point_t::t) \ + .def_readwrite("x", &ocs2::mpcnet::data_point_t::x) \ + .def_readwrite("u", &ocs2::mpcnet::data_point_t::u) \ + .def_readwrite("observation", &ocs2::mpcnet::data_point_t::observation) \ + .def_readwrite("actionTransformation", &ocs2::mpcnet::data_point_t::actionTransformation) \ + .def_readwrite("hamiltonian", &ocs2::mpcnet::data_point_t::hamiltonian); \ + /* bind metrics struct */ \ + pybind11::class_(m, "Metrics") \ + .def(pybind11::init<>()) \ + .def_readwrite("survivalTime", &ocs2::mpcnet::metrics_t::survivalTime) \ + .def_readwrite("incurredHamiltonian", &ocs2::mpcnet::metrics_t::incurredHamiltonian); \ + } + +/** + * Convenience macro to bind robot MPC-Net interface. + */ +#define CREATE_ROBOT_MPCNET_PYTHON_BINDINGS(MPCNET_INTERFACE, LIB_NAME) \ + /* create a python module */ \ + PYBIND11_MODULE(LIB_NAME, m) { \ + /* import the general MPC-Net module */ \ + pybind11::module::import("ocs2_mpcnet_core.MpcnetPybindings"); \ + /* bind actual MPC-Net interface for specific robot */ \ + pybind11::class_(m, "MpcnetInterface") \ + .def(pybind11::init()) \ + .def("startDataGeneration", &MPCNET_INTERFACE::startDataGeneration, "alpha"_a, "policyFilePath"_a, "timeStep"_a, \ + "dataDecimation"_a, "nSamples"_a, "samplingCovariance"_a.noconvert(), "initialObservations"_a, "modeSchedules"_a, \ + "targetTrajectories"_a) \ + .def("isDataGenerationDone", &MPCNET_INTERFACE::isDataGenerationDone) \ + .def("getGeneratedData", &MPCNET_INTERFACE::getGeneratedData) \ + .def("startPolicyEvaluation", &MPCNET_INTERFACE::startPolicyEvaluation, "alpha"_a, "policyFilePath"_a, "timeStep"_a, \ + "initialObservations"_a, "modeSchedules"_a, "targetTrajectories"_a) \ + .def("isPolicyEvaluationDone", &MPCNET_INTERFACE::isPolicyEvaluationDone) \ + .def("getComputedMetrics", &MPCNET_INTERFACE::getComputedMetrics); \ + } diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/control/MpcnetBehavioralController.h b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/control/MpcnetBehavioralController.h new file mode 100644 index 000000000..60ec8a6c6 --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/control/MpcnetBehavioralController.h @@ -0,0 +1,97 @@ +/****************************************************************************** +Copyright (c) 2022, Farbod Farshidian. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +******************************************************************************/ + +#pragma once + +#include + +#include + +#include "ocs2_mpcnet_core/control/MpcnetControllerBase.h" + +namespace ocs2 { +namespace mpcnet { + +/** + * A behavioral controller that computes the input based on a mixture of an optimal policy (e.g. implicitly found via MPC) + * and a learned policy (e.g. explicitly represented by a neural network). + * The behavioral policy is pi_behavioral = alpha * pi_optimal + (1 - alpha) * pi_learned with alpha in [0, 1]. + */ +class MpcnetBehavioralController final : public ControllerBase { + public: + /** + * Constructor. + * @param [in] alpha : The mixture parameter. + * @param [in] optimalController : The optimal controller (this class takes ownership of a clone). + * @param [in] learnedController : The learned controller (this class takes ownership of a clone). + */ + MpcnetBehavioralController(scalar_t alpha, const ControllerBase& optimalController, const MpcnetControllerBase& learnedController) + : alpha_(alpha), optimalControllerPtr_(optimalController.clone()), learnedControllerPtr_(learnedController.clone()) {} + + MpcnetBehavioralController() = default; + ~MpcnetBehavioralController() override = default; + MpcnetBehavioralController* clone() const override { return new MpcnetBehavioralController(*this); } + + /** + * Set the mixture parameter. + * @param [in] alpha : The mixture parameter. + */ + void setAlpha(scalar_t alpha) { alpha_ = alpha; } + + /** + * Set the optimal controller. + * @param [in] optimalController : The optimal controller (this class takes ownership of a clone). + */ + void setOptimalController(const ControllerBase& optimalController) { optimalControllerPtr_.reset(optimalController.clone()); } + + /** + * Set the learned controller. + * @param [in] learnedController : The learned controller (this class takes ownership of a clone). + */ + void setLearnedController(const MpcnetControllerBase& learnedController) { learnedControllerPtr_.reset(learnedController.clone()); } + + vector_t computeInput(scalar_t t, const vector_t& x) override; + ControllerType getType() const override { return ControllerType::BEHAVIORAL; } + + int size() const override; + void clear() override; + bool empty() const override; + void concatenate(const ControllerBase* otherController, int index, int length) override; + + private: + MpcnetBehavioralController(const MpcnetBehavioralController& other) + : MpcnetBehavioralController(other.alpha_, *other.optimalControllerPtr_, *other.learnedControllerPtr_) {} + + scalar_t alpha_; + std::unique_ptr optimalControllerPtr_; + std::unique_ptr learnedControllerPtr_; +}; + +} // namespace mpcnet +} // namespace ocs2 diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/control/MpcnetControllerBase.h b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/control/MpcnetControllerBase.h new file mode 100644 index 000000000..8a669cf79 --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/control/MpcnetControllerBase.h @@ -0,0 +1,57 @@ +/****************************************************************************** +Copyright (c) 2022, Farbod Farshidian. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +******************************************************************************/ + +#pragma once + +#include + +namespace ocs2 { +namespace mpcnet { + +/** + * The base class for all controllers that use a MPC-Net policy. + */ +class MpcnetControllerBase : public ControllerBase { + public: + MpcnetControllerBase() = default; + ~MpcnetControllerBase() override = default; + MpcnetControllerBase* clone() const override = 0; + + /** + * Load the model of the policy. + * @param [in] policyFilePath : Path to the file with the model of the policy. + */ + virtual void loadPolicyModel(const std::string& policyFilePath) = 0; + + protected: + MpcnetControllerBase(const MpcnetControllerBase& other) = default; +}; + +} // namespace mpcnet +} // namespace ocs2 diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/control/MpcnetOnnxController.h b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/control/MpcnetOnnxController.h new file mode 100644 index 000000000..89c102eb3 --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/control/MpcnetOnnxController.h @@ -0,0 +1,111 @@ +/****************************************************************************** +Copyright (c) 2022, Farbod Farshidian. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +******************************************************************************/ + +#pragma once + +#include + +#include + +#include "ocs2_mpcnet_core/MpcnetDefinitionBase.h" +#include "ocs2_mpcnet_core/control/MpcnetControllerBase.h" + +namespace ocs2 { +namespace mpcnet { + +/** + * Convenience function for creating an environment for ONNX Runtime and getting a corresponding shared pointer. + * @note Only one environment per process can be created. The environment offers some threading and logging options. + * @return Pointer to the environment for ONNX Runtime. + */ +inline std::shared_ptr createOnnxEnvironment() { + return std::make_shared(ORT_LOGGING_LEVEL_WARNING, "MpcnetOnnxController"); +} + +/** + * A neural network controller using ONNX Runtime based on the Open Neural Network Exchange (ONNX) format. + * The model of the policy computes u = model(t, x) with + * t: generalized time (1 x dimensionOfTime), + * x: relative state (1 x dimensionOfState), + * u: predicted input (1 x dimensionOfInput), + * @note The additional first dimension with size 1 for the variables of the model comes from batch processing during training. + */ +class MpcnetOnnxController final : public MpcnetControllerBase { + public: + /** + * Constructor. + * @note The class is not fully instantiated until calling loadPolicyModel(). + * @param [in] mpcnetDefinitionPtr : Pointer to the MPC-Net definitions. + * @param [in] referenceManagerPtr : Pointer to the reference manager. + * @param [in] onnxEnvironmentPtr : Pointer to the environment for ONNX Runtime. + */ + MpcnetOnnxController(std::shared_ptr mpcnetDefinitionPtr, + std::shared_ptr referenceManagerPtr, std::shared_ptr onnxEnvironmentPtr) + : mpcnetDefinitionPtr_(std::move(mpcnetDefinitionPtr)), + referenceManagerPtr_(std::move(referenceManagerPtr)), + onnxEnvironmentPtr_(std::move(onnxEnvironmentPtr)) {} + + ~MpcnetOnnxController() override = default; + MpcnetOnnxController* clone() const override { return new MpcnetOnnxController(*this); } + + void loadPolicyModel(const std::string& policyFilePath) override; + + vector_t computeInput(const scalar_t t, const vector_t& x) override; + ControllerType getType() const override { return ControllerType::ONNX; } + + int size() const override { throw std::runtime_error("[MpcnetOnnxController::size] not implemented."); } + void clear() override { throw std::runtime_error("[MpcnetOnnxController::clear] not implemented."); } + bool empty() const override { throw std::runtime_error("[MpcnetOnnxController::empty] not implemented."); } + void concatenate(const ControllerBase* otherController, int index, int length) override { + throw std::runtime_error("[MpcnetOnnxController::concatenate] not implemented."); + } + + private: + using tensor_element_t = float; + + MpcnetOnnxController(const MpcnetOnnxController& other) + : MpcnetOnnxController(other.mpcnetDefinitionPtr_, other.referenceManagerPtr_, other.onnxEnvironmentPtr_) { + if (!other.policyFilePath_.empty()) { + loadPolicyModel(other.policyFilePath_); + } + } + + std::shared_ptr mpcnetDefinitionPtr_; + std::shared_ptr referenceManagerPtr_; + std::shared_ptr onnxEnvironmentPtr_; + std::string policyFilePath_; + std::unique_ptr sessionPtr_; + std::vector inputNames_; + std::vector outputNames_; + std::vector> inputShapes_; + std::vector> outputShapes_; +}; + +} // namespace mpcnet +} // namespace ocs2 diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/dummy/MpcnetDummyLoopRos.h b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/dummy/MpcnetDummyLoopRos.h new file mode 100644 index 000000000..b1e954e28 --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/dummy/MpcnetDummyLoopRos.h @@ -0,0 +1,112 @@ +/****************************************************************************** +Copyright (c) 2022, Farbod Farshidian. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +******************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include + +#include "ocs2_mpcnet_core/control/MpcnetControllerBase.h" + +namespace ocs2 { +namespace mpcnet { + +/** + * Dummy loop to test a robot controlled by an MPC-Net policy. + */ +class MpcnetDummyLoopRos { + public: + /** + * Constructor. + * @param [in] controlFrequency : Minimum frequency at which the MPC-Net policy should be called. + * @param [in] rosFrequency : Frequency at which the ROS observers are updated. + * @param [in] mpcnetPtr : Pointer to the MPC-Net policy to be used (this class takes ownership). + * @param [in] rolloutPtr : Pointer to the rollout to be used (this class takes ownership). + * @param [in] rosReferenceManagerPtr : Pointer to the reference manager to be used (shared ownership). + */ + MpcnetDummyLoopRos(scalar_t controlFrequency, scalar_t rosFrequency, std::unique_ptr mpcnetPtr, + std::unique_ptr rolloutPtr, std::shared_ptr rosReferenceManagerPtr); + + /** + * Default destructor. + */ + virtual ~MpcnetDummyLoopRos() = default; + + /** + * Runs the dummy loop. + * @param [in] systemObservation: The initial system observation. + * @param [in] targetTrajectories: The initial target trajectories. + */ + void run(const SystemObservation& systemObservation, const TargetTrajectories& targetTrajectories); + + /** + * Adds one observer to the vector of observers that need to be informed about the system observation, primal solution and command data. + * Each observer is updated once after running a rollout. + * @param [in] observer : The observer to add. + */ + void addObserver(std::shared_ptr observer); + + /** + * Adds one module to the vector of modules that need to be synchronized with the policy. + * Each module is updated once before calling the policy. + * @param [in] synchronizedModule : The module to add. + */ + void addSynchronizedModule(std::shared_ptr synchronizedModule); + + protected: + /** + * Runs a rollout. + * @param [in] duration : The duration of the run. + * @param [in] initialSystemObservation : The initial system observation. + * @param [out] finalSystemObservation : The final system observation. + */ + void rollout(scalar_t duration, const SystemObservation& initialSystemObservation, SystemObservation& finalSystemObservation); + + /** + * Update the reference manager and the synchronized modules. + * @param [in] time : The current time. + * @param [in] state : The cuurent state. + */ + void preSolverRun(scalar_t time, const vector_t& state); + + private: + scalar_t controlFrequency_; + scalar_t rosFrequency_; + std::unique_ptr mpcnetPtr_; + std::unique_ptr rolloutPtr_; + std::shared_ptr rosReferenceManagerPtr_; + std::vector> observerPtrs_; + std::vector> synchronizedModulePtrs_; +}; + +} // namespace mpcnet +} // namespace ocs2 diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/dummy/MpcnetDummyObserverRos.h b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/dummy/MpcnetDummyObserverRos.h new file mode 100644 index 000000000..a03661bb4 --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/dummy/MpcnetDummyObserverRos.h @@ -0,0 +1,69 @@ +/****************************************************************************** +Copyright (c) 2022, Farbod Farshidian. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +******************************************************************************/ + +#pragma once + +#include + +#include + +namespace ocs2 { +namespace mpcnet { + +/** + * Dummy observer that publishes the current system observation that is required for some target command nodes. + */ +class MpcnetDummyObserverRos : public DummyObserver { + public: + /** + * Constructor. + * @param [in] nodeHandle : The ROS node handle. + * @param [in] topicPrefix : The prefix defines the names for the observation's publishing topic "topicPrefix_mpc_observation". + */ + explicit MpcnetDummyObserverRos(ros::NodeHandle& nodeHandle, std::string topicPrefix = "anonymousRobot"); + + /** + * Default destructor. + */ + ~MpcnetDummyObserverRos() override = default; + + /** + * Update and publish. + * @param [in] observation: The current system observation. + * @param [in] primalSolution: The current primal solution. + * @param [in] command: The given command data. + */ + void update(const SystemObservation& observation, const PrimalSolution& primalSolution, const CommandData& command) override; + + private: + ros::Publisher observationPublisher_; +}; + +} // namespace mpcnet +} // namespace ocs2 diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/rollout/MpcnetData.h b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/rollout/MpcnetData.h new file mode 100644 index 000000000..7fd63215e --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/rollout/MpcnetData.h @@ -0,0 +1,86 @@ +/****************************************************************************** +Copyright (c) 2022, Farbod Farshidian. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +******************************************************************************/ + +#pragma once + +#include +#include + +#include "ocs2_mpcnet_core/MpcnetDefinitionBase.h" + +namespace ocs2 { +namespace mpcnet { + +/** + * Data point collected during the data generation rollout. + */ +struct DataPoint { + /** Mode of the system. */ + size_t mode; + /** Absolute time. */ + scalar_t t; + /** Observed state. */ + vector_t x; + /** Optimal control input. */ + vector_t u; + /** Observation given as input to the policy. */ + vector_t observation; + /** Action transformation applied to the output of the policy. */ + std::pair actionTransformation; + /** Linear-quadratic approximation of the Hamiltonian, using x and u as development/expansion points. */ + ScalarFunctionQuadraticApproximation hamiltonian; +}; +using data_point_t = DataPoint; +using data_array_t = std::vector; + +/** + * Get a data point. + * @param [in] mpc : The MPC with a pointer to the underlying solver. + * @param [in] mpcnetDefinition : The MPC-Net definitions. + * @param [in] deviation : The state deviation from the nominal state where to get the data point from. + * @return A data point. + */ +inline data_point_t getDataPoint(MPC_BASE& mpc, MpcnetDefinitionBase& mpcnetDefinition, const vector_t& deviation) { + data_point_t dataPoint; + const auto& referenceManager = mpc.getSolverPtr()->getReferenceManager(); + const auto primalSolution = mpc.getSolverPtr()->primalSolution(mpc.getSolverPtr()->getFinalTime()); + dataPoint.t = primalSolution.timeTrajectory_.front(); + dataPoint.x = primalSolution.stateTrajectory_.front() + deviation; + dataPoint.u = primalSolution.controllerPtr_->computeInput(dataPoint.t, dataPoint.x); + dataPoint.mode = primalSolution.modeSchedule_.modeAtTime(dataPoint.t); + dataPoint.observation = mpcnetDefinition.getObservation(dataPoint.t, dataPoint.x, referenceManager.getModeSchedule(), + referenceManager.getTargetTrajectories()); + dataPoint.actionTransformation = mpcnetDefinition.getActionTransformation(dataPoint.t, dataPoint.x, referenceManager.getModeSchedule(), + referenceManager.getTargetTrajectories()); + dataPoint.hamiltonian = mpc.getSolverPtr()->getHamiltonian(dataPoint.t, dataPoint.x, dataPoint.u); + return dataPoint; +} + +} // namespace mpcnet +} // namespace ocs2 diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/rollout/MpcnetDataGeneration.h b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/rollout/MpcnetDataGeneration.h new file mode 100644 index 000000000..6b0ee2a0e --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/rollout/MpcnetDataGeneration.h @@ -0,0 +1,95 @@ +/****************************************************************************** +Copyright (c) 2022, Farbod Farshidian. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +******************************************************************************/ + +#pragma once + +#include "ocs2_mpcnet_core/rollout/MpcnetData.h" +#include "ocs2_mpcnet_core/rollout/MpcnetRolloutBase.h" + +namespace ocs2 { +namespace mpcnet { + +/** + * A class for generating data from a system that is forward simulated with a behavioral controller. + * @note Usually the behavioral controller moves from the MPC policy to the MPC-Net policy throughout the training process. + */ +class MpcnetDataGeneration final : public MpcnetRolloutBase { + public: + /** + * Constructor. + * @param [in] mpcPtr : Pointer to the MPC solver to be used (this class takes ownership). + * @param [in] mpcnetPtr : Pointer to the MPC-Net policy to be used (this class takes ownership). + * @param [in] rolloutPtr : Pointer to the rollout to be used (this class takes ownership). + * @param [in] mpcnetDefinitionPtr : Pointer to the MPC-Net definitions to be used (shared ownership). + * @param [in] referenceManagerPtr : Pointer to the reference manager to be used (shared ownership). + */ + MpcnetDataGeneration(std::unique_ptr mpcPtr, std::unique_ptr mpcnetPtr, + std::unique_ptr rolloutPtr, std::shared_ptr mpcnetDefinitionPtr, + std::shared_ptr referenceManagerPtr) + : MpcnetRolloutBase(std::move(mpcPtr), std::move(mpcnetPtr), std::move(rolloutPtr), std::move(mpcnetDefinitionPtr), + std::move(referenceManagerPtr)) {} + + /** + * Default destructor. + */ + ~MpcnetDataGeneration() override = default; + + /** + * Deleted copy constructor. + */ + MpcnetDataGeneration(const MpcnetDataGeneration&) = delete; + + /** + * Deleted copy assignment. + */ + MpcnetDataGeneration& operator=(const MpcnetDataGeneration&) = delete; + + /** + * Run the data generation. + * @param [in] alpha : The mixture parameter for the behavioral controller. + * @param [in] policyFilePath : The path to the file with the learned policy for the behavioral controller. + * @param [in] timeStep : The time step for the forward simulation of the system with the behavioral controller. + * @param [in] dataDecimation : The integer factor used for downsampling the data signal. + * @param [in] nSamples : The number of samples drawn from a multivariate normal distribution around the nominal states. + * @param [in] samplingCovariance : The covariance matrix used for sampling from a multivariate normal distribution. + * @param [in] initialObservation : The initial system observation to start from (time and state required). + * @param [in] modeSchedule : The mode schedule providing the event times and mode sequence. + * @param [in] targetTrajectories : The target trajectories to be tracked. + * @return Pointer to the data array with the generated data. + */ + const data_array_t* run(scalar_t alpha, const std::string& policyFilePath, scalar_t timeStep, size_t dataDecimation, size_t nSamples, + const matrix_t& samplingCovariance, const SystemObservation& initialObservation, const ModeSchedule& modeSchedule, + const TargetTrajectories& targetTrajectories); + + private: + data_array_t dataArray_; +}; + +} // namespace mpcnet +} // namespace ocs2 diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/rollout/MpcnetMetrics.h b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/rollout/MpcnetMetrics.h new file mode 100644 index 000000000..120da00c7 --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/rollout/MpcnetMetrics.h @@ -0,0 +1,50 @@ +/****************************************************************************** +Copyright (c) 2022, Farbod Farshidian. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +******************************************************************************/ + +#pragma once + +#include + +namespace ocs2 { +namespace mpcnet { + +/** + * Metrics computed during the policy evaluation rollout. + */ +struct Metrics { + /** Survival time. */ + scalar_t survivalTime = 0.0; + /** Hamiltonian incurred over time. */ + scalar_t incurredHamiltonian = 0.0; +}; +using metrics_t = Metrics; +using metrics_array_t = std::vector; + +} // namespace mpcnet +} // namespace ocs2 diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/rollout/MpcnetPolicyEvaluation.h b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/rollout/MpcnetPolicyEvaluation.h new file mode 100644 index 000000000..24d8ac57e --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/rollout/MpcnetPolicyEvaluation.h @@ -0,0 +1,88 @@ +/****************************************************************************** +Copyright (c) 2022, Farbod Farshidian. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +******************************************************************************/ + +#pragma once + +#include "ocs2_mpcnet_core/rollout/MpcnetMetrics.h" +#include "ocs2_mpcnet_core/rollout/MpcnetRolloutBase.h" + +namespace ocs2 { +namespace mpcnet { + +/** + * A class for evaluating a policy for a system that is forward simulated with a behavioral controller. + * @note Usually the behavioral controller is evaluated for the MPC-Net policy (alpha = 0). + */ +class MpcnetPolicyEvaluation final : public MpcnetRolloutBase { + public: + /** + * Constructor. + * @param [in] mpcPtr : Pointer to the MPC solver to be used (this class takes ownership). + * @param [in] mpcnetPtr : Pointer to the MPC-Net policy to be used (this class takes ownership). + * @param [in] rolloutPtr : Pointer to the rollout to be used (this class takes ownership). + * @param [in] mpcnetDefinitionPtr : Pointer to the MPC-Net definitions to be used (shared ownership). + * @param [in] referenceManagerPtr : Pointer to the reference manager to be used (shared ownership). + */ + MpcnetPolicyEvaluation(std::unique_ptr mpcPtr, std::unique_ptr mpcnetPtr, + std::unique_ptr rolloutPtr, std::shared_ptr mpcnetDefinitionPtr, + std::shared_ptr referenceManagerPtr) + : MpcnetRolloutBase(std::move(mpcPtr), std::move(mpcnetPtr), std::move(rolloutPtr), std::move(mpcnetDefinitionPtr), + std::move(referenceManagerPtr)) {} + + /** + * Default destructor. + */ + ~MpcnetPolicyEvaluation() override = default; + + /** + * Deleted copy constructor. + */ + MpcnetPolicyEvaluation(const MpcnetPolicyEvaluation&) = delete; + + /** + * Deleted copy assignment. + */ + MpcnetPolicyEvaluation& operator=(const MpcnetPolicyEvaluation&) = delete; + + /** + * Run the policy evaluation. + * @param [in] alpha : The mixture parameter for the behavioral controller. + * @param [in] policyFilePath : The path to the file with the learned policy for the behavioral controller. + * @param [in] timeStep : The time step for the forward simulation of the system with the behavioral controller. + * @param [in] initialObservation : The initial system observation to start from (time and state required). + * @param [in] modeSchedule : The mode schedule providing the event times and mode sequence. + * @param [in] targetTrajectories : The target trajectories to be tracked. + * @return The computed metrics. + */ + metrics_t run(scalar_t alpha, const std::string& policyFilePath, scalar_t timeStep, const SystemObservation& initialObservation, + const ModeSchedule& modeSchedule, const TargetTrajectories& targetTrajectories); +}; + +} // namespace mpcnet +} // namespace ocs2 diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/rollout/MpcnetRolloutBase.h b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/rollout/MpcnetRolloutBase.h new file mode 100644 index 000000000..23c16f1d2 --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/rollout/MpcnetRolloutBase.h @@ -0,0 +1,117 @@ +/****************************************************************************** +Copyright (c) 2022, Farbod Farshidian. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +******************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "ocs2_mpcnet_core/MpcnetDefinitionBase.h" +#include "ocs2_mpcnet_core/control/MpcnetBehavioralController.h" +#include "ocs2_mpcnet_core/control/MpcnetControllerBase.h" + +namespace ocs2 { +namespace mpcnet { + +/** + * The base class for doing rollouts for a system that is forward simulated with a behavioral controller. + * The behavioral policy is a mixture of an MPC policy and an MPC-Net policy (e.g. a neural network). + */ +class MpcnetRolloutBase { + public: + /** + * Constructor. + * @param [in] mpcPtr : Pointer to the MPC solver to be used (this class takes ownership). + * @param [in] mpcnetPtr : Pointer to the MPC-Net policy to be used (this class takes ownership). + * @param [in] rolloutPtr : Pointer to the rollout to be used (this class takes ownership). + * @param [in] mpcnetDefinitionPtr : Pointer to the MPC-Net definitions to be used (shared ownership). + * @param [in] referenceManagerPtr : Pointer to the reference manager to be used (shared ownership). + */ + MpcnetRolloutBase(std::unique_ptr mpcPtr, std::unique_ptr mpcnetPtr, + std::unique_ptr rolloutPtr, std::shared_ptr mpcnetDefinitionPtr, + std::shared_ptr referenceManagerPtr) + : mpcPtr_(std::move(mpcPtr)), + mpcnetPtr_(std::move(mpcnetPtr)), + rolloutPtr_(std::move(rolloutPtr)), + mpcnetDefinitionPtr_(std::move(mpcnetDefinitionPtr)), + referenceManagerPtr_(std::move(referenceManagerPtr)), + behavioralControllerPtr_(new MpcnetBehavioralController()) {} + + /** + * Default destructor. + */ + virtual ~MpcnetRolloutBase() = default; + + /** + * Deleted copy constructor. + */ + MpcnetRolloutBase(const MpcnetRolloutBase&) = delete; + + /** + * Deleted copy assignment. + */ + MpcnetRolloutBase& operator=(const MpcnetRolloutBase&) = delete; + + protected: + /** + * (Re)set system components. + * @param [in] alpha : The mixture parameter for the behavioral controller. + * @param [in] policyFilePath : The path to the file with the learned policy for the controller. + * @param [in] initialObservation : The initial system observation to start from (time and state required). + * @param [in] modeSchedule : The mode schedule providing the event times and mode sequence. + * @param [in] targetTrajectories : The target trajectories to be tracked. + */ + void set(scalar_t alpha, const std::string& policyFilePath, const SystemObservation& initialObservation, const ModeSchedule& modeSchedule, + const TargetTrajectories& targetTrajectories); + + /** + * Simulate the system one step forward. + * @param [in] timeStep : The time step for the forward simulation of the system with the behavioral controller. + */ + void step(scalar_t timeStep); + + std::unique_ptr mpcPtr_; + std::shared_ptr mpcnetDefinitionPtr_; + std::unique_ptr behavioralControllerPtr_; + SystemObservation systemObservation_; + PrimalSolution primalSolution_; + + private: + std::unique_ptr mpcnetPtr_; + std::unique_ptr rolloutPtr_; + std::shared_ptr referenceManagerPtr_; +}; + +} // namespace mpcnet +} // namespace ocs2 diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/rollout/MpcnetRolloutManager.h b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/rollout/MpcnetRolloutManager.h new file mode 100644 index 000000000..c1862e9bc --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/include/ocs2_mpcnet_core/rollout/MpcnetRolloutManager.h @@ -0,0 +1,137 @@ +/****************************************************************************** +Copyright (c) 2022, Farbod Farshidian. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +******************************************************************************/ + +#pragma once + +#include + +#include "ocs2_mpcnet_core/rollout/MpcnetDataGeneration.h" +#include "ocs2_mpcnet_core/rollout/MpcnetPolicyEvaluation.h" + +namespace ocs2 { +namespace mpcnet { + +/** + * A class to manage the data generation and policy evaluation rollouts for MPC-Net. + */ +class MpcnetRolloutManager { + public: + /** + * Constructor. + * @note The first nDataGenerationThreads pointers will be used for the data generation and the next nPolicyEvaluationThreads pointers for + * the policy evaluation. + * @param [in] nDataGenerationThreads : Number of data generation threads. + * @param [in] nPolicyEvaluationThreads : Number of policy evaluation threads. + * @param [in] mpcPtrs : Pointers to the MPC solvers to be used. + * @param [in] mpcnetPtrs : Pointers to the MPC-Net policies to be used. + * @param [in] rolloutPtrs : Pointers to the rollouts to be used. + * @param [in] mpcnetDefinitionPtrs : Pointers to the MPC-Net definitions to be used. + * @param [in] referenceManagerPtrs : Pointers to the reference managers to be used. + */ + MpcnetRolloutManager(size_t nDataGenerationThreads, size_t nPolicyEvaluationThreads, std::vector> mpcPtrs, + std::vector> mpcnetPtrs, std::vector> rolloutPtrs, + std::vector> mpcnetDefinitionPtrs, + std::vector> referenceManagerPtrs); + + /** + * Default destructor. + */ + virtual ~MpcnetRolloutManager() = default; + + /** + * Starts the data genration forward simulated by a behavioral controller. + * @param [in] alpha : The mixture parameter for the behavioral controller. + * @param [in] policyFilePath : The path to the file with the learned policy for the behavioral controller. + * @param [in] timeStep : The time step for the forward simulation of the system with the behavioral controller. + * @param [in] dataDecimation : The integer factor used for downsampling the data signal. + * @param [in] nSamples : The number of samples drawn from a multivariate normal distribution around the nominal states. + * @param [in] samplingCovariance : The covariance matrix used for sampling from a multivariate normal distribution. + * @param [in] initialObservations : The initial system observations to start from (time and state required). + * @param [in] modeSchedules : The mode schedules providing the event times and mode sequence. + * @param [in] targetTrajectories : The target trajectories to be tracked. + */ + void startDataGeneration(scalar_t alpha, const std::string& policyFilePath, scalar_t timeStep, size_t dataDecimation, size_t nSamples, + const matrix_t& samplingCovariance, const std::vector& initialObservations, + const std::vector& modeSchedules, const std::vector& targetTrajectories); + + /** + * Check if data generation is done. + * @return True if done. + */ + bool isDataGenerationDone(); + + /** + * Get the data generated from the data generation rollout. + * @return The generated data. + */ + const data_array_t& getGeneratedData(); + + /** + * Starts the policy evaluation forward simulated by a behavioral controller. + * @param [in] alpha : The mixture parameter for the behavioral controller. + * @param [in] policyFilePath : The path to the file with the learned policy for the behavioral controller. + * @param [in] timeStep : The time step for the forward simulation of the system with the behavioral controller. + * @param [in] initialObservations : The initial system observations to start from (time and state required). + * @param [in] modeSchedules : The mode schedules providing the event times and mode sequence. + * @param [in] targetTrajectories : The target trajectories to be tracked. + */ + void startPolicyEvaluation(scalar_t alpha, const std::string& policyFilePath, scalar_t timeStep, + const std::vector& initialObservations, const std::vector& modeSchedules, + const std::vector& targetTrajectories); + + /** + * Check if policy evaluation is done. + * @return True if done. + */ + bool isPolicyEvaluationDone(); + + /** + * Get the metrics computed from the policy evaluation rollout. + * @return The computed metrics. + */ + metrics_array_t getComputedMetrics(); + + private: + // data generation variables + size_t nDataGenerationThreads_; + std::atomic_int nDataGenerationTasksDone_; + std::unique_ptr dataGenerationThreadPoolPtr_; + std::vector> dataGenerationPtrs_; + std::vector> dataGenerationFtrs_; + data_array_t dataArray_; + // policy evaluation variables + size_t nPolicyEvaluationThreads_; + std::atomic_int nPolicyEvaluationTasksDone_; + std::unique_ptr policyEvaluationThreadPoolPtr_; + std::vector> policyEvaluationPtrs_; + std::vector> policyEvaluationFtrs_; +}; + +} // namespace mpcnet +} // namespace ocs2 diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/misc/onnxruntime/cmake/onnxruntimeConfig.cmake b/ocs2_mpcnet/ocs2_mpcnet_core/misc/onnxruntime/cmake/onnxruntimeConfig.cmake new file mode 100644 index 000000000..28cba4e94 --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/misc/onnxruntime/cmake/onnxruntimeConfig.cmake @@ -0,0 +1,28 @@ +# Custom cmake config file to enable find_package(onnxruntime) without modifying LIBRARY_PATH and LD_LIBRARY_PATH +# +# This will define the following variables: +# onnxruntime_FOUND -- True if the system has the onnxruntime library +# onnxruntime_INCLUDE_DIRS -- The include directories for onnxruntime +# onnxruntime_LIBRARIES -- Libraries to link against +# onnxruntime_CXX_FLAGS -- Additional (required) compiler flags + +include(FindPackageHandleStandardArgs) + +# Assume we are in /share/cmake/onnxruntime/onnxruntimeConfig.cmake +get_filename_component(CMAKE_CURRENT_LIST_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH) +get_filename_component(onnxruntime_INSTALL_PREFIX "${CMAKE_CURRENT_LIST_DIR}/../../../" ABSOLUTE) + +set(onnxruntime_INCLUDE_DIRS ${onnxruntime_INSTALL_PREFIX}/include) +set(onnxruntime_LIBRARIES onnxruntime) +set(onnxruntime_CXX_FLAGS "") # no flags needed + +find_library(onnxruntime_LIBRARY onnxruntime + PATHS "${onnxruntime_INSTALL_PREFIX}/lib" +) + +add_library(onnxruntime SHARED IMPORTED) +set_property(TARGET onnxruntime PROPERTY IMPORTED_LOCATION "${onnxruntime_LIBRARY}") +set_property(TARGET onnxruntime PROPERTY INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_INCLUDE_DIRS}") +set_property(TARGET onnxruntime PROPERTY INTERFACE_COMPILE_OPTIONS "${onnxruntime_CXX_FLAGS}") + +find_package_handle_standard_args(onnxruntime DEFAULT_MSG onnxruntime_LIBRARY onnxruntime_INCLUDE_DIRS) diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/misc/onnxruntime/cmake/onnxruntimeVersion.cmake b/ocs2_mpcnet/ocs2_mpcnet_core/misc/onnxruntime/cmake/onnxruntimeVersion.cmake new file mode 100644 index 000000000..8dbbb0498 --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/misc/onnxruntime/cmake/onnxruntimeVersion.cmake @@ -0,0 +1,13 @@ +# Custom cmake version file + +set(PACKAGE_VERSION "1.7.0") + +# Check whether the requested PACKAGE_FIND_VERSION is compatible +if("${PACKAGE_VERSION}" VERSION_LESS "${PACKAGE_FIND_VERSION}") + set(PACKAGE_VERSION_COMPATIBLE FALSE) +else() + set(PACKAGE_VERSION_COMPATIBLE TRUE) + if("${PACKAGE_VERSION}" VERSION_EQUAL "${PACKAGE_FIND_VERSION}") + set(PACKAGE_VERSION_EXACT TRUE) + endif() +endif() diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/package.xml b/ocs2_mpcnet/ocs2_mpcnet_core/package.xml new file mode 100644 index 000000000..48eb0efea --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/package.xml @@ -0,0 +1,24 @@ + + + ocs2_mpcnet_core + 0.0.0 + The ocs2_mpcnet_core package + + Alexander Reske + + Farbod Farshidian + Alexander Reske + + BSD-3 + + catkin + + cmake_clang_tools + + pybind11_catkin + + ocs2_mpc + ocs2_python_interface + ocs2_ros_interfaces + + diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/__init__.py b/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/__init__.py new file mode 100644 index 000000000..0b2bd554f --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/__init__.py @@ -0,0 +1,7 @@ +from ocs2_mpcnet_core.MpcnetPybindings import size_array, scalar_array, vector_array, matrix_array +from ocs2_mpcnet_core.MpcnetPybindings import ScalarFunctionQuadraticApproximation +from ocs2_mpcnet_core.MpcnetPybindings import SystemObservation, SystemObservationArray +from ocs2_mpcnet_core.MpcnetPybindings import ModeSchedule, ModeScheduleArray +from ocs2_mpcnet_core.MpcnetPybindings import TargetTrajectories, TargetTrajectoriesArray +from ocs2_mpcnet_core.MpcnetPybindings import DataPoint, DataArray +from ocs2_mpcnet_core.MpcnetPybindings import Metrics, MetricsArray diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/config.py b/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/config.py new file mode 100644 index 000000000..1943898aa --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/config.py @@ -0,0 +1,70 @@ +############################################################################### +# Copyright (c) 2022, Farbod Farshidian. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +############################################################################### + +"""Configuration class. + +Provides a class that handles the configuration parameters. +""" + +import yaml +import torch + + +class Config: + """Config. + + Loads configuration parameters from a YAML file and provides access to them as attributes of this class. + + Attributes: + DTYPE: The PyTorch data type. + DEVICE: The PyTorch device to select. + """ + + def __init__(self, config_file_path: str) -> None: + """Initializes the Config class. + + Initializes the Config class by setting fixed attributes and by loading attributes from a YAML file. + + Args: + config_file_path: A string with the path to the configuration file. + """ + # + # class config + # + # data type for tensor elements + self.DTYPE = torch.float + # device on which tensors will be allocated + self.DEVICE = torch.device("cuda") + # + # yaml config + # + with open(config_file_path, "r") as stream: + config = yaml.safe_load(stream) + for key, value in config.items(): + setattr(self, key, value) diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/helper.py b/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/helper.py new file mode 100644 index 000000000..345e7d96f --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/helper.py @@ -0,0 +1,309 @@ +############################################################################### +# Copyright (c) 2022, Farbod Farshidian. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +############################################################################### + +"""Helper functions. + +Provides helper functions, such as convenience functions for batch-wise operations or access to OCC2 types. +""" + +import torch +import numpy as np +from typing import Tuple, Dict + +from ocs2_mpcnet_core import ( + size_array, + scalar_array, + vector_array, + SystemObservation, + SystemObservationArray, + ModeSchedule, + ModeScheduleArray, + TargetTrajectories, + TargetTrajectoriesArray, +) + + +def bdot(bv1: torch.Tensor, bv2: torch.Tensor) -> torch.Tensor: + """Batch-wise dot product. + + Performs a batch-wise dot product between two batches of vectors with batch size B and dimension N. Supports + broadcasting for the batch dimension. + + Args: + bv1: A (B,N) tensor containing a batch of vectors. + bv2: A (B,N) tensor containing a batch of vectors. + + Returns: + A (B) tensor containing the batch-wise dot product. + """ + return torch.sum(torch.mul(bv1, bv2), dim=1) + + +def bmv(bm: torch.Tensor, bv: torch.Tensor) -> torch.Tensor: + """Batch-wise matrix-vector product. + + Performs a batch-wise matrix-vector product between a batch of MxN matrices and a batch of vectors of dimension N, + each with batch size B. Supports broadcasting for the batch dimension. + + Args: + bm: A (B,M,N) tensor containing a batch of matrices. + bv: A (B,N) tensor containing a batch of vectors. + + Returns: + A (B,M) tensor containing the batch-wise matrix-vector product. + """ + return torch.matmul(bm, bv.unsqueeze(dim=2)).squeeze(dim=2) + + +def bmm(bm1: torch.Tensor, bm2: torch.Tensor) -> torch.Tensor: + """Batch-wise matrix-matrix product. + + Performs a batch-wise matrix-matrix product between a batch of MxK matrices and a batch of KxN matrices, each with + batch size B. Supports broadcasting for the batch dimension (unlike torch.bmm). + + Args: + bm1: A (B,M,K) tensor containing a batch of matrices. + bm2: A (B,K,N) tensor containing a batch of matrices. + + Returns: + A (B,M,N) tensor containing the batch-wise matrix-matrix product. + """ + return torch.matmul(bm1, bm2) + + +def get_size_array(data: np.ndarray) -> size_array: + """Get an OCS2 size array. + + Creates an OCS2 size array and fills it with integer data from a NumPy array. + + Args: + data: A NumPy array of shape (N) containing integers. + + Returns: + An OCS2 size array of length N. + """ + my_size_array = size_array() + my_size_array.resize(len(data)) + for i in range(len(data)): + my_size_array[i] = data[i] + return my_size_array + + +def get_scalar_array(data: np.ndarray) -> scalar_array: + """Get an OCS2 scalar array. + + Creates an OCS2 scalar array and fills it with float data from a NumPy array. + + Args: + data: A NumPy array of shape (N) containing floats. + + Returns: + An OCS2 scalar array of length N. + """ + my_scalar_array = scalar_array() + my_scalar_array.resize(len(data)) + for i in range(len(data)): + my_scalar_array[i] = data[i] + return my_scalar_array + + +def get_vector_array(data: np.ndarray) -> vector_array: + """Get an OCS2 vector array. + + Creates an OCS2 vector array and fills it with float data from a NumPy array. + + Args: + data: A NumPy array of shape (M,N) containing floats. + + Returns: + An OCS2 vector array of length M with vectors of dimension N. + """ + my_vector_array = vector_array() + my_vector_array.resize(len(data)) + for i in range(len(data)): + my_vector_array[i] = data[i] + return my_vector_array + + +def get_system_observation(mode: int, time: float, state: np.ndarray, input: np.ndarray) -> SystemObservation: + """Get an OCS2 system observation object. + + Creates an OCS2 system observation object and fills it with data. + + Args: + mode: The observed mode given by an integer. + time: The observed time given by a float. + state: The observed state given by a NumPy array of shape (M) containing floats. + input: The observed input given by a NumPy array of shape (N) containing floats. + + Returns: + An OCS2 system observation object. + """ + system_observation = SystemObservation() + system_observation.mode = mode + system_observation.time = time + system_observation.state = state + system_observation.input = input + return system_observation + + +def get_system_observation_array(length: int) -> SystemObservationArray: + """Get an OCS2 system observation array. + + Creates an OCS2 system observation array but does not fill it with data. + + Args: + length: The length that the array should have given by an integer. + + Returns: + An OCS2 system observation array of the desired length. + """ + system_observation_array = SystemObservationArray() + system_observation_array.resize(length) + return system_observation_array + + +def get_target_trajectories( + time_trajectory: np.ndarray, state_trajectory: np.ndarray, input_trajectory: np.ndarray +) -> TargetTrajectories: + """Get an OCS2 target trajectories object. + + Creates an OCS2 target trajectories object and fills it with data. + + Args: + time_trajectory: The target time trajectory given by a NumPy array of shape (K) containing floats. + state_trajectory: The target state trajectory given by a NumPy array of shape (K,M) containing floats. + input_trajectory: The target input trajectory given by a NumPy array of shape (K,N) containing floats. + + Returns: + An OCS2 target trajectories object. + """ + time_trajectory_array = get_scalar_array(time_trajectory) + state_trajectory_array = get_vector_array(state_trajectory) + input_trajectory_array = get_vector_array(input_trajectory) + return TargetTrajectories(time_trajectory_array, state_trajectory_array, input_trajectory_array) + + +def get_target_trajectories_array(length: int) -> TargetTrajectoriesArray: + """Get an OCS2 target trajectories array. + + Creates an OCS2 target trajectories array but does not fill it with data. + + Args: + length: The length that the array should have given by an integer. + + Returns: + An OCS2 target trajectories array of the desired length. + """ + target_trajectories_array = TargetTrajectoriesArray() + target_trajectories_array.resize(length) + return target_trajectories_array + + +def get_mode_schedule(event_times: np.ndarray, mode_sequence: np.ndarray) -> ModeSchedule: + """Get an OCS2 mode schedule object. + + Creates an OCS2 mode schedule object and fills it with data. + + Args: + event_times: The event times given by a NumPy array of shape (K-1) containing floats. + mode_sequence: The mode sequence given by a NumPy array of shape (K) containing integers. + + Returns: + An OCS2 mode schedule object. + """ + event_times_array = get_scalar_array(event_times) + mode_sequence_array = get_size_array(mode_sequence) + return ModeSchedule(event_times_array, mode_sequence_array) + + +def get_mode_schedule_array(length: int) -> ModeScheduleArray: + """Get an OCS2 mode schedule array. + + Creates an OCS2 mode schedule array but does not fill it with data. + + Args: + length: The length that the array should have given by an integer. + + Returns: + An OCS2 mode schedule array of the desired length. + """ + mode_schedule_array = ModeScheduleArray() + mode_schedule_array.resize(length) + return mode_schedule_array + + +def get_event_times_and_mode_sequence( + default_mode: int, duration: float, event_times_template: np.ndarray, mode_sequence_template: np.ndarray +) -> Tuple[np.ndarray, np.ndarray]: + """Get the event times and mode sequence describing a mode schedule. + + Creates the event times and mode sequence for a certain time duration from a template (e.g. a gait). + + Args: + default_mode: The default mode prepended and appended to the mode schedule and given by an integer. + duration: The duration of the mode schedule given by a float. + event_times_template: The event times template given by a NumPy array of shape (T) containing floats. + mode_sequence_template: The mode sequence template given by a NumPy array of shape (T) containing integers. + + Returns: + A tuple containing the components of the mode schedule. + - event_times: The event times given by a NumPy array of shape (K-1) containing floats. + - mode_sequence: The mode sequence given by a NumPy array of shape (K) containing integers. + """ + gait_cycle_duration = event_times_template[-1] + num_gait_cycles = int(np.floor(duration / gait_cycle_duration)) + event_times = np.array([0.0], dtype=np.float64) + mode_sequence = np.array([default_mode], dtype=np.uintp) + for _ in range(num_gait_cycles): + event_times = np.append( + event_times, event_times[-1] * np.ones(len(event_times_template)) + event_times_template + ) + mode_sequence = np.append(mode_sequence, mode_sequence_template) + mode_sequence = np.append(mode_sequence, np.array([default_mode], dtype=np.uintp)) + return event_times, mode_sequence + + +def get_one_hot(mode: int, expert_number: int, expert_for_mode: Dict[int, int]) -> np.ndarray: + """Get one hot encoding of mode. + + Get a one hot encoding of a mode represented by a discrete probability distribution, where the sample space is the + set of P individually identified items given by the set of E individually identified experts. + + Args: + mode: The mode of the system given by an integer. + expert_number: The number of experts given by an integer. + expert_for_mode: A dictionary that assigns modes to experts. + + Returns: + p: Discrete probability distribution given by a NumPy array of shape (P) containing floats. + """ + one_hot = np.zeros(expert_number) + one_hot[expert_for_mode[mode]] = 1.0 + return one_hot diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/loss/__init__.py b/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/loss/__init__.py new file mode 100644 index 000000000..c343a60cb --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/loss/__init__.py @@ -0,0 +1,6 @@ +from .base import BaseLoss +from .behavioral_cloning import BehavioralCloningLoss +from .cross_entropy import CrossEntropyLoss +from .hamiltonian import HamiltonianLoss + +__all__ = ["BaseLoss", "BehavioralCloningLoss", "CrossEntropyLoss", "HamiltonianLoss"] diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/loss/base.py b/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/loss/base.py new file mode 100644 index 000000000..329c58089 --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/loss/base.py @@ -0,0 +1,94 @@ +############################################################################### +# Copyright (c) 2022, Farbod Farshidian. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +############################################################################### + +"""Base loss. + +Provides a base class for all loss classes. +""" + +import torch +from abc import ABCMeta, abstractmethod + +from ocs2_mpcnet_core.config import Config + + +class BaseLoss(metaclass=ABCMeta): + """Base loss. + + Provides the interface to all loss classes. + """ + + def __init__(self, config: Config) -> None: + """Initializes the BaseLoss class. + + Initializes the BaseLoss class. + + Args: + config: An instance of the configuration class. + """ + pass + + @abstractmethod + def __call__( + self, + x_query: torch.Tensor, + x_nominal: torch.Tensor, + u_query: torch.Tensor, + u_nominal: torch.Tensor, + p_query: torch.Tensor, + p_nominal: torch.Tensor, + dHdxx: torch.Tensor, + dHdux: torch.Tensor, + dHduu: torch.Tensor, + dHdx: torch.Tensor, + dHdu: torch.Tensor, + H: torch.Tensor, + ) -> torch.Tensor: + """Computes the loss. + + Computes the mean loss for a batch. + + Args: + x_query: A (B,X) tensor with the query (e.g. predicted) states. + x_nominal: A (B,X) tensor with the nominal (e.g. target) states. + u_query: A (B,U) tensor with the query (e.g. predicted) inputs. + u_nominal: A (B,U) tensor with the nominal (e.g. target) inputs. + p_query: A (B,P) tensor with the query (e.g. predicted) discrete probability distributions. + p_nominal: A (B,P) tensor with the nominal (e.g. target) discrete probability distributions. + dHdxx: A (B,X,X) tensor with the state-state Hessians of the Hamiltonian approximations. + dHdux: A (B,U,X) tensor with the input-state Hessians of the Hamiltonian approximations. + dHduu: A (B,U,U) tensor with the input-input Hessians of the Hamiltonian approximations. + dHdx: A (B,X) tensor with the state gradients of the Hamiltonian approximations. + dHdu: A (B,U) tensor with the input gradients of the Hamiltonian approximations. + H: A (B) tensor with the Hamiltonians at the nominal points. + + Returns: + A (1) tensor containing the mean loss. + """ + pass diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/loss/behavioral_cloning.py b/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/loss/behavioral_cloning.py new file mode 100644 index 000000000..5dda2c2d3 --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/loss/behavioral_cloning.py @@ -0,0 +1,119 @@ +############################################################################### +# Copyright (c) 2022, Farbod Farshidian. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +############################################################################### + +"""Behavioral cloning loss. + +Provides a class that implements a simple behavioral cloning loss for benchmarking MPC-Net. +""" + +import torch + +from ocs2_mpcnet_core.config import Config +from ocs2_mpcnet_core.loss.base import BaseLoss +from ocs2_mpcnet_core.helper import bdot, bmv + + +class BehavioralCloningLoss(BaseLoss): + r"""Behavioral cloning loss. + + Uses a simple quadratic function as loss: + + .. math:: + + BC(u) = \delta u^T R \, \delta u, + + where the input u is of dimension U. + + Attributes: + R: A (1,U,U) tensor with the input cost matrix. + """ + + def __init__(self, config: Config) -> None: + """Initializes the BehavioralCloningLoss class. + + Initializes the BehavioralCloningLoss class by setting fixed attributes. + + Args: + config: An instance of the configuration class. + """ + super().__init__(config) + self.R = torch.tensor(config.R, device=config.DEVICE, dtype=config.DTYPE).diag().unsqueeze(dim=0) + + def __call__( + self, + x_query: torch.Tensor, + x_nominal: torch.Tensor, + u_query: torch.Tensor, + u_nominal: torch.Tensor, + p_query: torch.Tensor, + p_nominal: torch.Tensor, + dHdxx: torch.Tensor, + dHdux: torch.Tensor, + dHduu: torch.Tensor, + dHdx: torch.Tensor, + dHdu: torch.Tensor, + H: torch.Tensor, + ) -> torch.Tensor: + """Computes the loss. + + Computes the mean loss for a batch. + + Args: + x_query: A (B,X) tensor with the query (e.g. predicted) states. + x_nominal: A (B,X) tensor with the nominal (e.g. target) states. + u_query: A (B,U) tensor with the query (e.g. predicted) inputs. + u_nominal: A (B,U) tensor with the nominal (e.g. target) inputs. + p_query: A (B,P) tensor with the query (e.g. predicted) discrete probability distributions. + p_nominal: A (B,P) tensor with the nominal (e.g. target) discrete probability distributions. + dHdxx: A (B,X,X) tensor with the state-state Hessians of the Hamiltonian approximations. + dHdux: A (B,U,X) tensor with the input-state Hessians of the Hamiltonian approximations. + dHduu: A (B,U,U) tensor with the input-input Hessians of the Hamiltonian approximations. + dHdx: A (B,X) tensor with the state gradients of the Hamiltonian approximations. + dHdu: A (B,U) tensor with the input gradients of the Hamiltonian approximations. + H: A (B) tensor with the Hamiltonians at the nominal points. + + Returns: + A (1) tensor containing the mean loss. + """ + return torch.mean(self.compute(u_query, u_nominal)) + + def compute(self, u_predicted: torch.Tensor, u_target: torch.Tensor) -> torch.Tensor: + """Computes the behavioral cloning losses for a batch. + + Computes the behavioral cloning losses for a batch of size B using the cost matrix. + + Args: + u_predicted: A (B, U) tensor with the predicted inputs. + u_target: A (B, U) tensor with the target inputs. + + Returns: + A (B) tensor containing the behavioral cloning losses. + """ + du = torch.sub(u_predicted, u_target) + return bdot(du, bmv(self.R, du)) diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/loss/cross_entropy.py b/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/loss/cross_entropy.py new file mode 100644 index 000000000..93cc706c0 --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/loss/cross_entropy.py @@ -0,0 +1,118 @@ +############################################################################### +# Copyright (c) 2022, Farbod Farshidian. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +############################################################################### + +"""Cross entropy loss. + +Provides a class that implements the cross entropy loss for training a gating network of a mixture of experts network. +""" + +import torch + +from ocs2_mpcnet_core.config import Config +from ocs2_mpcnet_core.loss.base import BaseLoss +from ocs2_mpcnet_core.helper import bdot, bmv + + +class CrossEntropyLoss(BaseLoss): + r"""Cross entropy loss. + + Uses the cross entropy between two discrete probability distributions as loss: + + .. math:: + + CE(p_{target}, p_{predicted}) = - \sum_{i=1}^{P} (p_{target,i} \log(p_{predicted,i} + \varepsilon)), + + where the sample space is the set of P individually identified items. + + Attributes: + epsilon: A (1) tensor with a small epsilon used to stabilize the logarithm. + """ + + def __init__(self, config: Config) -> None: + """Initializes the CrossEntropyLoss class. + + Initializes the CrossEntropyLoss class by setting fixed attributes. + + Args: + config: An instance of the configuration class. + """ + super().__init__(config) + self.epsilon = torch.tensor(config.EPSILON, device=config.DEVICE, dtype=config.DTYPE) + + def __call__( + self, + x_query: torch.Tensor, + x_nominal: torch.Tensor, + u_query: torch.Tensor, + u_nominal: torch.Tensor, + p_query: torch.Tensor, + p_nominal: torch.Tensor, + dHdxx: torch.Tensor, + dHdux: torch.Tensor, + dHduu: torch.Tensor, + dHdx: torch.Tensor, + dHdu: torch.Tensor, + H: torch.Tensor, + ) -> torch.Tensor: + """Computes the loss. + + Computes the mean loss for a batch. + + Args: + x_query: A (B,X) tensor with the query (e.g. predicted) states. + x_nominal: A (B,X) tensor with the nominal (e.g. target) states. + u_query: A (B,U) tensor with the query (e.g. predicted) inputs. + u_nominal: A (B,U) tensor with the nominal (e.g. target) inputs. + p_query: A (B,P) tensor with the query (e.g. predicted) discrete probability distributions. + p_nominal: A (B,P) tensor with the nominal (e.g. target) discrete probability distributions. + dHdxx: A (B,X,X) tensor with the state-state Hessians of the Hamiltonian approximations. + dHdux: A (B,U,X) tensor with the input-state Hessians of the Hamiltonian approximations. + dHduu: A (B,U,U) tensor with the input-input Hessians of the Hamiltonian approximations. + dHdx: A (B,X) tensor with the state gradients of the Hamiltonian approximations. + dHdu: A (B,U) tensor with the input gradients of the Hamiltonian approximations. + H: A (B) tensor with the Hamiltonians at the nominal points. + + Returns: + A (1) tensor containing the mean loss. + """ + return torch.mean(self.compute(p_query, p_nominal)) + + def compute(self, p_predicted: torch.Tensor, p_target: torch.Tensor) -> torch.Tensor: + """Computes the cross entropy losses for a batch. + + Computes the cross entropy losses for a batch, where the logarithm is stabilized by a small epsilon. + + Args: + p_predicted: A (B,P) tensor with the predicted discrete probability distributions. + p_target: A (B,P) tensor with the target discrete probability distributions. + + Returns: + A (B) tensor containing the cross entropy losses. + """ + return -bdot(p_target, torch.log(p_predicted + self.epsilon)) diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/loss/hamiltonian.py b/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/loss/hamiltonian.py new file mode 100644 index 000000000..3def1f538 --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/loss/hamiltonian.py @@ -0,0 +1,151 @@ +############################################################################### +# Copyright (c) 2022, Farbod Farshidian. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +############################################################################### + +"""Hamiltonian loss. + +Provides a class that implements the Hamiltonian loss for MPC-Net. +""" + +import torch + +from ocs2_mpcnet_core.config import Config +from ocs2_mpcnet_core.loss.base import BaseLoss +from ocs2_mpcnet_core.helper import bdot, bmv + + +class HamiltonianLoss(BaseLoss): + r"""Hamiltonian loss. + + Uses the linear quadratic approximation of the Hamiltonian as loss: + + .. math:: + + H(x,u) = \frac{1}{2} \delta x^T \partial H_{xx} \delta x +\delta u^T \partial H_{ux} \delta x + \frac{1}{2} + \delta u^T \partial H_{uu} \delta u + \partial H_{x}^T \delta x + \partial H_{u}^T \delta u + H, + + where the state x is of dimension X and the input u is of dimension U. + """ + + def __init__(self, config: Config) -> None: + """Initializes the HamiltonianLoss class. + + Initializes the HamiltonianLoss class. + + Args: + config: An instance of the configuration class. + """ + super().__init__(config) + + def __call__( + self, + x_query: torch.Tensor, + x_nominal: torch.Tensor, + u_query: torch.Tensor, + u_nominal: torch.Tensor, + p_query: torch.Tensor, + p_nominal: torch.Tensor, + dHdxx: torch.Tensor, + dHdux: torch.Tensor, + dHduu: torch.Tensor, + dHdx: torch.Tensor, + dHdu: torch.Tensor, + H: torch.Tensor, + ) -> torch.Tensor: + """Computes the loss. + + Computes the mean loss for a batch. + + Args: + x_query: A (B,X) tensor with the query (e.g. predicted) states. + x_nominal: A (B,X) tensor with the nominal (e.g. target) states. + u_query: A (B,U) tensor with the query (e.g. predicted) inputs. + u_nominal: A (B,U) tensor with the nominal (e.g. target) inputs. + p_query: A (B,P) tensor with the query (e.g. predicted) discrete probability distributions. + p_nominal: A (B,P) tensor with the nominal (e.g. target) discrete probability distributions. + dHdxx: A (B,X,X) tensor with the state-state Hessians of the Hamiltonian approximations. + dHdux: A (B,U,X) tensor with the input-state Hessians of the Hamiltonian approximations. + dHduu: A (B,U,U) tensor with the input-input Hessians of the Hamiltonian approximations. + dHdx: A (B,X) tensor with the state gradients of the Hamiltonian approximations. + dHdu: A (B,U) tensor with the input gradients of the Hamiltonian approximations. + H: A (B) tensor with the Hamiltonians at the nominal points. + + Returns: + A (1) tensor containing the mean loss. + """ + return torch.mean(self.compute(x_query, x_nominal, u_query, u_nominal, dHdxx, dHdux, dHduu, dHdx, dHdu, H)) + + @staticmethod + def compute( + x_query: torch.Tensor, + x_nominal: torch.Tensor, + u_query: torch.Tensor, + u_nominal: torch.Tensor, + dHdxx: torch.Tensor, + dHdux: torch.Tensor, + dHduu: torch.Tensor, + dHdx: torch.Tensor, + dHdu: torch.Tensor, + H: torch.Tensor, + ) -> torch.Tensor: + """Computes the Hamiltonian losses for a batch. + + Computes the Hamiltonian losses for a batch of size B using the provided linear quadratic approximations. + + Args: + x_query: A (B,X) tensor with the states the Hamiltonian loss should be computed for. + x_nominal: A (B,X) tensor with the states that were used as development/expansion points. + u_query: A (B,U) tensor with the inputs the Hamiltonian loss should be computed for. + u_nominal: A (B,U) tensor with the inputs that were used as development/expansion point. + dHdxx: A (B,X,X) tensor with the state-state Hessians of the Hamiltonian approximations. + dHdux: A (B,U,X) tensor with the input-state Hessians of the Hamiltonian approximations. + dHduu: A (B,U,U) tensor with the input-input Hessians of the Hamiltonian approximations. + dHdx: A (B,X) tensor with the state gradients of the Hamiltonian approximations. + dHdu: A (B,U) tensor with the input gradients of the Hamiltonian approximations. + H: A (B) tensor with the Hamiltonians at the development/expansion points. + + Returns: + A (B) tensor containing the computed Hamiltonian losses. + """ + if torch.equal(x_query, x_nominal): + du = torch.sub(u_query, u_nominal) + return 0.5 * bdot(du, bmv(dHduu, du)) + bdot(dHdu, du) + H + elif torch.equal(u_query, u_nominal): + dx = torch.sub(x_query, x_nominal) + return 0.5 * bdot(dx, bmv(dHdxx, dx)) + bdot(dHdx, dx) + H + else: + dx = torch.sub(x_query, x_nominal) + du = torch.sub(u_query, u_nominal) + return ( + 0.5 * bdot(dx, bmv(dHdxx, dx)) + + bdot(du, bmv(dHdux, dx)) + + 0.5 * bdot(du, bmv(dHduu, du)) + + bdot(dHdx, dx) + + bdot(dHdu, du) + + H + ) diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/memory/__init__.py b/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/memory/__init__.py new file mode 100644 index 000000000..7d777af72 --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/memory/__init__.py @@ -0,0 +1,4 @@ +from .base import BaseMemory +from .circular import CircularMemory + +__all__ = ["BaseMemory", "CircularMemory"] diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/memory/base.py b/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/memory/base.py new file mode 100644 index 000000000..5c1b26fec --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/memory/base.py @@ -0,0 +1,122 @@ +############################################################################### +# Copyright (c) 2022, Farbod Farshidian. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +############################################################################### + +"""Base memory. + +Provides a base class for all memory classes. +""" + +import torch +import numpy as np +from typing import Tuple, List +from abc import ABCMeta, abstractmethod + +from ocs2_mpcnet_core.config import Config +from ocs2_mpcnet_core import ScalarFunctionQuadraticApproximation + + +class BaseMemory(metaclass=ABCMeta): + """Base memory. + + Provides the interface to all memory classes. + """ + + def __init__(self, config: Config) -> None: + """Initializes the BaseMemory class. + + Initializes the BaseMemory class. + + Args: + config: An instance of the configuration class. + """ + pass + + @abstractmethod + def push( + self, + t: float, + x: np.ndarray, + u: np.ndarray, + p: np.ndarray, + observation: np.ndarray, + action_transformation: List[np.ndarray], + hamiltonian: ScalarFunctionQuadraticApproximation, + ) -> None: + """Pushes data into the memory. + + Pushes one data sample into the memory. + + Args: + t: A float with the time. + x: A NumPy array of shape (X) with the observed state. + u: A NumPy array of shape (U) with the optimal input. + p: A NumPy array of shape (P) tensor for the observed discrete probability distributions of the modes. + observation: A NumPy array of shape (O) with the generalized times. + action_transformation: A list containing NumPy arrays of shape (U,A) and (U) with the action transformation. + hamiltonian: An OCS2 scalar function quadratic approximation representing the Hamiltonian around x and u. + """ + pass + + @abstractmethod + def sample(self, batch_size: int) -> Tuple[torch.Tensor, ...]: + """Samples data from the memory. + + Samples a batch of data from the memory. + + Args: + batch_size: An integer defining the batch size B. + + Returns: + A tuple containing the sampled batch of data. + - t_batch: A (B) tensor with the times. + - x_batch: A (B,X) tensor with the observed states. + - u_batch: A (B,U) tensor with the optimal inputs. + - p_batch: A (B,P) tensor with the observed discrete probability distributions of the modes. + - observation_batch: A (B,O) tensor with the observations. + - action_transformation_matrix_batch: A (B,U,A) tensor with the action transformation matrices. + - action_transformation_vector_batch: A (B,U) tensor with the action transformation vectors. + - dHdxx_batch: A (B,X,X) tensor with the state-state Hessians of the Hamiltonian approximations. + - dHdux_batch: A (B,U,X) tensor with the input-state Hessians of the Hamiltonian approximations. + - dHduu_batch: A (B,U,U) tensor with the input-input Hessians of the Hamiltonian approximations. + - dHdx_batch: A (B,X) tensor with the state gradients of the Hamiltonian approximations. + - dHdu_batch: A (B,U) tensor with the input gradients of the Hamiltonian approximations. + - H_batch: A (B) tensor with the Hamiltonians at the development/expansion points. + """ + pass + + @abstractmethod + def __len__(self) -> int: + """The length of the memory. + + Return the length of the memory given by the current size. + + Returns: + An integer describing the length of the memory. + """ + pass diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/memory/circular.py b/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/memory/circular.py new file mode 100644 index 000000000..70decc31b --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/memory/circular.py @@ -0,0 +1,218 @@ +############################################################################### +# Copyright (c) 2022, Farbod Farshidian. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +############################################################################### + +"""Circular memory. + +Provides a class that implements a circular memory. +""" + +import torch +import numpy as np +from typing import Tuple, List + +from ocs2_mpcnet_core.config import Config +from ocs2_mpcnet_core.memory.base import BaseMemory +from ocs2_mpcnet_core import ScalarFunctionQuadraticApproximation + + +class CircularMemory(BaseMemory): + """Circular memory. + + Stores data in a circular memory that overwrites old data if the size of the memory reaches its capacity. + + Attributes: + capacity: An integer defining the capacity of the memory. + size: An integer giving the current size of the memory. + position: An integer giving the current position in the memory. + t: A (C) tensor for the times. + x: A (C,X) tensor for the observed states. + u: A (C,U) tensor for the optimal inputs. + p: A (C,P) tensor for the observed discrete probability distributions of the modes. + observation: A (C,O) tensor for the observations. + action_transformation_matrix: A (C,U,A) tensor for the action transformation matrices. + action_transformation_vector: A (C,U) tensor for the action transformation vectors. + dHdxx: A (C,X,X) tensor for the state-state Hessians of the Hamiltonian approximations. + dHdux: A (C,U,X) tensor for the input-state Hessians of the Hamiltonian approximations. + dHduu: A (C,U,U) tensor for the input-input Hessians of the Hamiltonian approximations. + dHdx: A (C,X) tensor for the state gradients of the Hamiltonian approximations. + dHdu: A (C,U) tensor for the input gradients of the Hamiltonian approximations. + H: A (C) tensor for the Hamiltonians at the development/expansion points. + """ + + def __init__(self, config: Config) -> None: + """Initializes the CircularMemory class. + + Initializes the CircularMemory class by setting fixed attributes, initializing variable attributes and + pre-allocating memory. + + Args: + config: An instance of the configuration class. + """ + # init variables + self.device = config.DEVICE + self.capacity = config.CAPACITY + self.size = 0 + self.position = 0 + # pre-allocate memory + self.t = torch.zeros(config.CAPACITY, device=config.DEVICE, dtype=config.DTYPE) + self.x = torch.zeros(config.CAPACITY, config.STATE_DIM, device=config.DEVICE, dtype=config.DTYPE) + self.u = torch.zeros(config.CAPACITY, config.INPUT_DIM, device=config.DEVICE, dtype=config.DTYPE) + self.p = torch.zeros(config.CAPACITY, config.EXPERT_NUM, device=config.DEVICE, dtype=config.DTYPE) + self.observation = torch.zeros( + config.CAPACITY, config.OBSERVATION_DIM, device=config.DEVICE, dtype=config.DTYPE + ) + self.action_transformation_matrix = torch.zeros( + config.CAPACITY, config.INPUT_DIM, config.ACTION_DIM, device=config.DEVICE, dtype=config.DTYPE + ) + self.action_transformation_vector = torch.zeros( + config.CAPACITY, config.INPUT_DIM, device=config.DEVICE, dtype=config.DTYPE + ) + self.dHdxx = torch.zeros( + config.CAPACITY, config.STATE_DIM, config.STATE_DIM, device=config.DEVICE, dtype=config.DTYPE + ) + self.dHdux = torch.zeros( + config.CAPACITY, config.INPUT_DIM, config.STATE_DIM, device=config.DEVICE, dtype=config.DTYPE + ) + self.dHduu = torch.zeros( + config.CAPACITY, config.INPUT_DIM, config.INPUT_DIM, device=config.DEVICE, dtype=config.DTYPE + ) + self.dHdx = torch.zeros(config.CAPACITY, config.STATE_DIM, device=config.DEVICE, dtype=config.DTYPE) + self.dHdu = torch.zeros(config.CAPACITY, config.INPUT_DIM, device=config.DEVICE, dtype=config.DTYPE) + self.H = torch.zeros(config.CAPACITY, device=config.DEVICE, dtype=config.DTYPE) + + def push( + self, + t: float, + x: np.ndarray, + u: np.ndarray, + p: np.ndarray, + observation: np.ndarray, + action_transformation: List[np.ndarray], + hamiltonian: ScalarFunctionQuadraticApproximation, + ) -> None: + """Pushes data into the memory. + + Pushes one data sample into the memory. + + Args: + t: A float with the time. + x: A NumPy array of shape (X) with the observed state. + u: A NumPy array of shape (U) with the optimal input. + p: A NumPy array of shape (P) tensor for the observed discrete probability distributions of the modes. + observation: A NumPy array of shape (O) with the generalized times. + action_transformation: A list containing NumPy arrays of shape (U,A) and (U) with the action transformation. + hamiltonian: An OCS2 scalar function quadratic approximation representing the Hamiltonian around x and u. + """ + # push data into memory + # note: - torch.as_tensor: no copy as data is a ndarray of the corresponding dtype and the device is the cpu + # - torch.Tensor.copy_: copy performed together with potential dtype and device change + self.t[self.position].copy_(torch.as_tensor(t, dtype=None, device=torch.device("cpu"))) + self.x[self.position].copy_(torch.as_tensor(x, dtype=None, device=torch.device("cpu"))) + self.u[self.position].copy_(torch.as_tensor(u, dtype=None, device=torch.device("cpu"))) + self.p[self.position].copy_(torch.as_tensor(p, dtype=None, device=torch.device("cpu"))) + self.observation[self.position].copy_(torch.as_tensor(observation, dtype=None, device=torch.device("cpu"))) + self.action_transformation_matrix[self.position].copy_( + torch.as_tensor(action_transformation[0], dtype=None, device=torch.device("cpu")) + ) + self.action_transformation_vector[self.position].copy_( + torch.as_tensor(action_transformation[1], dtype=None, device=torch.device("cpu")) + ) + self.dHdxx[self.position].copy_(torch.as_tensor(hamiltonian.dfdxx, dtype=None, device=torch.device("cpu"))) + self.dHdux[self.position].copy_(torch.as_tensor(hamiltonian.dfdux, dtype=None, device=torch.device("cpu"))) + self.dHduu[self.position].copy_(torch.as_tensor(hamiltonian.dfduu, dtype=None, device=torch.device("cpu"))) + self.dHdx[self.position].copy_(torch.as_tensor(hamiltonian.dfdx, dtype=None, device=torch.device("cpu"))) + self.dHdu[self.position].copy_(torch.as_tensor(hamiltonian.dfdu, dtype=None, device=torch.device("cpu"))) + self.H[self.position].copy_(torch.as_tensor(hamiltonian.f, dtype=None, device=torch.device("cpu"))) + # update size and position + self.size = min(self.size + 1, self.capacity) + self.position = (self.position + 1) % self.capacity + + def sample(self, batch_size: int) -> Tuple[torch.Tensor, ...]: + """Samples data from the memory. + + Samples a batch of data from the memory. + + Args: + batch_size: An integer defining the batch size B. + + Returns: + A tuple containing the sampled batch of data. + - t_batch: A (B) tensor with the times. + - x_batch: A (B,X) tensor with the observed states. + - u_batch: A (B,U) tensor with the optimal inputs. + - p_batch: A (B,P) tensor with the observed discrete probability distributions of the modes. + - observation_batch: A (B,O) tensor with the observations. + - action_transformation_matrix_batch: A (B,U,A) tensor with the action transformation matrices. + - action_transformation_vector_batch: A (B,U) tensor with the action transformation vectors. + - dHdxx_batch: A (B,X,X) tensor with the state-state Hessians of the Hamiltonian approximations. + - dHdux_batch: A (B,U,X) tensor with the input-state Hessians of the Hamiltonian approximations. + - dHduu_batch: A (B,U,U) tensor with the input-input Hessians of the Hamiltonian approximations. + - dHdx_batch: A (B,X) tensor with the state gradients of the Hamiltonian approximations. + - dHdu_batch: A (B,U) tensor with the input gradients of the Hamiltonian approximations. + - H_batch: A (B) tensor with the Hamiltonians at the development/expansion points. + """ + indices = torch.randint(low=0, high=self.size, size=(batch_size,), device=self.device) + t_batch = self.t[indices] + x_batch = self.x[indices] + u_batch = self.u[indices] + p_batch = self.p[indices] + observation_batch = self.observation[indices] + action_transformation_matrix_batch = self.action_transformation_matrix[indices] + action_transformation_vector_batch = self.action_transformation_vector[indices] + dHdxx_batch = self.dHdxx[indices] + dHdux_batch = self.dHdux[indices] + dHduu_batch = self.dHduu[indices] + dHdx_batch = self.dHdx[indices] + dHdu_batch = self.dHdu[indices] + H_batch = self.H[indices] + return ( + t_batch, + x_batch, + u_batch, + p_batch, + observation_batch, + action_transformation_matrix_batch, + action_transformation_vector_batch, + dHdxx_batch, + dHdux_batch, + dHduu_batch, + dHdx_batch, + dHdu_batch, + H_batch, + ) + + def __len__(self) -> int: + """The length of the memory. + + Return the length of the memory given by the current size. + + Returns: + An integer describing the length of the memory. + """ + return self.size diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/mpcnet.py b/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/mpcnet.py new file mode 100644 index 000000000..a79fbc9d7 --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/mpcnet.py @@ -0,0 +1,329 @@ +############################################################################### +# Copyright (c) 2022, Farbod Farshidian. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +############################################################################### + +"""MPC-Net class. + +Provides a class that handles the MPC-Net training. +""" + +import os +import time +import datetime +import torch +import numpy as np +from typing import Optional, Tuple +from abc import ABCMeta, abstractmethod +from torch.utils.tensorboard import SummaryWriter + + +from ocs2_mpcnet_core import helper +from ocs2_mpcnet_core import SystemObservationArray, ModeScheduleArray, TargetTrajectoriesArray +from ocs2_mpcnet_core.config import Config +from ocs2_mpcnet_core.loss import BaseLoss +from ocs2_mpcnet_core.memory import BaseMemory +from ocs2_mpcnet_core.policy import BasePolicy + + +class Mpcnet(metaclass=ABCMeta): + """MPC-Net. + + Implements the main methods for the MPC-Net training. + + Takes a specific configuration, interface, memory, policy and loss function(s). + The task formulation has to be implemented in a robot-specific class derived from this class. + Provides the main training loop for MPC-Net. + """ + + def __init__( + self, + root_dir: str, + config: Config, + interface: object, + memory: BaseMemory, + policy: BasePolicy, + experts_loss: BaseLoss, + gating_loss: Optional[BaseLoss] = None, + ) -> None: + """Initializes the Mpcnet class. + + Initializes the Mpcnet class by setting fixed and variable attributes. + + Args: + root_dir: The absolute path to the root directory. + config: An instance of the configuration class. + interface: An instance of the interface class. + memory: An instance of a memory class. + policy: An instance of a policy class. + experts_loss: An instance of a loss class used as experts loss. + gating_loss: An instance of a loss class used as gating loss. + """ + # config + self.config = config + # interface + self.interface = interface + # logging + timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + self.log_dir = os.path.join(root_dir, "runs", f"{timestamp}_{config.NAME}_{config.DESCRIPTION}") + self.writer = SummaryWriter(self.log_dir) + # loss + self.experts_loss = experts_loss + self.gating_loss = gating_loss + # memory + self.memory = memory + # policy + self.policy = policy + self.policy.to(config.DEVICE) + self.dummy_observation = torch.randn(1, config.OBSERVATION_DIM, device=config.DEVICE, dtype=config.DTYPE) + # optimizer + self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=config.LEARNING_RATE) + + @abstractmethod + def get_tasks( + self, tasks_number: int, duration: float + ) -> Tuple[SystemObservationArray, ModeScheduleArray, TargetTrajectoriesArray]: + """Get tasks. + + Get a random set of task that should be executed by the data generation or policy evaluation. + + Args: + tasks_number: Number of tasks given by an integer. + duration: Duration of each task given by a float. + + Returns: + A tuple containing the components of the task. + - initial_observations: The initial observations given by an OCS2 system observation array. + - mode_schedules: The desired mode schedules given by an OCS2 mode schedule array. + - target_trajectories: The desired target trajectories given by an OCS2 target trajectories array. + """ + pass + + def start_data_generation(self, policy: BasePolicy, alpha: float = 1.0): + """Start data generation. + + Start the data generation rollouts to receive new data. + + Args: + policy: The current learned policy. + alpha: The weight of the MPC policy in the rollouts. + """ + policy_file_path = "/tmp/data_generation_" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".onnx" + torch.onnx.export(model=policy, args=self.dummy_observation, f=policy_file_path) + initial_observations, mode_schedules, target_trajectories = self.get_tasks( + self.config.DATA_GENERATION_TASKS, self.config.DATA_GENERATION_DURATION + ) + self.interface.startDataGeneration( + alpha, + policy_file_path, + self.config.DATA_GENERATION_TIME_STEP, + self.config.DATA_GENERATION_DATA_DECIMATION, + self.config.DATA_GENERATION_SAMPLES, + np.diag(np.power(np.array(self.config.DATA_GENERATION_SAMPLING_VARIANCE), 2)), + initial_observations, + mode_schedules, + target_trajectories, + ) + + def start_policy_evaluation(self, policy: BasePolicy, alpha: float = 0.0): + """Start policy evaluation. + + Start the policy evaluation rollouts to validate the current performance. + + Args: + policy: The current learned policy. + alpha: The weight of the MPC policy in the rollouts. + """ + policy_file_path = "/tmp/policy_evaluation_" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".onnx" + torch.onnx.export(model=policy, args=self.dummy_observation, f=policy_file_path) + initial_observations, mode_schedules, target_trajectories = self.get_tasks( + self.config.POLICY_EVALUATION_TASKS, self.config.POLICY_EVALUATION_DURATION + ) + self.interface.startPolicyEvaluation( + alpha, + policy_file_path, + self.config.POLICY_EVALUATION_TIME_STEP, + initial_observations, + mode_schedules, + target_trajectories, + ) + + def train(self) -> None: + """Train. + + Run the main training loop of MPC-Net. + """ + try: + # save initial policy + save_path = self.log_dir + "/initial_policy" + torch.onnx.export(model=self.policy, args=self.dummy_observation, f=save_path + ".onnx") + torch.save(obj=self.policy, f=save_path + ".pt") + + print("==============\nWaiting for first data.\n==============") + self.start_data_generation(self.policy) + self.start_policy_evaluation(self.policy) + while not self.interface.isDataGenerationDone(): + time.sleep(1.0) + + print("==============\nStarting training.\n==============") + for iteration in range(self.config.LEARNING_ITERATIONS): + alpha = 1.0 - 1.0 * iteration / self.config.LEARNING_ITERATIONS + + # data generation + if self.interface.isDataGenerationDone(): + # get generated data + data = self.interface.getGeneratedData() + for i in range(len(data)): + # push t, x, u, p, observation, action transformation, Hamiltonian into memory + self.memory.push( + data[i].t, + data[i].x, + data[i].u, + helper.get_one_hot(data[i].mode, self.config.EXPERT_NUM, self.config.EXPERT_FOR_MODE), + data[i].observation, + data[i].actionTransformation, + data[i].hamiltonian, + ) + # logging + self.writer.add_scalar("data/new_data_points", len(data), iteration) + self.writer.add_scalar("data/total_data_points", len(self.memory), iteration) + print("iteration", iteration, "received data points", len(data), "requesting with alpha", alpha) + # start new data generation + self.start_data_generation(self.policy, alpha) + + # policy evaluation + if self.interface.isPolicyEvaluationDone(): + # get computed metrics + metrics = self.interface.getComputedMetrics() + survival_time = np.mean([metrics[i].survivalTime for i in range(len(metrics))]) + incurred_hamiltonian = np.mean([metrics[i].incurredHamiltonian for i in range(len(metrics))]) + # logging + self.writer.add_scalar("metric/survival_time", survival_time, iteration) + self.writer.add_scalar("metric/incurred_hamiltonian", incurred_hamiltonian, iteration) + print( + "iteration", + iteration, + "received metrics:", + "incurred_hamiltonian", + incurred_hamiltonian, + "survival_time", + survival_time, + ) + # start new policy evaluation + self.start_policy_evaluation(self.policy) + + # save intermediate policy + if (iteration % int(0.1 * self.config.LEARNING_ITERATIONS) == 0) and (iteration > 0): + save_path = self.log_dir + "/intermediate_policy_" + str(iteration) + torch.onnx.export(model=self.policy, args=self.dummy_observation, f=save_path + ".onnx") + torch.save(obj=self.policy, f=save_path + ".pt") + + # extract batch from memory + ( + t, + x, + u, + p, + observation, + action_transformation_matrix, + action_transformation_vector, + dHdxx, + dHdux, + dHduu, + dHdx, + dHdu, + H, + ) = self.memory.sample(self.config.BATCH_SIZE) + + # normal closure only evaluating the experts loss + def normal_closure(): + # clear the gradients + self.optimizer.zero_grad() + # prediction + action = self.policy(observation)[0] + input = helper.bmv(action_transformation_matrix, action) + action_transformation_vector + # compute the empirical loss + empirical_loss = self.experts_loss(x, x, input, u, p, p, dHdxx, dHdux, dHduu, dHdx, dHdu, H) + # compute the gradients + empirical_loss.backward() + # clip the gradients + if self.config.GRADIENT_CLIPPING: + torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.config.GRADIENT_CLIPPING_VALUE) + # logging + self.writer.add_scalar("objective/empirical_loss", empirical_loss.item(), iteration) + # return empirical loss + return empirical_loss + + # cheating closure also adding the gating loss (only relevant for mixture of experts networks) + def cheating_closure(): + # clear the gradients + self.optimizer.zero_grad() + # prediction + action, weights = self.policy(observation)[:2] + input = helper.bmv(action_transformation_matrix, action) + action_transformation_vector + # compute the empirical loss + empirical_experts_loss = self.experts_loss(x, x, input, u, p, p, dHdxx, dHdux, dHduu, dHdx, dHdu, H) + empirical_gating_loss = self.gating_loss(x, x, u, u, weights, p, dHdxx, dHdux, dHduu, dHdx, dHdu, H) + empirical_loss = empirical_experts_loss + self.config.LAMBDA * empirical_gating_loss + # compute the gradients + empirical_loss.backward() + # clip the gradients + if self.config.GRADIENT_CLIPPING: + torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.config.GRADIENT_CLIPPING_VALUE) + # logging + self.writer.add_scalar("objective/empirical_experts_loss", empirical_experts_loss.item(), iteration) + self.writer.add_scalar("objective/empirical_gating_loss", empirical_gating_loss.item(), iteration) + self.writer.add_scalar("objective/empirical_loss", empirical_loss.item(), iteration) + # return empirical loss + return empirical_loss + + # take an optimization step + if self.config.CHEATING: + self.optimizer.step(cheating_closure) + else: + self.optimizer.step(normal_closure) + + # let data generation and policy evaluation finish in last iteration (to avoid a segmentation fault) + if iteration == self.config.LEARNING_ITERATIONS - 1: + while (not self.interface.isDataGenerationDone()) or (not self.interface.isPolicyEvaluationDone()): + time.sleep(1.0) + + print("==============\nTraining completed.\n==============") + + # save final policy + save_path = self.log_dir + "/final_policy" + torch.onnx.export(model=self.policy, args=self.dummy_observation, f=save_path + ".onnx") + torch.save(obj=self.policy, f=save_path + ".pt") + + except KeyboardInterrupt: + # let data generation and policy evaluation finish (to avoid a segmentation fault) + while (not self.interface.isDataGenerationDone()) or (not self.interface.isPolicyEvaluationDone()): + time.sleep(1.0) + print("==============\nTraining interrupted.\n==============") + pass + + self.writer.close() diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/policy/__init__.py b/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/policy/__init__.py new file mode 100644 index 000000000..3e178b783 --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/policy/__init__.py @@ -0,0 +1,13 @@ +from .base import BasePolicy +from .linear import LinearPolicy +from .mixture_of_linear_experts import MixtureOfLinearExpertsPolicy +from .mixture_of_nonlinear_experts import MixtureOfNonlinearExpertsPolicy +from .nonlinear import NonlinearPolicy + +__all__ = [ + "BasePolicy", + "LinearPolicy", + "MixtureOfLinearExpertsPolicy", + "MixtureOfNonlinearExpertsPolicy", + "NonlinearPolicy", +] diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/policy/base.py b/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/policy/base.py new file mode 100644 index 000000000..7b9c584c2 --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/policy/base.py @@ -0,0 +1,107 @@ +############################################################################### +# Copyright (c) 2022, Farbod Farshidian. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +############################################################################### + +"""Base policy. + +Provides a base class for all policy classes. +""" + +import torch +from typing import Tuple +from abc import ABCMeta, abstractmethod + +from ocs2_mpcnet_core.config import Config +from ocs2_mpcnet_core.helper import bmv + + +class BasePolicy(torch.nn.Module, metaclass=ABCMeta): + """Base policy. + + Provides the interface to all policy classes. + + Attributes: + observation_scaling: A (1,O,O) tensor for the observation scaling. + action_scaling: A (1,A,A) tensor for the action scaling. + """ + + def __init__(self, config: Config) -> None: + """Initializes the BasePolicy class. + + Initializes the BasePolicy class. + + Args: + config: An instance of the configuration class. + """ + super().__init__() + self.observation_scaling = ( + torch.tensor(config.OBSERVATION_SCALING, device=config.DEVICE, dtype=config.DTYPE).diag().unsqueeze(dim=0) + ) + self.action_scaling = ( + torch.tensor(config.ACTION_SCALING, device=config.DEVICE, dtype=config.DTYPE).diag().unsqueeze(dim=0) + ) + + @abstractmethod + def forward(self, observation: torch.Tensor) -> Tuple[torch.Tensor, ...]: + """Forward method. + + Defines the computation performed at every call. Computes the output tensors from the input tensors. + + Args: + observation: A (B,O) tensor with the observations. + + Returns: + tuple: A tuple with the predictions, e.g. containing a (B,A) tensor with the predicted actions. + """ + pass + + def scale_observation(self, observation: torch.Tensor) -> torch.Tensor: + """Scale observation. + + Scale the observation with a fixed diagonal matrix. + + Args: + observation: A (B,O) tensor with the observations. + + Returns: + scaled_observation: A (B,O) tensor with the scaled observations. + """ + return bmv(self.observation_scaling, observation) + + def scale_action(self, action: torch.Tensor) -> torch.Tensor: + """Scale action. + + Scale the action with a fixed diagonal matrix. + + Args: + action: A (B,A) tensor with the actions. + + Returns: + scaled_action: A (B,A) tensor with the scaled actions. + """ + return bmv(self.action_scaling, action) diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/policy/linear.py b/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/policy/linear.py new file mode 100644 index 000000000..ba9c9fe9c --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/policy/linear.py @@ -0,0 +1,82 @@ +############################################################################### +# Copyright (c) 2022, Farbod Farshidian. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +############################################################################### + +"""Linear policy. + +Provides a class that implements a linear policy. +""" + +import torch +from typing import Tuple + +from ocs2_mpcnet_core.config import Config +from ocs2_mpcnet_core.policy.base import BasePolicy + + +class LinearPolicy(BasePolicy): + """Linear policy. + + Class for a simple linear neural network policy. + + Attributes: + name: A string with the name of the policy. + observation_dimension: An integer defining the observation (i.e. input) dimension of the policy. + action_dimension: An integer defining the action (i.e. output) dimension of the policy. + linear: The linear neural network layer. + """ + + def __init__(self, config: Config) -> None: + """Initializes the LinearPolicy class. + + Initializes the LinearPolicy class by setting fixed and variable attributes. + + Args: + config: An instance of the configuration class. + """ + super().__init__(config) + self.name = "LinearPolicy" + self.observation_dimension = config.OBSERVATION_DIM + self.action_dimension = config.ACTION_DIM + self.linear = torch.nn.Linear(self.observation_dimension, self.action_dimension) + + def forward(self, observation: torch.Tensor) -> Tuple[torch.Tensor]: + """Forward method. + + Defines the computation performed at every call. Computes the output tensors from the input tensors. + + Args: + observation: A (B,O) tensor with the observations. + + Returns: + action: A (B,A) tensor with the predicted actions. + """ + scaled_observation = self.scale_observation(observation) + unscaled_action = self.linear(scaled_observation) + action = self.scale_action(unscaled_action) + return (action,) diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/policy/mixture_of_linear_experts.py b/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/policy/mixture_of_linear_experts.py new file mode 100644 index 000000000..2527d5e67 --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/policy/mixture_of_linear_experts.py @@ -0,0 +1,140 @@ +############################################################################### +# Copyright (c) 2022, Farbod Farshidian. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +############################################################################### + +"""Mixture of linear experts policy. + +Provides classes that implement a mixture of linear experts policy. +""" + +import torch +from typing import Tuple + +from ocs2_mpcnet_core.config import Config +from ocs2_mpcnet_core.policy.base import BasePolicy +from ocs2_mpcnet_core.helper import bmv + + +class MixtureOfLinearExpertsPolicy(BasePolicy): + """Mixture of linear experts policy. + + Class for a mixture of experts neural network policy with linear experts. + + Attributes: + name: A string with the name of the policy. + observation_dimension: An integer defining the observation (i.e. input) dimension of the policy. + action_dimension: An integer defining the action (i.e. output) dimension of the policy. + expert_number: An integer defining the number of experts. + gating_net: The gating network. + expert_nets: The expert networks. + """ + + def __init__(self, config: Config) -> None: + """Initializes the MixtureOfLinearExpertsPolicy class. + + Initializes the MixtureOfLinearExpertsPolicy class by setting fixed and variable attributes. + + Args: + config: An instance of the configuration class. + """ + super().__init__(config) + self.name = "MixtureOfLinearExpertsPolicy" + self.observation_dimension = config.OBSERVATION_DIM + self.action_dimension = config.ACTION_DIM + self.expert_number = config.EXPERT_NUM + # gating + self.gating_net = torch.nn.Sequential( + torch.nn.Linear(self.observation_dimension, self.expert_number), torch.nn.Softmax(dim=1) + ) + # experts + self.expert_nets = torch.nn.ModuleList( + [_LinearExpert(i, self.observation_dimension, self.action_dimension) for i in range(self.expert_number)] + ) + + def forward(self, observation: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward method. + + Defines the computation performed at every call. Computes the output tensors from the input tensors. + + Args: + observation: A (B,O) tensor with the observations. + + Returns: + action: A (B,A) tensor with the predicted actions. + expert_weights: A (B,E) tensor with the predicted expert weights. + """ + scaled_observation = self.scale_observation(observation) + expert_weights = self.gating_net(scaled_observation) + expert_actions = torch.stack( + [self.expert_nets[i](scaled_observation) for i in range(self.expert_number)], dim=2 + ) + unscaled_action = bmv(expert_actions, expert_weights) + action = self.scale_action(unscaled_action) + return action, expert_weights + + +class _LinearExpert(torch.nn.Module): + """Linear expert. + + Class for a simple linear neural network expert. + + Attributes: + name: A string with the name of the expert. + input_dimension: An integer defining the input dimension of the expert. + output_dimension: An integer defining the output dimension of the expert. + linear: The linear neural network layer. + """ + + def __init__(self, index: int, input_dimension: int, output_dimension: int) -> None: + """Initializes the _LinearExpert class. + + Initializes the _LinearExpert class by setting fixed and variable attributes. + + Args: + index: An integer with the index of the expert. + input_dimension: An integer defining the input dimension of the expert. + output_dimension: An integer defining the output dimension of the expert. + """ + super().__init__() + self.name = "LinearExpert" + str(index) + self.input_dimension = input_dimension + self.output_dimension = output_dimension + self.linear = torch.nn.Linear(self.input_dimension, self.output_dimension) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + """Forward method. + + Defines the computation performed at every call. Computes the output tensors from the input tensors. + + Args: + input: A (B,I) tensor with the inputs. + + Returns: + output: A (B,O) tensor with the outputs. + """ + return self.linear(input) diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/policy/mixture_of_nonlinear_experts.py b/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/policy/mixture_of_nonlinear_experts.py new file mode 100644 index 000000000..6fed0da20 --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/policy/mixture_of_nonlinear_experts.py @@ -0,0 +1,159 @@ +############################################################################### +# Copyright (c) 2022, Farbod Farshidian. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +############################################################################### + +"""Mixture of nonlinear experts policy. + +Provides classes that implement a mixture of nonlinear experts policy. +""" + +import torch +from typing import Tuple + +from ocs2_mpcnet_core.config import Config +from ocs2_mpcnet_core.policy.base import BasePolicy +from ocs2_mpcnet_core.helper import bmv + + +class MixtureOfNonlinearExpertsPolicy(BasePolicy): + """Mixture of nonlinear experts policy. + + Class for a mixture of experts neural network policy with nonlinear experts, where the hidden layer dimension is the + mean of the input and output dimensions. + + Attributes: + name: A string with the name of the policy. + observation_dimension: An integer defining the observation (i.e. input) dimension of the policy. + gating_hidden_dimension: An integer defining the dimension of the hidden layer for the gating network. + expert_hidden_dimension: An integer defining the dimension of the hidden layer for the expert networks. + action_dimension: An integer defining the action (i.e. output) dimension of the policy. + expert_number: An integer defining the number of experts. + gating_net: The gating network. + expert_nets: The expert networks. + """ + + def __init__(self, config: Config) -> None: + """Initializes the MixtureOfNonlinearExpertsPolicy class. + + Initializes the MixtureOfNonlinearExpertsPolicy class by setting fixed and variable attributes. + + Args: + config: An instance of the configuration class. + """ + super().__init__(config) + self.name = "MixtureOfNonlinearExpertsPolicy" + self.observation_dimension = config.OBSERVATION_DIM + self.gating_hidden_dimension = int((config.OBSERVATION_DIM + config.EXPERT_NUM) / 2) + self.expert_hidden_dimension = int((config.OBSERVATION_DIM + config.ACTION_DIM) / 2) + self.action_dimension = config.ACTION_DIM + self.expert_number = config.EXPERT_NUM + # gating + self.gating_net = torch.nn.Sequential( + torch.nn.Linear(self.observation_dimension, self.gating_hidden_dimension), + torch.nn.Tanh(), + torch.nn.Linear(self.gating_hidden_dimension, self.expert_number), + torch.nn.Softmax(dim=1), + ) + # experts + self.expert_nets = torch.nn.ModuleList( + [ + _NonlinearExpert(i, self.observation_dimension, self.expert_hidden_dimension, self.action_dimension) + for i in range(self.expert_number) + ] + ) + + def forward(self, observation: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward method. + + Defines the computation performed at every call. Computes the output tensors from the input tensors. + + Args: + observation: A (B,O) tensor with the observations. + + Returns: + action: A (B,A) tensor with the predicted actions. + expert_weights: A (B,E) tensor with the predicted expert weights. + """ + scaled_observation = self.scale_observation(observation) + expert_weights = self.gating_net(scaled_observation) + expert_actions = torch.stack( + [self.expert_nets[i](scaled_observation) for i in range(self.expert_number)], dim=2 + ) + unscaled_action = bmv(expert_actions, expert_weights) + action = self.scale_action(unscaled_action) + return action, expert_weights + + +class _NonlinearExpert(torch.nn.Module): + """Nonlinear expert. + + Class for a simple nonlinear neural network expert, where the hidden layer dimension is the mean of the input and + output dimensions. + + Attributes: + name: A string with the name of the expert. + input_dimension: An integer defining the input dimension of the expert. + hidden_dimension: An integer defining the dimension of the hidden layer. + output_dimension: An integer defining the output dimension of the expert. + linear1: The first linear neural network layer. + activation: The activation to get the hidden layer. + linear2: The second linear neural network layer. + """ + + def __init__(self, index: int, input_dimension: int, hidden_dimension: int, output_dimension: int) -> None: + """Initializes the _NonlinearExpert class. + + Initializes the _NonlinearExpert class by setting fixed and variable attributes. + + Args: + index: An integer with the index of the expert. + input_dimension: An integer defining the input dimension of the expert. + hidden_dimension: An integer defining the dimension of the hidden layer. + output_dimension: An integer defining the output dimension of the expert. + """ + super().__init__() + self.name = "NonlinearExpert" + str(index) + self.input_dimension = input_dimension + self.hidden_dimension = hidden_dimension + self.output_dimension = output_dimension + self.linear1 = torch.nn.Linear(self.input_dimension, self.hidden_dimension) + self.activation = torch.nn.Tanh() + self.linear2 = torch.nn.Linear(self.hidden_dimension, self.output_dimension) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + """Forward method. + + Defines the computation performed at every call. Computes the output tensors from the input tensors. + + Args: + input: A (B,I) tensor with the inputs. + + Returns: + output: A (B,O) tensor with the outputs. + """ + return self.linear2(self.activation(self.linear1(input))) diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/policy/nonlinear.py b/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/policy/nonlinear.py new file mode 100644 index 000000000..f2f9f95e1 --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/python/ocs2_mpcnet_core/policy/nonlinear.py @@ -0,0 +1,89 @@ +############################################################################### +# Copyright (c) 2022, Farbod Farshidian. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +############################################################################### + +"""Nonlinear policy. + +Provides a class that implements a nonlinear policy. +""" + +import torch +from typing import Tuple + +from ocs2_mpcnet_core.config import Config +from ocs2_mpcnet_core.policy.base import BasePolicy + + +class NonlinearPolicy(BasePolicy): + """Nonlinear policy. + + Class for a simple nonlinear neural network policy, where the hidden layer dimension is the mean of the input and + output dimensions. + + Attributes: + name: A string with the name of the policy. + observation_dimension: An integer defining the observation (i.e. input) dimension of the policy. + hidden_dimension: An integer defining the dimension of the hidden layer. + action_dimension: An integer defining the action (i.e. output) dimension of the policy. + linear1: The first linear neural network layer. + activation: The activation to get the hidden layer. + linear2: The second linear neural network layer. + """ + + def __init__(self, config: Config) -> None: + """Initializes the NonlinearPolicy class. + + Initializes the NonlinearPolicy class by setting fixed and variable attributes. + + Args: + config: An instance of the configuration class. + """ + super().__init__(config) + self.name = "NonlinearPolicy" + self.observation_dimension = config.OBSERVATION_DIM + self.hidden_dimension = int((config.OBSERVATION_DIM + config.ACTION_DIM) / 2) + self.action_dimension = config.ACTION_DIM + self.linear1 = torch.nn.Linear(self.observation_dimension, self.hidden_dimension) + self.activation = torch.nn.Tanh() + self.linear2 = torch.nn.Linear(self.hidden_dimension, self.action_dimension) + + def forward(self, observation: torch.Tensor) -> Tuple[torch.Tensor]: + """Forward method. + + Defines the computation performed at every call. Computes the output tensors from the input tensors. + + Args: + observation: A (B,O) tensor with the observations. + + Returns: + action: A (B,A) tensor with the predicted actions. + """ + scaled_observation = self.scale_observation(observation) + unscaled_action = self.linear2(self.activation(self.linear1(scaled_observation))) + action = self.scale_action(unscaled_action) + return (action,) diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/requirements.txt b/ocs2_mpcnet/ocs2_mpcnet_core/requirements.txt new file mode 100644 index 000000000..243bf94a5 --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/requirements.txt @@ -0,0 +1,14 @@ +# +####### requirements.txt ####### +# +###### Requirements without version specifiers ###### +black +numpy +pyyaml +tensorboard +torch +# +###### Requirements with version specifiers ###### +# +###### Refer to other requirements files ###### +# \ No newline at end of file diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/setup.py b/ocs2_mpcnet/ocs2_mpcnet_core/setup.py new file mode 100644 index 000000000..61478769f --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/setup.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python + +from setuptools import setup +from catkin_pkg.python_setup import generate_distutils_setup + +setup_args = generate_distutils_setup(packages=["ocs2_mpcnet_core"], package_dir={"": "python"}) + +setup(**setup_args) diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/src/MpcnetInterfaceBase.cpp b/ocs2_mpcnet/ocs2_mpcnet_core/src/MpcnetInterfaceBase.cpp new file mode 100644 index 000000000..e02f861d9 --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/src/MpcnetInterfaceBase.cpp @@ -0,0 +1,86 @@ +/****************************************************************************** +Copyright (c) 2022, Farbod Farshidian. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +******************************************************************************/ + +#include "ocs2_mpcnet_core/MpcnetInterfaceBase.h" + +namespace ocs2 { +namespace mpcnet { + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +void MpcnetInterfaceBase::startDataGeneration(scalar_t alpha, const std::string& policyFilePath, scalar_t timeStep, size_t dataDecimation, + size_t nSamples, const matrix_t& samplingCovariance, + const std::vector& initialObservations, + const std::vector& modeSchedules, + const std::vector& targetTrajectories) { + mpcnetRolloutManagerPtr_->startDataGeneration(alpha, policyFilePath, timeStep, dataDecimation, nSamples, samplingCovariance, + initialObservations, modeSchedules, targetTrajectories); +} + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +bool MpcnetInterfaceBase::isDataGenerationDone() { + return mpcnetRolloutManagerPtr_->isDataGenerationDone(); +} + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +data_array_t MpcnetInterfaceBase::getGeneratedData() { + return mpcnetRolloutManagerPtr_->getGeneratedData(); +} + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +void MpcnetInterfaceBase::startPolicyEvaluation(scalar_t alpha, const std::string& policyFilePath, scalar_t timeStep, + const std::vector& initialObservations, + const std::vector& modeSchedules, + const std::vector& targetTrajectories) { + mpcnetRolloutManagerPtr_->startPolicyEvaluation(alpha, policyFilePath, timeStep, initialObservations, modeSchedules, targetTrajectories); +} + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +bool MpcnetInterfaceBase::isPolicyEvaluationDone() { + return mpcnetRolloutManagerPtr_->isPolicyEvaluationDone(); +} + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +metrics_array_t MpcnetInterfaceBase::getComputedMetrics() { + return mpcnetRolloutManagerPtr_->getComputedMetrics(); +} + +} // namespace mpcnet +} // namespace ocs2 diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/src/MpcnetPybindings.cpp b/ocs2_mpcnet/ocs2_mpcnet_core/src/MpcnetPybindings.cpp new file mode 100644 index 000000000..38fec2cf0 --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/src/MpcnetPybindings.cpp @@ -0,0 +1,34 @@ +/****************************************************************************** +Copyright (c) 2022, Farbod Farshidian. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +******************************************************************************/ + +#include "ocs2_mpcnet_core/MpcnetPybindMacros.h" + +#include "ocs2_mpcnet_core/MpcnetInterfaceBase.h" + +CREATE_MPCNET_PYTHON_BINDINGS(MpcnetPybindings) diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/src/control/MpcnetBehavioralController.cpp b/ocs2_mpcnet/ocs2_mpcnet_core/src/control/MpcnetBehavioralController.cpp new file mode 100644 index 000000000..3c3ef8437 --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/src/control/MpcnetBehavioralController.cpp @@ -0,0 +1,110 @@ +/****************************************************************************** +Copyright (c) 2022, Farbod Farshidian. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +******************************************************************************/ + +#include "ocs2_mpcnet_core/control/MpcnetBehavioralController.h" + +#include + +namespace ocs2 { +namespace mpcnet { + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +vector_t MpcnetBehavioralController::computeInput(scalar_t t, const vector_t& x) { + if (optimalControllerPtr_ != nullptr && learnedControllerPtr_ != nullptr) { + if (numerics::almost_eq(alpha_, 0.0)) { + return learnedControllerPtr_->computeInput(t, x); + } else if (numerics::almost_eq(alpha_, 1.0)) { + return optimalControllerPtr_->computeInput(t, x); + } else { + return alpha_ * optimalControllerPtr_->computeInput(t, x) + (1 - alpha_) * learnedControllerPtr_->computeInput(t, x); + } + } else { + throw std::runtime_error( + "[MpcnetBehavioralController::computeInput] cannot compute input, since optimal and/or learned controller not set."); + } +} + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +int MpcnetBehavioralController::size() const { + if (optimalControllerPtr_ != nullptr && learnedControllerPtr_ != nullptr) { + return std::max(optimalControllerPtr_->size(), learnedControllerPtr_->size()); + } else if (optimalControllerPtr_ != nullptr) { + return optimalControllerPtr_->size(); + } else if (learnedControllerPtr_ != nullptr) { + return learnedControllerPtr_->size(); + } else { + return 0; + } +} + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +void MpcnetBehavioralController::clear() { + if (optimalControllerPtr_ != nullptr) { + optimalControllerPtr_->clear(); + } + if (learnedControllerPtr_ != nullptr) { + learnedControllerPtr_->clear(); + } +} + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +bool MpcnetBehavioralController::empty() const { + if (optimalControllerPtr_ != nullptr && learnedControllerPtr_ != nullptr) { + return optimalControllerPtr_->empty() && learnedControllerPtr_->empty(); + } else if (optimalControllerPtr_ != nullptr) { + return optimalControllerPtr_->empty(); + } else if (learnedControllerPtr_ != nullptr) { + return learnedControllerPtr_->empty(); + } else { + return true; + } +} + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +void MpcnetBehavioralController::concatenate(const ControllerBase* otherController, int index, int length) { + if (optimalControllerPtr_ != nullptr) { + optimalControllerPtr_->concatenate(otherController, index, length); + } + if (learnedControllerPtr_ != nullptr) { + learnedControllerPtr_->concatenate(otherController, index, length); + } +} + +} // namespace mpcnet +} // namespace ocs2 diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/src/control/MpcnetOnnxController.cpp b/ocs2_mpcnet/ocs2_mpcnet_core/src/control/MpcnetOnnxController.cpp new file mode 100644 index 000000000..b77c73d2b --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/src/control/MpcnetOnnxController.cpp @@ -0,0 +1,89 @@ +/****************************************************************************** +Copyright (c) 2022, Farbod Farshidian. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +******************************************************************************/ + +#include "ocs2_mpcnet_core/control/MpcnetOnnxController.h" + +namespace ocs2 { +namespace mpcnet { + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +void MpcnetOnnxController::loadPolicyModel(const std::string& policyFilePath) { + policyFilePath_ = policyFilePath; + // create session + Ort::SessionOptions sessionOptions; + sessionOptions.SetIntraOpNumThreads(1); + sessionOptions.SetInterOpNumThreads(1); + sessionPtr_.reset(new Ort::Session(*onnxEnvironmentPtr_, policyFilePath_.c_str(), sessionOptions)); + // get input and output info + inputNames_.clear(); + outputNames_.clear(); + inputShapes_.clear(); + outputShapes_.clear(); + Ort::AllocatorWithDefaultOptions allocator; + for (int i = 0; i < sessionPtr_->GetInputCount(); i++) { + inputNames_.push_back(sessionPtr_->GetInputName(i, allocator)); + inputShapes_.push_back(sessionPtr_->GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape()); + } + for (int i = 0; i < sessionPtr_->GetOutputCount(); i++) { + outputNames_.push_back(sessionPtr_->GetOutputName(i, allocator)); + outputShapes_.push_back(sessionPtr_->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape()); + } +} + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +vector_t MpcnetOnnxController::computeInput(const scalar_t t, const vector_t& x) { + if (sessionPtr_ == nullptr) { + throw std::runtime_error("[MpcnetOnnxController::computeInput] cannot compute input, since policy model is not loaded."); + } + // create input tensor object + Eigen::Matrix observation = + mpcnetDefinitionPtr_->getObservation(t, x, referenceManagerPtr_->getModeSchedule(), referenceManagerPtr_->getTargetTrajectories()) + .cast(); + Ort::MemoryInfo memoryInfo = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault); + std::vector inputValues; + inputValues.push_back(Ort::Value::CreateTensor(memoryInfo, observation.data(), observation.size(), + inputShapes_[0].data(), inputShapes_[0].size())); + // run inference + Ort::RunOptions runOptions; + std::vector outputValues = sessionPtr_->Run(runOptions, inputNames_.data(), inputValues.data(), 1, outputNames_.data(), 1); + // evaluate output tensor object + Eigen::Map> action(outputValues[0].GetTensorMutableData(), + outputShapes_[0][1], outputShapes_[0][0]); + std::pair actionTransformation = mpcnetDefinitionPtr_->getActionTransformation( + t, x, referenceManagerPtr_->getModeSchedule(), referenceManagerPtr_->getTargetTrajectories()); + // transform action + return actionTransformation.first * action.cast() + actionTransformation.second; +} + +} // namespace mpcnet +} // namespace ocs2 diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/src/dummy/MpcnetDummyLoopRos.cpp b/ocs2_mpcnet/ocs2_mpcnet_core/src/dummy/MpcnetDummyLoopRos.cpp new file mode 100644 index 000000000..80e138143 --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/src/dummy/MpcnetDummyLoopRos.cpp @@ -0,0 +1,151 @@ +/****************************************************************************** +Copyright (c) 2022, Farbod Farshidian. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +******************************************************************************/ + +#include "ocs2_mpcnet_core/dummy/MpcnetDummyLoopRos.h" + +#include + +namespace ocs2 { +namespace mpcnet { + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +MpcnetDummyLoopRos::MpcnetDummyLoopRos(scalar_t controlFrequency, scalar_t rosFrequency, std::unique_ptr mpcnetPtr, + std::unique_ptr rolloutPtr, std::shared_ptr rosReferenceManagerPtr) + : controlFrequency_(controlFrequency), + rosFrequency_(rosFrequency), + mpcnetPtr_(std::move(mpcnetPtr)), + rolloutPtr_(std::move(rolloutPtr)), + rosReferenceManagerPtr_(std::move(rosReferenceManagerPtr)) {} + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +void MpcnetDummyLoopRos::run(const SystemObservation& systemObservation, const TargetTrajectories& targetTrajectories) { + ros::WallRate rosRate(rosFrequency_); + scalar_t duration = 1.0 / rosFrequency_; + + // initialize + SystemObservation initialSystemObservation = systemObservation; + SystemObservation finalSystemObservation = systemObservation; + rosReferenceManagerPtr_->setTargetTrajectories(targetTrajectories); + + // start of while loop + while (::ros::ok() && ::ros::master::check()) { + // update system observation + swap(initialSystemObservation, finalSystemObservation); + + // update reference manager and synchronized modules + preSolverRun(initialSystemObservation.time, initialSystemObservation.state); + + // rollout + rollout(duration, initialSystemObservation, finalSystemObservation); + + // update observers + PrimalSolution primalSolution; + primalSolution.timeTrajectory_ = {finalSystemObservation.time}; + primalSolution.stateTrajectory_ = {finalSystemObservation.state}; + primalSolution.inputTrajectory_ = {finalSystemObservation.input}; + primalSolution.modeSchedule_ = rosReferenceManagerPtr_->getModeSchedule(); + primalSolution.controllerPtr_ = std::unique_ptr(mpcnetPtr_->clone()); + CommandData commandData; + commandData.mpcInitObservation_ = initialSystemObservation; + commandData.mpcTargetTrajectories_ = rosReferenceManagerPtr_->getTargetTrajectories(); + for (auto& observer : observerPtrs_) { + observer->update(finalSystemObservation, primalSolution, commandData); + } + + // process callbacks and sleep + ::ros::spinOnce(); + rosRate.sleep(); + } // end of while loop +} + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +void MpcnetDummyLoopRos::addObserver(std::shared_ptr observer) { + observerPtrs_.push_back(std::move(observer)); +} + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +void MpcnetDummyLoopRos::addSynchronizedModule(std::shared_ptr synchronizedModule) { + synchronizedModulePtrs_.push_back(std::move(synchronizedModule)); +} + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +void MpcnetDummyLoopRos::rollout(scalar_t duration, const SystemObservation& initialSystemObservation, + SystemObservation& finalSystemObservation) { + scalar_t timeStep = 1.0 / controlFrequency_; + + // initial time, state and input + scalar_t time = initialSystemObservation.time; + vector_t state = initialSystemObservation.state; + vector_t input = initialSystemObservation.input; + + // start of while loop + while (time <= initialSystemObservation.time + duration) { + // forward simulate system + ModeSchedule modeSchedule = rosReferenceManagerPtr_->getModeSchedule(); + scalar_array_t timeTrajectory; + size_array_t postEventIndicesStock; + vector_array_t stateTrajectory; + vector_array_t inputTrajectory; + rolloutPtr_->run(time, state, time + timeStep, mpcnetPtr_.get(), modeSchedule, timeTrajectory, postEventIndicesStock, stateTrajectory, + inputTrajectory); + + // update time, state and input + time = timeTrajectory.back(); + state = stateTrajectory.back(); + input = inputTrajectory.back(); + } // end of while loop + + // final time, state and input + finalSystemObservation.time = time; + finalSystemObservation.state = state; + finalSystemObservation.input = input; +} + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +void MpcnetDummyLoopRos::preSolverRun(scalar_t time, const vector_t& state) { + rosReferenceManagerPtr_->preSolverRun(time, time + scalar_t(1.0), state); + for (auto& module : synchronizedModulePtrs_) { + module->preSolverRun(time, time + scalar_t(1.0), state, *rosReferenceManagerPtr_); + } +} + +} // namespace mpcnet +} // namespace ocs2 diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/src/dummy/MpcnetDummyObserverRos.cpp b/ocs2_mpcnet/ocs2_mpcnet_core/src/dummy/MpcnetDummyObserverRos.cpp new file mode 100644 index 000000000..6005b62c9 --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/src/dummy/MpcnetDummyObserverRos.cpp @@ -0,0 +1,56 @@ +/****************************************************************************** +Copyright (c) 2022, Farbod Farshidian. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +******************************************************************************/ + +#include "ocs2_mpcnet_core/dummy/MpcnetDummyObserverRos.h" + +#include + +#include + +namespace ocs2 { +namespace mpcnet { + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +MpcnetDummyObserverRos::MpcnetDummyObserverRos(ros::NodeHandle& nodeHandle, std::string topicPrefix) { + observationPublisher_ = nodeHandle.advertise(topicPrefix + "_mpc_observation", 1); +} + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +void MpcnetDummyObserverRos::update(const SystemObservation& observation, const PrimalSolution& primalSolution, + const CommandData& command) { + auto observationMsg = ros_msg_conversions::createObservationMsg(observation); + observationPublisher_.publish(observationMsg); +} + +} // namespace mpcnet +} // namespace ocs2 diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/src/rollout/MpcnetDataGeneration.cpp b/ocs2_mpcnet/ocs2_mpcnet_core/src/rollout/MpcnetDataGeneration.cpp new file mode 100644 index 000000000..0cfc884eb --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/src/rollout/MpcnetDataGeneration.cpp @@ -0,0 +1,94 @@ +/****************************************************************************** +Copyright (c) 2022, Farbod Farshidian. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +******************************************************************************/ + +#include "ocs2_mpcnet_core/rollout/MpcnetDataGeneration.h" + +#include + +#include "ocs2_mpcnet_core/control/MpcnetBehavioralController.h" + +namespace ocs2 { +namespace mpcnet { + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +const data_array_t* MpcnetDataGeneration::run(scalar_t alpha, const std::string& policyFilePath, scalar_t timeStep, size_t dataDecimation, + size_t nSamples, const matrix_t& samplingCovariance, + const SystemObservation& initialObservation, const ModeSchedule& modeSchedule, + const TargetTrajectories& targetTrajectories) { + // clear data array + dataArray_.clear(); + + // set system + set(alpha, policyFilePath, initialObservation, modeSchedule, targetTrajectories); + + // set up scalar standard normal generator and compute Cholesky decomposition of covariance matrix + std::random_device randomDevice; + std::default_random_engine pseudoRandomNumberGenerator(randomDevice()); + std::normal_distribution standardNormalDistribution(scalar_t(0.0), scalar_t(1.0)); + auto standardNormalNullaryOp = [&](scalar_t) -> scalar_t { return standardNormalDistribution(pseudoRandomNumberGenerator); }; + const matrix_t L = samplingCovariance.llt().matrixL(); + + // run data generation + int iteration = 0; + try { + while (systemObservation_.time <= targetTrajectories.timeTrajectory.back()) { + // step system + step(timeStep); + + // downsample the data signal by an integer factor + if (iteration % dataDecimation == 0) { + // get nominal data point + const vector_t deviation = vector_t::Zero(primalSolution_.stateTrajectory_.front().size()); + dataArray_.push_back(getDataPoint(*mpcPtr_, *mpcnetDefinitionPtr_, deviation)); + + // get samples around nominal data point + for (int i = 0; i < nSamples; i++) { + const vector_t deviation = L * vector_t::NullaryExpr(primalSolution_.stateTrajectory_.front().size(), standardNormalNullaryOp); + dataArray_.push_back(getDataPoint(*mpcPtr_, *mpcnetDefinitionPtr_, deviation)); + } + } + + // update iteration + ++iteration; + } + } catch (const std::exception& e) { + // print error for exceptions + std::cerr << "[MpcnetDataGeneration::run] a standard exception was caught, with message: " << e.what() << "\n"; + // this data generation run failed, clear data + dataArray_.clear(); + } + + // return pointer to the data array + return &dataArray_; +} + +} // namespace mpcnet +} // namespace ocs2 diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/src/rollout/MpcnetPolicyEvaluation.cpp b/ocs2_mpcnet/ocs2_mpcnet_core/src/rollout/MpcnetPolicyEvaluation.cpp new file mode 100644 index 000000000..d1847fc04 --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/src/rollout/MpcnetPolicyEvaluation.cpp @@ -0,0 +1,74 @@ +/****************************************************************************** +Copyright (c) 2022, Farbod Farshidian. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +******************************************************************************/ + +#include "ocs2_mpcnet_core/rollout/MpcnetPolicyEvaluation.h" + +namespace ocs2 { +namespace mpcnet { + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +metrics_t MpcnetPolicyEvaluation::run(scalar_t alpha, const std::string& policyFilePath, scalar_t timeStep, + const SystemObservation& initialObservation, const ModeSchedule& modeSchedule, + const TargetTrajectories& targetTrajectories) { + // declare metrics + metrics_t metrics; + + // set system + set(alpha, policyFilePath, initialObservation, modeSchedule, targetTrajectories); + + // run policy evaluation + try { + while (systemObservation_.time <= targetTrajectories.timeTrajectory.back()) { + // step system + step(timeStep); + + // incurred quantities + const scalar_t time = primalSolution_.timeTrajectory_.front(); + const vector_t state = primalSolution_.stateTrajectory_.front(); + const vector_t input = behavioralControllerPtr_->computeInput(time, state); + metrics.incurredHamiltonian += mpcPtr_->getSolverPtr()->getHamiltonian(time, state, input).f * timeStep; + } + } catch (const std::exception& e) { + // print error for exceptions + std::cerr << "[MpcnetPolicyEvaluation::run] a standard exception was caught, with message: " << e.what() << "\n"; + // this policy evaluation run failed, incurred quantities are not reported + metrics.incurredHamiltonian = std::numeric_limits::quiet_NaN(); + } + + // report survival time + metrics.survivalTime = systemObservation_.time; + + // return metrics + return metrics; +} + +} // namespace mpcnet +} // namespace ocs2 diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/src/rollout/MpcnetRolloutBase.cpp b/ocs2_mpcnet/ocs2_mpcnet_core/src/rollout/MpcnetRolloutBase.cpp new file mode 100644 index 000000000..8037e5541 --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/src/rollout/MpcnetRolloutBase.cpp @@ -0,0 +1,101 @@ +/****************************************************************************** +Copyright (c) 2022, Farbod Farshidian. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +******************************************************************************/ + +#include "ocs2_mpcnet_core/rollout/MpcnetRolloutBase.h" + +#include "ocs2_mpcnet_core/control/MpcnetBehavioralController.h" + +namespace ocs2 { +namespace mpcnet { + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +void MpcnetRolloutBase::set(scalar_t alpha, const std::string& policyFilePath, const SystemObservation& initialObservation, + const ModeSchedule& modeSchedule, const TargetTrajectories& targetTrajectories) { + // init system observation + systemObservation_ = initialObservation; + + // reset mpc + mpcPtr_->reset(); + + // prepare learned controller + mpcnetPtr_->loadPolicyModel(policyFilePath); + + // reset rollout, i.e. reset the internal simulator state (e.g. relevant for RaiSim) + rolloutPtr_->resetRollout(); + + // update the reference manager + referenceManagerPtr_->setModeSchedule(modeSchedule); + referenceManagerPtr_->setTargetTrajectories(targetTrajectories); + + // set up behavioral controller with mixture parameter alpha and learned controller + behavioralControllerPtr_->setAlpha(alpha); + behavioralControllerPtr_->setLearnedController(*mpcnetPtr_); +} + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +void MpcnetRolloutBase::step(scalar_t timeStep) { + // run mpc + if (!mpcPtr_->run(systemObservation_.time, systemObservation_.state)) { + throw std::runtime_error("[MpcnetRolloutBase::step] main routine of MPC returned false."); + } + + // update primal solution + primalSolution_ = mpcPtr_->getSolverPtr()->primalSolution(mpcPtr_->getSolverPtr()->getFinalTime()); + + // update behavioral controller with MPC controller + behavioralControllerPtr_->setOptimalController(*primalSolution_.controllerPtr_); + + // forward simulate system with behavioral controller + scalar_array_t timeTrajectory; + size_array_t postEventIndicesStock; + vector_array_t stateTrajectory; + vector_array_t inputTrajectory; + rolloutPtr_->run(primalSolution_.timeTrajectory_.front(), primalSolution_.stateTrajectory_.front(), + primalSolution_.timeTrajectory_.front() + timeStep, behavioralControllerPtr_.get(), primalSolution_.modeSchedule_, + timeTrajectory, postEventIndicesStock, stateTrajectory, inputTrajectory); + + // update system observation + systemObservation_.time = timeTrajectory.back(); + systemObservation_.state = stateTrajectory.back(); + systemObservation_.input = inputTrajectory.back(); + systemObservation_.mode = primalSolution_.modeSchedule_.modeAtTime(systemObservation_.time); + + // check forward simulated system + if (!mpcnetDefinitionPtr_->isValid(systemObservation_.time, systemObservation_.state, referenceManagerPtr_->getModeSchedule(), + referenceManagerPtr_->getTargetTrajectories())) { + throw std::runtime_error("MpcnetDataGeneration::run Tuple (time, state, modeSchedule, targetTrajectories) is not valid."); + } +} + +} // namespace mpcnet +} // namespace ocs2 diff --git a/ocs2_mpcnet/ocs2_mpcnet_core/src/rollout/MpcnetRolloutManager.cpp b/ocs2_mpcnet/ocs2_mpcnet_core/src/rollout/MpcnetRolloutManager.cpp new file mode 100644 index 000000000..9d3f43d2b --- /dev/null +++ b/ocs2_mpcnet/ocs2_mpcnet_core/src/rollout/MpcnetRolloutManager.cpp @@ -0,0 +1,243 @@ +/****************************************************************************** +Copyright (c) 2022, Farbod Farshidian. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +******************************************************************************/ + +#include "ocs2_mpcnet_core/rollout/MpcnetRolloutManager.h" + +namespace ocs2 { +namespace mpcnet { + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +MpcnetRolloutManager::MpcnetRolloutManager(size_t nDataGenerationThreads, size_t nPolicyEvaluationThreads, + std::vector> mpcPtrs, + std::vector> mpcnetPtrs, + std::vector> rolloutPtrs, + std::vector> mpcnetDefinitionPtrs, + std::vector> referenceManagerPtrs) { + // data generation + nDataGenerationThreads_ = nDataGenerationThreads; + if (nDataGenerationThreads_ > 0) { + dataGenerationThreadPoolPtr_.reset(new ThreadPool(nDataGenerationThreads_)); + dataGenerationPtrs_.reserve(nDataGenerationThreads); + for (int i = 0; i < nDataGenerationThreads; i++) { + dataGenerationPtrs_.push_back(std::unique_ptr( + new MpcnetDataGeneration(std::move(mpcPtrs.at(i)), std::move(mpcnetPtrs.at(i)), std::move(rolloutPtrs.at(i)), + std::move(mpcnetDefinitionPtrs.at(i)), referenceManagerPtrs.at(i)))); + } + } + + // policy evaluation + nPolicyEvaluationThreads_ = nPolicyEvaluationThreads; + if (nPolicyEvaluationThreads_ > 0) { + policyEvaluationThreadPoolPtr_.reset(new ThreadPool(nPolicyEvaluationThreads_)); + policyEvaluationPtrs_.reserve(nPolicyEvaluationThreads_); + for (int i = nDataGenerationThreads_; i < (nDataGenerationThreads_ + nPolicyEvaluationThreads_); i++) { + policyEvaluationPtrs_.push_back(std::unique_ptr( + new MpcnetPolicyEvaluation(std::move(mpcPtrs.at(i)), std::move(mpcnetPtrs.at(i)), std::move(rolloutPtrs.at(i)), + std::move(mpcnetDefinitionPtrs.at(i)), referenceManagerPtrs.at(i)))); + } + } +} + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +void MpcnetRolloutManager::startDataGeneration(scalar_t alpha, const std::string& policyFilePath, scalar_t timeStep, size_t dataDecimation, + size_t nSamples, const matrix_t& samplingCovariance, + const std::vector& initialObservations, + const std::vector& modeSchedules, + const std::vector& targetTrajectories) { + if (nDataGenerationThreads_ <= 0) { + throw std::runtime_error("[MpcnetRolloutManager::startDataGeneration] cannot work without at least one data generation thread."); + } + + // reset variables + dataGenerationFtrs_.clear(); + nDataGenerationTasksDone_ = 0; + + // push tasks into pool + for (int i = 0; i < initialObservations.size(); i++) { + dataGenerationFtrs_.push_back(dataGenerationThreadPoolPtr_->run([=](int threadNumber) { + const auto* result = + dataGenerationPtrs_[threadNumber]->run(alpha, policyFilePath, timeStep, dataDecimation, nSamples, samplingCovariance, + initialObservations.at(i), modeSchedules.at(i), targetTrajectories.at(i)); + nDataGenerationTasksDone_++; + // print thread and task number + std::cerr << "Data generation thread " << threadNumber << " finished task " << nDataGenerationTasksDone_ << "\n"; + return result; + })); + } +} + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +bool MpcnetRolloutManager::isDataGenerationDone() { + if (nDataGenerationThreads_ <= 0) { + throw std::runtime_error("[MpcnetRolloutManager::isDataGenerationDone] cannot work without at least one data generation thread."); + } + if (dataGenerationFtrs_.size() <= 0) { + throw std::runtime_error( + "[MpcnetRolloutManager::isDataGenerationDone] cannot return if startDataGeneration has not been triggered once."); + } + + // check if done + if (nDataGenerationTasksDone_ < dataGenerationFtrs_.size()) { + return false; + } else if (nDataGenerationTasksDone_ == dataGenerationFtrs_.size()) { + return true; + } else { + throw std::runtime_error("[MpcnetRolloutManager::isDataGenerationDone] error since more tasks done than futures available."); + } +} + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +const data_array_t& MpcnetRolloutManager::getGeneratedData() { + if (nDataGenerationThreads_ <= 0) { + throw std::runtime_error("[MpcnetRolloutManager::getGeneratedData] cannot work without at least one data generation thread."); + } + if (!isDataGenerationDone()) { + throw std::runtime_error("[MpcnetRolloutManager::getGeneratedData] cannot get data when data generation is not done."); + } + + // clear data array + dataArray_.clear(); + + // get pointers to data + std::vector dataPtrs; + dataPtrs.reserve(dataGenerationFtrs_.size()); + for (auto& dataGenerationFtr : dataGenerationFtrs_) { + try { + // get results from futures of the tasks + dataPtrs.push_back(dataGenerationFtr.get()); + } catch (const std::exception& e) { + // print error for exceptions + std::cerr << "[MpcnetRolloutManager::getGeneratedData] a standard exception was caught, with message: " << e.what() << "\n"; + } + } + + // find number of data points + int nDataPoints = 0; + for (int i = 0; i < dataPtrs.size(); i++) { + nDataPoints += dataPtrs[i]->size(); + } + + // fill data array + dataArray_.reserve(nDataPoints); + for (const auto dataPtr : dataPtrs) { + dataArray_.insert(dataArray_.end(), dataPtr->begin(), dataPtr->end()); + } + + // return data array + return dataArray_; +} + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +void MpcnetRolloutManager::startPolicyEvaluation(scalar_t alpha, const std::string& policyFilePath, scalar_t timeStep, + const std::vector& initialObservations, + const std::vector& modeSchedules, + const std::vector& targetTrajectories) { + if (nPolicyEvaluationThreads_ <= 0) { + throw std::runtime_error("[MpcnetRolloutManager::startPolicyEvaluation] cannot work without at least one policy evaluation thread."); + } + + // reset variables + policyEvaluationFtrs_.clear(); + nPolicyEvaluationTasksDone_ = 0; + + // push tasks into pool + for (int i = 0; i < initialObservations.size(); i++) { + policyEvaluationFtrs_.push_back(policyEvaluationThreadPoolPtr_->run([=](int threadNumber) { + const auto result = policyEvaluationPtrs_[threadNumber]->run(alpha, policyFilePath, timeStep, initialObservations.at(i), + modeSchedules.at(i), targetTrajectories.at(i)); + nPolicyEvaluationTasksDone_++; + // print thread and task number + std::cerr << "Policy evaluation thread " << threadNumber << " finished task " << nPolicyEvaluationTasksDone_ << "\n"; + return result; + })); + } +} + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +bool MpcnetRolloutManager::isPolicyEvaluationDone() { + if (nPolicyEvaluationThreads_ <= 0) { + throw std::runtime_error("[MpcnetRolloutManager::isPolicyEvaluationDone] cannot work without at least one policy evaluation thread."); + } + if (policyEvaluationFtrs_.size() <= 0) { + throw std::runtime_error( + "[MpcnetRolloutManager::isPolicyEvaluationDone] cannot return if startPolicyEvaluation has not been triggered once."); + } + + // check if done + if (nPolicyEvaluationTasksDone_ < policyEvaluationFtrs_.size()) { + return false; + } else if (nPolicyEvaluationTasksDone_ == policyEvaluationFtrs_.size()) { + return true; + } else { + throw std::runtime_error("[MpcnetRolloutManager::isPolicyEvaluationDone] error since more tasks done than futures available."); + } +} + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +metrics_array_t MpcnetRolloutManager::getComputedMetrics() { + if (nPolicyEvaluationThreads_ <= 0) { + throw std::runtime_error("[MpcnetRolloutManager::getComputedMetrics] cannot work without at least one policy evaluation thread."); + } + if (!isPolicyEvaluationDone()) { + throw std::runtime_error("[MpcnetRolloutManager::getComputedMetrics] cannot get metrics when policy evaluation is not done."); + } + + // get metrics and fill metrics array + metrics_array_t metricsArray; + metricsArray.reserve(policyEvaluationFtrs_.size()); + for (auto& policyEvaluationFtr : policyEvaluationFtrs_) { + try { + // get results from futures of the tasks + metricsArray.push_back(policyEvaluationFtr.get()); + } catch (const std::exception& e) { + // print error for exceptions + std::cerr << "[MpcnetRolloutManager::getComputedMetrics] a standard exception was caught, with message: " << e.what() << "\n"; + } + } + + // return metrics array + return metricsArray; +} + +} // namespace mpcnet +} // namespace ocs2 diff --git a/ocs2_raisim/ocs2_legged_robot_raisim/config/raisim.info b/ocs2_raisim/ocs2_legged_robot_raisim/config/raisim.info index b4c9212ea..404503643 100644 --- a/ocs2_raisim/ocs2_legged_robot_raisim/config/raisim.info +++ b/ocs2_raisim/ocs2_legged_robot_raisim/config/raisim.info @@ -30,7 +30,7 @@ rollout [11] RH_KFE } - controlMode 0 ; 0: FORCE_AND_TORQUE, 1: PD_PLUS_FEEDFORWARD_TORQUE + controlMode 1 ; 0: FORCE_AND_TORQUE, 1: PD_PLUS_FEEDFORWARD_TORQUE ; PD control on torque level (if controlMode = 1) pGains diff --git a/ocs2_raisim/ocs2_raisim_core/include/ocs2_raisim_core/RaisimRollout.h b/ocs2_raisim/ocs2_raisim_core/include/ocs2_raisim_core/RaisimRollout.h index 88888ed28..3d0311abb 100644 --- a/ocs2_raisim/ocs2_raisim_core/include/ocs2_raisim_core/RaisimRollout.h +++ b/ocs2_raisim/ocs2_raisim_core/include/ocs2_raisim_core/RaisimRollout.h @@ -81,6 +81,9 @@ class RaisimRollout final : public RolloutBase { //! Copy constructor RaisimRollout(const RaisimRollout& other); + //! Destructor + ~RaisimRollout() override; + void resetRollout() override { raisimRolloutSettings_.setSimulatorStateOnRolloutRunOnce_ = true; } RaisimRollout* clone() const override { return new RaisimRollout(*this); } diff --git a/ocs2_raisim/ocs2_raisim_core/src/RaisimRollout.cpp b/ocs2_raisim/ocs2_raisim_core/src/RaisimRollout.cpp index 14790e77a..122b483ec 100644 --- a/ocs2_raisim/ocs2_raisim_core/src/RaisimRollout.cpp +++ b/ocs2_raisim/ocs2_raisim_core/src/RaisimRollout.cpp @@ -94,6 +94,15 @@ RaisimRollout::RaisimRollout(const RaisimRollout& other) } } +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +RaisimRollout::~RaisimRollout() { + if (raisimRolloutSettings_.raisimServer_) { + serverPtr_->killServer(); + } +} + /******************************************************************************************************/ /******************************************************************************************************/ /******************************************************************************************************/ diff --git a/ocs2_robotic_examples/ocs2_ballbot_ros/CMakeLists.txt b/ocs2_robotic_examples/ocs2_ballbot_ros/CMakeLists.txt index 5121de843..6d0eebfc6 100644 --- a/ocs2_robotic_examples/ocs2_ballbot_ros/CMakeLists.txt +++ b/ocs2_robotic_examples/ocs2_ballbot_ros/CMakeLists.txt @@ -42,6 +42,8 @@ catkin_package( ${EIGEN3_INCLUDE_DIRS} CATKIN_DEPENDS ${CATKIN_PACKAGE_DEPENDENCIES} + LIBRARIES + ${PROJECT_NAME} DEPENDS Boost ) @@ -57,6 +59,18 @@ include_directories( ${catkin_INCLUDE_DIRS} ) +# main library +add_library(${PROJECT_NAME} + src/BallbotDummyVisualization.cpp +) +add_dependencies(${PROJECT_NAME} + ${catkin_EXPORTED_TARGETS} +) +target_link_libraries(${PROJECT_NAME} + ${catkin_LIBRARIES} +) +target_compile_options(${PROJECT_NAME} PUBLIC ${OCS2_CXX_FLAGS}) + # Mpc node add_executable(ballbot_mpc src/BallbotMpcNode.cpp @@ -72,12 +86,13 @@ target_compile_options(ballbot_mpc PRIVATE ${OCS2_CXX_FLAGS}) # Dummy node add_executable(ballbot_dummy_test src/DummyBallbotNode.cpp - src/BallbotDummyVisualization.cpp ) add_dependencies(ballbot_dummy_test + ${PROJECT_NAME} ${catkin_EXPORTED_TARGETS} ) target_link_libraries(ballbot_dummy_test + ${PROJECT_NAME} ${catkin_LIBRARIES} ) target_compile_options(ballbot_dummy_test PRIVATE ${OCS2_CXX_FLAGS}) @@ -129,6 +144,7 @@ if(cmake_clang_tools_FOUND) message(STATUS "Run clang tooling for target ocs2_ballbot") add_clang_tooling( TARGETS + ${PROJECT_NAME} ballbot_mpc ballbot_dummy_test ballbot_target @@ -142,6 +158,12 @@ endif(cmake_clang_tools_FOUND) ## Install ## ############# +install(TARGETS ${PROJECT_NAME} + ARCHIVE DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} + LIBRARY DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} + RUNTIME DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} +) + install(DIRECTORY include/${PROJECT_NAME}/ DESTINATION ${CATKIN_PACKAGE_INCLUDE_DESTINATION} ) diff --git a/ocs2_robotic_examples/ocs2_legged_robot/CMakeLists.txt b/ocs2_robotic_examples/ocs2_legged_robot/CMakeLists.txt index 19f7c45d5..8a19765f3 100644 --- a/ocs2_robotic_examples/ocs2_legged_robot/CMakeLists.txt +++ b/ocs2_robotic_examples/ocs2_legged_robot/CMakeLists.txt @@ -94,6 +94,7 @@ add_library(${PROJECT_NAME} src/foot_planner/SwingTrajectoryPlanner.cpp src/gait/Gait.cpp src/gait/GaitSchedule.cpp + src/gait/LegLogic.cpp src/gait/ModeSequenceTemplate.cpp src/LeggedRobotInterface.cpp src/LeggedRobotPreComputation.cpp diff --git a/ocs2_robotic_examples/ocs2_legged_robot/config/mpc/task.info b/ocs2_robotic_examples/ocs2_legged_robot/config/mpc/task.info index c1cf596d7..b7eae9ae6 100644 --- a/ocs2_robotic_examples/ocs2_legged_robot/config/mpc/task.info +++ b/ocs2_robotic_examples/ocs2_legged_robot/config/mpc/task.info @@ -9,7 +9,7 @@ legged_robot_interface model_settings { - positionErrorGain 20.0 + positionErrorGain 0.0 ; 20.0 phaseTransitionStanceTime 0.4 verboseCppAd true diff --git a/ocs2_robotic_examples/ocs2_legged_robot/include/ocs2_legged_robot/gait/GaitSchedule.h b/ocs2_robotic_examples/ocs2_legged_robot/include/ocs2_legged_robot/gait/GaitSchedule.h index 1dd064b01..fe97199f4 100644 --- a/ocs2_robotic_examples/ocs2_legged_robot/include/ocs2_legged_robot/gait/GaitSchedule.h +++ b/ocs2_robotic_examples/ocs2_legged_robot/include/ocs2_legged_robot/gait/GaitSchedule.h @@ -44,6 +44,15 @@ class GaitSchedule { GaitSchedule(ModeSchedule initModeSchedule, ModeSequenceTemplate initModeSequenceTemplate, scalar_t phaseTransitionStanceTime); /** + * Sets the mode schedule. + * + * @param [in] modeSchedule: The mode schedule to be used. + */ + void setModeSchedule(const ModeSchedule& modeSchedule) { modeSchedule_ = modeSchedule; } + + /** + * Gets the mode schedule. + * * @param [in] lowerBoundTime: The smallest time for which the ModeSchedule should be defined. * @param [in] upperBoundTime: The greatest time for which the ModeSchedule should be defined. */ diff --git a/ocs2_robotic_examples/ocs2_legged_robot/include/ocs2_legged_robot/gait/LegLogic.h b/ocs2_robotic_examples/ocs2_legged_robot/include/ocs2_legged_robot/gait/LegLogic.h new file mode 100644 index 000000000..6f105b980 --- /dev/null +++ b/ocs2_robotic_examples/ocs2_legged_robot/include/ocs2_legged_robot/gait/LegLogic.h @@ -0,0 +1,112 @@ +/****************************************************************************** +Copyright (c) 2021, Farbod Farshidian. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +******************************************************************************/ + +#pragma once + +#include + +#include "ocs2_legged_robot/common/Types.h" + +namespace ocs2 { +namespace legged_robot { + +struct LegPhase { + scalar_t phase; + scalar_t duration; +}; + +struct ContactTiming { + scalar_t start; + scalar_t end; +}; + +struct SwingTiming { + scalar_t start; + scalar_t end; +}; + +/** + * @brief Get the contact phase for all legs. + * If leg in contact, returns a value between 0.0 (at start of contact phase) and 1.0 (at end of contact phase). + * If leg not in contact (i.e. in swing), returns -1.0. + * If mode schedule starts with contact phase, returns 1.0 during this phase. + * If mode schedule ends with contact phase, returns 0.0 during this phase. + * @param [in] time : Query time. + * @param [in] modeSchedule : Mode schedule. + * @return Contact phases for all legs. + */ +feet_array_t getContactPhasePerLeg(scalar_t time, const ocs2::ModeSchedule& modeSchedule); + +/** + * @brief Get the swing phase for all legs. + * If leg in swing, returns a value between 0.0 (at start of swing phase) and 1.0 (at end of swing phase). + * If leg not in swing (i.e. in contact), returns -1.0. + * If mode schedule starts with swing phase, returns 1.0 during this phase. + * If mode schedule ends with swing phase, returns 0.0 during this phase. + * @param [in] time : Query time. + * @param [in] modeSchedule : Mode schedule. + * @return Swing phases for all legs. + */ +feet_array_t getSwingPhasePerLeg(scalar_t time, const ocs2::ModeSchedule& modeSchedule); + +/** Extracts the contact timings for all legs from a modeSchedule */ +feet_array_t> extractContactTimingsPerLeg(const ocs2::ModeSchedule& modeSchedule); + +/** Extracts the swing timings for all legs from a modeSchedule */ +feet_array_t> extractSwingTimingsPerLeg(const ocs2::ModeSchedule& modeSchedule); + +/** Returns time of the next lift off. Returns nan if leg is not lifting off */ +scalar_t getTimeOfNextLiftOff(scalar_t currentTime, const std::vector& contactTimings); + +/** Returns time of the touch down for all legs from a modeschedule. Returns nan if leg does not touch down */ +scalar_t getTimeOfNextTouchDown(scalar_t currentTime, const std::vector& contactTimings); + +/** + * Get {startTime, endTime} for all contact phases. Swingphases are always implied in between: endTime[i] < startTime[i+1] + * times are NaN if they cannot be identified at the boundaries + * Vector is empty if there are no contact phases + */ +std::vector extractContactTimings(const std::vector& eventTimes, const std::vector& contactFlags); + +/** + * Get {startTime, endTime} for all swing phases. Contact phases are always implied in between: endTime[i] < startTime[i+1] + * times are NaN if they cannot be identified at the boundaries + * Vector is empty if there are no swing phases + */ +std::vector extractSwingTimings(const std::vector& eventTimes, const std::vector& contactFlags); + +/** + * Extracts for each leg the contact sequence over the motion phase sequence. + * @param modeSequence : Sequence of contact modes. + * @return Sequence of contact flags per leg. + */ +feet_array_t> extractContactFlags(const std::vector& modeSequence); + +} // namespace legged_robot +} // namespace ocs2 diff --git a/ocs2_robotic_examples/ocs2_legged_robot/include/ocs2_legged_robot/reference_manager/SwitchedModelReferenceManager.h b/ocs2_robotic_examples/ocs2_legged_robot/include/ocs2_legged_robot/reference_manager/SwitchedModelReferenceManager.h index a1ee327d9..9f47651c0 100644 --- a/ocs2_robotic_examples/ocs2_legged_robot/include/ocs2_legged_robot/reference_manager/SwitchedModelReferenceManager.h +++ b/ocs2_robotic_examples/ocs2_legged_robot/include/ocs2_legged_robot/reference_manager/SwitchedModelReferenceManager.h @@ -48,6 +48,8 @@ class SwitchedModelReferenceManager : public ReferenceManager { ~SwitchedModelReferenceManager() override = default; + void setModeSchedule(const ModeSchedule& modeSchedule) override; + contact_flag_t getContactFlags(scalar_t time) const; const std::shared_ptr& getGaitSchedule() { return gaitSchedulePtr_; } diff --git a/ocs2_robotic_examples/ocs2_legged_robot/src/gait/LegLogic.cpp b/ocs2_robotic_examples/ocs2_legged_robot/src/gait/LegLogic.cpp new file mode 100644 index 000000000..d61ae5b71 --- /dev/null +++ b/ocs2_robotic_examples/ocs2_legged_robot/src/gait/LegLogic.cpp @@ -0,0 +1,340 @@ +/****************************************************************************** +Copyright (c) 2021, Farbod Farshidian. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +******************************************************************************/ + +#include "ocs2_legged_robot/gait/LegLogic.h" + +#include "ocs2_legged_robot/gait/MotionPhaseDefinition.h" + +namespace { + +inline ocs2::scalar_t timingNaN() { + return std::numeric_limits::quiet_NaN(); +} + +inline bool hasStartTime(const ocs2::legged_robot::ContactTiming& timing) { + return !std::isnan(timing.start); +} +inline bool hasEndTime(const ocs2::legged_robot::ContactTiming& timing) { + return !std::isnan(timing.end); +} + +inline bool hasStartTime(const ocs2::legged_robot::SwingTiming& timing) { + return !std::isnan(timing.start); +} +inline bool hasEndTime(const ocs2::legged_robot::SwingTiming& timing) { + return !std::isnan(timing.end); +} + +inline bool startsWithSwingPhase(const std::vector& timings) { + return timings.empty() || hasStartTime(timings.front()); +} +inline bool startsWithContactPhase(const std::vector& timings) { + return !startsWithSwingPhase(timings); +} +inline bool endsWithSwingPhase(const std::vector& timings) { + return timings.empty() || hasEndTime(timings.back()); +} +inline bool endsWithContactPhase(const std::vector& timings) { + return !endsWithSwingPhase(timings); +} + +inline bool startsWithContactPhase(const std::vector& timings) { + return timings.empty() || hasStartTime(timings.front()); +} +inline bool startsWithSwingPhase(const std::vector& timings) { + return !startsWithContactPhase(timings); +} +inline bool endsWithContactPhase(const std::vector& timings) { + return timings.empty() || hasEndTime(timings.back()); +} +inline bool endsWithSwingPhase(const std::vector& timings) { + return !endsWithContactPhase(timings); +} + +inline bool touchesDownAtLeastOnce(const std::vector& timings) { + return std::any_of(timings.begin(), timings.end(), [](const ocs2::legged_robot::ContactTiming& timing) { return hasStartTime(timing); }); +} + +inline bool liftsOffAtLeastOnce(const std::vector& timings) { + return !timings.empty() && hasEndTime(timings.front()); +} + +inline bool touchesDownAtLeastOnce(const std::vector& timings) { + return !timings.empty() && hasEndTime(timings.front()); +} + +inline bool liftsOffAtLeastOnce(const std::vector& timings) { + return std::any_of(timings.begin(), timings.end(), [](const ocs2::legged_robot::SwingTiming& timing) { return hasStartTime(timing); }); +} + +} // anonymous namespace + +namespace ocs2 { +namespace legged_robot { + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +feet_array_t getContactPhasePerLeg(scalar_t time, const ocs2::ModeSchedule& modeSchedule) { + feet_array_t contactPhasePerLeg; + + // Convert mode sequence to a contact timing vector per leg + const auto contactTimingsPerLeg = extractContactTimingsPerLeg(modeSchedule); + + // Extract contact phases per leg + for (size_t leg = 0; leg < contactPhasePerLeg.size(); ++leg) { + if (contactTimingsPerLeg[leg].empty()) { + // Leg is always in swing phase + contactPhasePerLeg[leg].phase = -1.0; + contactPhasePerLeg[leg].duration = std::numeric_limits::quiet_NaN(); + } else if (startsWithContactPhase(contactTimingsPerLeg[leg]) && (time <= contactTimingsPerLeg[leg].front().end)) { + // It is assumed that contact phase started at minus infinity, so current time will be always close to ContactTiming.end + contactPhasePerLeg[leg].phase = 1.0; + contactPhasePerLeg[leg].duration = std::numeric_limits::infinity(); + } else if (endsWithContactPhase(contactTimingsPerLeg[leg]) && (time >= contactTimingsPerLeg[leg].back().start)) { + // It is assumed that contact phase ends at infinity, so current time will be always close to ContactTiming.start + contactPhasePerLeg[leg].phase = 0.0; + contactPhasePerLeg[leg].duration = std::numeric_limits::infinity(); + } else { + // Check if leg is in contact interval at current time + auto it = std::find_if(contactTimingsPerLeg[leg].begin(), contactTimingsPerLeg[leg].end(), + [time](ContactTiming timing) { return (timing.start <= time) && (time <= timing.end); }); + if (it == contactTimingsPerLeg[leg].end()) { + // Leg is not in contact for current time + contactPhasePerLeg[leg].phase = -1.0; + contactPhasePerLeg[leg].duration = std::numeric_limits::quiet_NaN(); + } else { + // Leg is in contact for current time + const auto& currentContactTiming = *it; + contactPhasePerLeg[leg].phase = (time - currentContactTiming.start) / (currentContactTiming.end - currentContactTiming.start); + contactPhasePerLeg[leg].duration = currentContactTiming.end - currentContactTiming.start; + } + } + } + + return contactPhasePerLeg; +} + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +feet_array_t getSwingPhasePerLeg(scalar_t time, const ocs2::ModeSchedule& modeSchedule) { + feet_array_t swingPhasePerLeg; + + // Convert mode sequence to a swing timing vector per leg + const auto swingTimingsPerLeg = extractSwingTimingsPerLeg(modeSchedule); + + // Extract swing phases per leg + for (size_t leg = 0; leg < swingPhasePerLeg.size(); ++leg) { + if (swingTimingsPerLeg[leg].empty()) { + // Leg is always in contact phase + swingPhasePerLeg[leg].phase = -1.0; + swingPhasePerLeg[leg].duration = std::numeric_limits::quiet_NaN(); + } else if (startsWithSwingPhase(swingTimingsPerLeg[leg]) && (time <= swingTimingsPerLeg[leg].front().end)) { + // It is assumed that swing phase started at minus infinity, so current time will be always close to SwingTiming.end + swingPhasePerLeg[leg].phase = 1.0; + swingPhasePerLeg[leg].duration = std::numeric_limits::infinity(); + } else if (endsWithSwingPhase(swingTimingsPerLeg[leg]) && (time >= swingTimingsPerLeg[leg].back().start)) { + // It is assumed that swing phase ends at infinity, so current time will be always close to SwingTiming.start + swingPhasePerLeg[leg].phase = 0.0; + swingPhasePerLeg[leg].duration = std::numeric_limits::infinity(); + } else { + // Check if leg is in swing interval at current time + auto it = std::find_if(swingTimingsPerLeg[leg].begin(), swingTimingsPerLeg[leg].end(), + [time](SwingTiming timing) { return (timing.start <= time) && (time <= timing.end); }); + if (it == swingTimingsPerLeg[leg].end()) { + // Leg is not swinging for current time + swingPhasePerLeg[leg].phase = scalar_t(-1.0); + swingPhasePerLeg[leg].duration = std::numeric_limits::quiet_NaN(); + } else { + // Leg is swinging for current time + const auto& currentSwingTiming = *it; + swingPhasePerLeg[leg].phase = (time - currentSwingTiming.start) / (currentSwingTiming.end - currentSwingTiming.start); + swingPhasePerLeg[leg].duration = currentSwingTiming.end - currentSwingTiming.start; + } + } + } + + return swingPhasePerLeg; +} + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +feet_array_t> extractContactTimingsPerLeg(const ocs2::ModeSchedule& modeSchedule) { + feet_array_t> contactTimingsPerLeg; + + // Convert mode sequence to a contact flag vector per leg + const auto contactSequencePerLeg = extractContactFlags(modeSchedule.modeSequence); + + // Extract timings per leg + for (size_t leg = 0; leg < contactTimingsPerLeg.size(); ++leg) { + contactTimingsPerLeg[leg] = extractContactTimings(modeSchedule.eventTimes, contactSequencePerLeg[leg]); + } + + return contactTimingsPerLeg; +} + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +feet_array_t> extractSwingTimingsPerLeg(const ocs2::ModeSchedule& modeSchedule) { + feet_array_t> swingTimingsPerLeg; + + // Convert mode sequence to a contact flag vector per leg + const auto contactSequencePerLeg = extractContactFlags(modeSchedule.modeSequence); + + // Extract timings per leg + for (size_t leg = 0; leg < swingTimingsPerLeg.size(); ++leg) { + swingTimingsPerLeg[leg] = extractSwingTimings(modeSchedule.eventTimes, contactSequencePerLeg[leg]); + } + + return swingTimingsPerLeg; +} + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +scalar_t getTimeOfNextLiftOff(scalar_t currentTime, const std::vector& contactTimings) { + for (const auto& contactPhase : contactTimings) { + if (hasEndTime(contactPhase) && contactPhase.end > currentTime) { + return contactPhase.end; + } + } + return timingNaN(); +} + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +scalar_t getTimeOfNextTouchDown(scalar_t currentTime, const std::vector& contactTimings) { + for (const auto& contactPhase : contactTimings) { + if (hasStartTime(contactPhase) && contactPhase.start > currentTime) { + return contactPhase.start; + } + } + return timingNaN(); +} + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +std::vector extractContactTimings(const std::vector& eventTimes, const std::vector& contactFlags) { + assert(eventTimes.size() + 1 == contactFlags.size()); + const int numPhases = contactFlags.size(); + + std::vector contactTimings; + contactTimings.reserve(1 + eventTimes.size() / 2); // Approximate upper bound + int currentPhase = 0; + + while (currentPhase < numPhases) { + // Search where contact phase starts + while (currentPhase < numPhases && !contactFlags[currentPhase]) { + ++currentPhase; + } + if (currentPhase >= numPhases) { + break; // No more contact phases + } + + // Register start of the contact phase + const scalar_t startTime = (currentPhase == 0) ? std::numeric_limits::quiet_NaN() : eventTimes[currentPhase - 1]; + + // Find when the contact phase ends + while (currentPhase + 1 < numPhases && contactFlags[currentPhase + 1]) { + ++currentPhase; + } + + // Register end of the contact phase + const scalar_t endTime = (currentPhase + 1 >= numPhases) ? std::numeric_limits::quiet_NaN() : eventTimes[currentPhase]; + + // Add to phases + contactTimings.push_back({startTime, endTime}); + ++currentPhase; + } + return contactTimings; +} + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +std::vector extractSwingTimings(const std::vector& eventTimes, const std::vector& contactFlags) { + assert(eventTimes.size() + 1 == contactFlags.size()); + const int numPhases = contactFlags.size(); + + std::vector swingTimings; + swingTimings.reserve(1 + eventTimes.size() / 2); // Approximate upper bound + int currentPhase = 0; + + while (currentPhase < numPhases) { + // Search where swing phase starts + while (currentPhase < numPhases && contactFlags[currentPhase]) { + ++currentPhase; + } + if (currentPhase >= numPhases) { + break; // No more swing phases + } + + // Register start of the swing phase + const scalar_t startTime = (currentPhase == 0) ? std::numeric_limits::quiet_NaN() : eventTimes[currentPhase - 1]; + + // Find when the swing phase ends + while (currentPhase + 1 < numPhases && !contactFlags[currentPhase + 1]) { + ++currentPhase; + } + + // Register end of the contact phase + const scalar_t endTime = (currentPhase + 1 >= numPhases) ? std::numeric_limits::quiet_NaN() : eventTimes[currentPhase]; + + // Add to phases + swingTimings.push_back({startTime, endTime}); + ++currentPhase; + } + return swingTimings; +} + +/******************************************************************************************************/ +/******************************************************************************************************/ +/******************************************************************************************************/ +feet_array_t> extractContactFlags(const std::vector& modeSequence) { + const size_t numPhases = modeSequence.size(); + + feet_array_t> contactFlagStock; + std::fill(contactFlagStock.begin(), contactFlagStock.end(), std::vector(numPhases)); + + for (size_t i = 0; i < numPhases; i++) { + const auto contactFlag = modeNumber2StanceLeg(modeSequence[i]); + for (size_t j = 0; j < contactFlagStock.size(); j++) { + contactFlagStock[j][i] = contactFlag[j]; + } + } + return contactFlagStock; +} + +} // namespace legged_robot +} // namespace ocs2 diff --git a/ocs2_robotic_examples/ocs2_legged_robot/src/reference_manager/SwitchedModelReferenceManager.cpp b/ocs2_robotic_examples/ocs2_legged_robot/src/reference_manager/SwitchedModelReferenceManager.cpp index 0fbf8a61f..2e7907ed8 100644 --- a/ocs2_robotic_examples/ocs2_legged_robot/src/reference_manager/SwitchedModelReferenceManager.cpp +++ b/ocs2_robotic_examples/ocs2_legged_robot/src/reference_manager/SwitchedModelReferenceManager.cpp @@ -41,6 +41,14 @@ SwitchedModelReferenceManager::SwitchedModelReferenceManager(std::shared_ptrsetModeSchedule(modeSchedule); +} + /******************************************************************************************************/ /******************************************************************************************************/ /******************************************************************************************************/