diff --git a/CMakeLists.txt b/CMakeLists.txt
new file mode 100644
index 00000000..23c6316b
--- /dev/null
+++ b/CMakeLists.txt
@@ -0,0 +1,47 @@
+cmake_minimum_required(VERSION 3.8)
+
+set(CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/cmake")
+
+project(tarantella VERSION 0.6.0)
+
+set(CMAKE_CXX_STANDARD 17)
+set(CMAKE_CXX_STANDARD_REQUIRED ON)
+set(CMAKE_CXX_FLAGS "-O3 -Wall -Wextra -Werror")
+
+option(LINK_IB "Defines whether to link against Infiniband drivers [default: disabled]" off)
+option(ENABLE_TESTING "Compile tests [default: disabled]" off)
+option(BUILD_DOCS "Build documentation [default: disabled]" off)
+
+set(SRC_DIR "${CMAKE_SOURCE_DIR}/src")
+set(CMAKE_BUILD_DIR "${CMAKE_SOURCE_DIR}/build")
+set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
+
+set(INSTALL_LIB_DIR "${CMAKE_INSTALL_PREFIX}/lib/tarantella")
+set(INSTALL_BIN_DIR "${CMAKE_INSTALL_PREFIX}/bin")
+
+find_package(GPI2 REQUIRED)
+find_package(pybind11 REQUIRED)
+find_package(Tensorflow REQUIRED)
+
+add_subdirectory(${SRC_DIR})
+add_subdirectory(${SRC_DIR}/gpi_comm_lib/gpi)
+add_subdirectory(${SRC_DIR}/gpi_comm_lib/collectives)
+add_subdirectory(${SRC_DIR}/gpi_comm_lib)
+add_subdirectory(${SRC_DIR}/gpi_comm_lib/tf_ops)
+
+if (BUILD_DOCS)
+ find_package(Sphinx)
+ add_subdirectory(docs)
+endif()
+
+if (ENABLE_TESTING)
+ find_package(Boost 1.61 REQUIRED COMPONENTS
+ unit_test_framework)
+ find_package(PythonModules REQUIRED COMPONENTS
+ numpy
+ pytest)
+ enable_testing()
+ set(SLEEP_TIME_AFTER_TEST 4)
+ add_subdirectory(${CMAKE_SOURCE_DIR}/test)
+endif()
+
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 00000000..67da183f
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,101 @@
+TARANTELLA END USER LICENSE AGREEMENT
+October 21, 2020
+
+PLEASE READ THIS LICENSE AGREEMENT CAREFULLY. BY USING THE SOFTWARE TARANTELLA YOU
+ACCEPT ALL TERMS OF THE LICENSE AGREEMENT. IF YOU DO NOT AGREE TO THE TERMS OF
+THIS LICENSE, DO NOT INSTALL, COPY, OR USE THE SOFTWARE.
+
+1.) DEFINITIONS
+
+1.1) LICENSOR: Fraunhofer Gesellschaft zur Foerderung der angewandten Forschung
+e.V., Hansastr. 27c, 80686 Muenchen, Germany, as legal entity of Fraunhofer-
+Institut fuer Techno- und Wirtschaftsmathematik, Fraunhofer-Platz 1,
+67663 Kaiserslautern, Germany.
+
+1.2) LICENSEE: The user of Tarantella under this License Agreement.
+
+1.3) LICENSED SOFTWARE: The Software Tarantella in source code and object code form
+including all executable programs.
+
+1.4) DOCUMENTATION: The Tarantella documentation, user's guide, e-mails and other explanatory
+materials accompanying the LICENSED SOFTWARE in printed or electronic form.
+
+2.) OWNERSHIP / INTELLECTUAL PROPERTY RIGHTS
+
+LICENSEE acknowledges that ownership and all intellectual property rights
+related to the LICENSED SOFTWARE and to the DOCUMENTATION, including patents,
+copyright, company or trade secrets remain with the LICENSOR.
+
+LICENSEE promises to keep and not to modify the copyright notices of the
+LICENSOR.
+
+3.) SCOPE OF LICENSE
+
+3.1) Provided LICENSEE accepts all terms of this License Agreement, LICENSEE
+is granted a non-exclusive, non-assignable right to use the LICENSED SOFTWARE,
+which means LICENSEE may use the software for an unrestricted number of users,
+as well as use the accompanying DOCUMENTATION by the actual number of users.
+
+3.2) Without prior written consent of LICENSOR or an authorized partner,
+LICENSEE may modify the source code and use the modified version of the LICENSED
+SOFTWARE for internal use only.
+
+3.2.1) LICENSEE must inform users of modified versions about the fact that the
+software differs from the original version.
+
+3.2.2) The LICENSED SOFTWARE and the modifications generated by LICENSEE shall
+remain the property of LICENSOR and no rights, including but not limited to the
+right to apply for industrial property rights, are granted to LICENSEE.
+
+3.3) Without prior written consent of LICENSOR or an authorized partner,
+LICENSEE may not:
+- use, copy or distribute the LICENSED SOFTWARE except as provided for under
+ sections 3.1 and 3.2.
+- provide commercial turn-key solutions based on the LICENSED SOFTWARE or
+ commercial services for the LICENSED SOFTWARE to any third party.
+- rent or lease the LICENSED SOFTWARE and DOCUMENTATION to any third party.
+- modify, adapt, or translate the LICENSED SOFTWARE for any third party.
+
+3.4) The license under this License Agreement relates to the LICENSED SOFTWARE.
+
+4.) LIMITED WARRANTY AND LIABILITY
+
+4.1) LICENSOR confirms that the LICENSED SOFTWARE has been developed without
+infringement of any rights of third parties, in particular patents, copyrights
+or other intellectual property rights of third parties. Nevertheless LICENSOR
+does not warrant that the use of the LICENSED SOFTWARE by LICENSEE does not
+infringe any third party intellectual property rights.
+
+4.2) LICENSEE is aware that there is a risk that the LICENSED SOFTWARE might
+damage the data or the computer of the LICENSEE or even other computers on the
+network in unpredictable ways. The use of the LICENSED SOFTWARE is at the
+exclusive risk of the LICENSEE. LICENSOR does not offer any warranty either
+expressed or implied and is not liable for any damages resulting from the use of
+the LICENSED SOFTWARE or DOCUMENTATION such as, but not limited to, data loss.
+
+4.3) Notwithstanding sections 4.1 and 4.2, the liability of the LICENSOR, its
+legal representatives and employees resulting from breach of duty or tort is
+restricted to damages caused intentionally or by gross negligence. In any case,
+the liability under this section is limited by typical, foreseeable, direct
+damages. The liability is unrestricted for damages of the body, life or health.
+
+5.) MISCELLANEOUS
+
+This License Agreement in English is the original one. The terms of this
+Agreement can only be modified or amended in writing. In case of interpretation
+controversies the terms of this Agreement shall prevail over the respective
+terms of any other agreements.
+
+This Agreement is construed under the Law of the Federal Republic of Germany.
+Therefore, any and all controversies resulting out of this Agreement shall be
+resolved under the Law of the Federal Republic of Germany excluding the German
+International Private Law Rules. The application of the UN-Convention of the
+International Sales of Goods (CISG) is explicitly excluded. Exclusive venue of
+jurisdiction for both parties shall be Munich, Germany.
+
+In case that one or several of the terms of this Agreement should be or become
+invalid or unenforceable, the validity of the other terms shall remain
+unaffected. In such a case, the parties shall replace the invalid or
+unenforceable condition by another legally effective provision meeting the
+purpose of the abolished provision to the greatest extent. The same applies in
+case of a gap of regulation.
diff --git a/README.md b/README.md
new file mode 100644
index 00000000..86c77563
--- /dev/null
+++ b/README.md
@@ -0,0 +1,41 @@
+![Tarantella](docs/source/pics/tnt_logo_text.png)
+
+
+
+Tarantella is an open-source, distributed Deep Learning framework built on top of TensorFlow 2,
+providing scalable Deep Neural Network training on CPU and GPU compute clusters.
+
+Tarantella is easy-to-use, allows to re-use existing TensorFlow 2/Keras models,
+and does not require any knowledge of parallel computing.
+
+
+## Goals
+
+Tarantella is designed to meet the following goals:
+
+* strong scalability
+* ease of use
+* synchronous training scheme
+* seamless integration with existing Keras models
+* support for GPU and CPU systems
+
+## Install
+
+To build Tarantella from source, the following dependencies are required:
+
+* [TensorFlow 2](https://www.tensorflow.org/install) (supported versions TF2.2, TF2.1, TF2.0)
+* [GPI-2](https://github.com/cc-hpc-itwm/GPI-2) (version 1.4.0)
+* [pybind11](https://github.com/pybind/pybind11) (from version 2.4.3)
+* C++ compiler (e.g., `gcc` from version 7.4.0)
+* CMake (from version 3.8)
+
+Detailed installation instructions can be found in the [technical docs](https://tarantella.readthedocs.io/en/latest/installation.html).
+
+## Resources
+
+* [Official website](https://www.tarantella.org)
+* [Technical documentation](https://tarantella.readthedocs.io/en/latest)
+
+## License
+
+[License](LICENSE)
diff --git a/cmake/FindDNNL.cmake b/cmake/FindDNNL.cmake
new file mode 100644
index 00000000..1f015465
--- /dev/null
+++ b/cmake/FindDNNL.cmake
@@ -0,0 +1,37 @@
+# Finds Intel DNNL library
+# Martin Kuehn May 2020
+
+find_path(DNNL_INCLUDE_DIR
+ NAMES dnnl.hpp
+ PATHS ${DNNL_ROOT}
+ ENV DNNL_ROOT
+ ${DNNL_ROOT_DIR}
+ ENV DNNL_ROOT_DIR
+ PATH_SUFFIXES include
+ DOC "DNNL header files"
+)
+
+find_library(DNNL_LIBRARY dnnl
+ PATHS ${DNNL_ROOT}
+ ENV DNNL_ROOT
+ ${DNNL_ROOT_DIR}
+ ENV DNNL_ROOT_DIR
+ PATH_SUFFIXES lib lib64
+ DOC "DNNL library files")
+
+#include (FindPackageHandleStandardArgs)
+find_package_handle_standard_args(DNNL
+ DEFAULT_MSG
+ DNNL_LIBRARY
+ DNNL_INCLUDE_DIR)
+
+mark_as_advanced(DNNL_INCLUDE_DIR DNNL_LIBRARY)
+
+set(DNNL_INCLUDE_DIRS ${DNNL_INCLUDE_DIR})
+set(DNNL_LIBRARIES ${DNNL_LIBRARY})
+
+if(DNNL_FOUND AND NOT TARGET dnnl)
+ add_library(dnnl SHARED IMPORTED GLOBAL)
+ target_include_directories(dnnl INTERFACE ${DNNL_INCLUDE_DIRS})
+ set_property(TARGET dnnl PROPERTY IMPORTED_LOCATION ${DNNL_LIBRARIES})
+endif()
diff --git a/cmake/FindGPI2.cmake b/cmake/FindGPI2.cmake
new file mode 100644
index 00000000..d4bd3360
--- /dev/null
+++ b/cmake/FindGPI2.cmake
@@ -0,0 +1,133 @@
+
+#[=======================================================================[.rst:
+FindGPI2
+-------
+
+Finds the GPI2 library.
+
+Imported Targets
+^^^^^^^^^^^^^^^^
+
+This module provides the following imported targets, if found:
+
+``GPI2::GPI2``
+ The GPI2 library
+
+Result Variables
+^^^^^^^^^^^^^^^^
+
+This will define the following variables:
+
+``GPI2_FOUND``
+ True if the system has the GPI2 library.
+``GPI2_INCLUDE_DIRS``
+ Include directories needed to use GPI2.
+``GPI2_LIBRARIES``
+ Libraries needed to link to GPI2.
+``GPI2_DBG_LIBRARIES``
+ Libraries needed to link to the Debug version of GPI2.
+``GPI2_GASPI_RUN``
+ Path to ``gaspi_run``.
+
+Cache Variables
+^^^^^^^^^^^^^^^
+
+The following cache variables may also be set:
+
+``GPI2_INCLUDE_DIR``
+ The directory containing ``gaspi.h``.
+``GPI2_LIBRARY``
+ The path to the GPI2 library.
+
+#]=======================================================================]
+
+set(GPI2_LIBRARY_NAME "GPI2")
+set(GPI2_DBG_LIBRARY_NAME "GPI2-dbg")
+
+FIND_PROGRAM(GASPIRUN_PATH gaspi_run
+ PATHS
+ $ENV{PATH}
+ $ENV{LIB_DIR}/bin
+ /usr/local/bin/
+ /usr/bin/
+ )
+
+IF (GASPIRUN_PATH)
+ get_filename_component(GASPIRUN_FOUND_HOME ${GASPIRUN_PATH} DIRECTORY)
+ get_filename_component(GPI2_INSTALLED_PATH ${GASPIRUN_FOUND_HOME} DIRECTORY)
+ get_filename_component(GPI2_INSTALLED_PATH ${GPI2_INSTALLED_PATH} REALPATH)
+ENDIF(GASPIRUN_PATH)
+
+find_path (GPI2_INCLUDE_DIR GASPI.h
+ PATHS ${GPI2_DEFAULT_PATH} ${GPI2_INSTALLED_PATH}
+ PATHS ENV LD_LIBRARY_PATH DYLD_LIBRARY_PATH
+ PATH_SUFFIXES include)
+
+find_library (GPI2_DBG_LIBRARY ${GPI2_DBG_LIBRARY_NAME}
+ PATHS ${GPI2_DEFAULT_PATH} ${GPI2_INSTALLED_PATH}
+ PATHS ENV LD_LIBRARY_PATH DYLD_LIBRARY_PATH
+ PATH_SUFFIXES lib lib64)
+
+find_library (GPI2_LIBRARY ${GPI2_LIBRARY_NAME}
+ PATHS ${GPI2_DEFAULT_PATH} ${GPI2_INSTALLED_PATH}
+ PATHS ENV LD_LIBRARY_PATH DYLD_LIBRARY_PATH
+ PATH_SUFFIXES lib lib64)
+
+if (GPI2_DBG_LIBRARY)
+ message(STATUS "GPI2-dbg library path: ${GPI2_DBG_LIBRARY}" )
+else(GPI2_DBG_LIBRARY)
+ message(STATUS "GPI2-dbg library path: not found" )
+endif()
+
+
+if (GPI2_LIBRARY)
+ message(STATUS "GPI2 library path: ${GPI2_LIBRARY}" )
+else(GPI2_LIBRARY)
+ message(STATUS "GPI2 library path: not found" )
+endif()
+
+
+include(FindPackageHandleStandardArgs)
+# handle the QUIETLY and REQUIRED arguments and set GPI2_FOUND to TRUE
+# if all listed variables are TRUE
+find_package_handle_standard_args(GPI2 DEFAULT_MSG
+ GASPIRUN_PATH
+ GPI2_DBG_LIBRARY GPI2_LIBRARY)
+
+mark_as_advanced(GPI2_INCLUDE_DIR GASPIRUN_PATH
+ GPI2_DBG_LIBRARY GPI2_LIBRARY)
+set(GPI2_INCLUDE_DIRS ${GPI2_INCLUDE_DIR} )
+set(GPI2_DBG_LIBRARIES ${GPI2_DBG_LIBRARY} )
+set(GPI2_LIBRARIES ${GPI2_LIBRARY} )
+set(GPI2_GASPI_RUN ${GASPIRUN_PATH})
+
+message(STATUS "Found GPI2: " ${GPI2_FOUND})
+
+if(GPI2_FOUND AND NOT TARGET GPI2::GPI2)
+ set(THREADS_PREFER_PTHREAD_FLAG ON)
+ find_package(Threads REQUIRED)
+ add_library(GPI2::GPI2 SHARED IMPORTED GLOBAL)
+ target_link_libraries(GPI2::GPI2 INTERFACE Threads::Threads)
+ target_include_directories(GPI2::GPI2 INTERFACE ${GPI2_INCLUDE_DIRS})
+ set_property(TARGET GPI2::GPI2 PROPERTY IMPORTED_LOCATION ${GPI2_LIBRARIES})
+
+ add_library(GPI2::GPI2dbg SHARED IMPORTED GLOBAL)
+ target_link_libraries(GPI2::GPI2dbg INTERFACE Threads::Threads)
+ target_include_directories(GPI2::GPI2dbg INTERFACE ${GPI2_INCLUDE_DIRS})
+ set_property(TARGET GPI2::GPI2dbg PROPERTY IMPORTED_LOCATION ${GPI2_DBG_LIBRARIES})
+
+ if (LINK_IB)
+ find_package(IBverbs)
+
+ if (IBverbs_FOUND)
+ message (STATUS "GPI2: linking against ibverbs")
+ target_link_libraries(GPI2::GPI2 INTERFACE IBverbs::IBverbs)
+ target_link_libraries(GPI2::GPI2dbg INTERFACE IBverbs::IBverbs)
+ else()
+ message (FATAL_ERROR "GPI2: could not find ibverbs, disable Infiniband \
+ support (-DLINK_IB=OFF) to load GPI-2")
+ endif()
+ else()
+ message (STATUS "GPI2: loading library without Infiniband support")
+ endif()
+endif()
diff --git a/cmake/FindIBverbs.cmake b/cmake/FindIBverbs.cmake
new file mode 100644
index 00000000..aeb205e6
--- /dev/null
+++ b/cmake/FindIBverbs.cmake
@@ -0,0 +1,61 @@
+
+#[=======================================================================[.rst:
+FindIBverbs
+-------
+
+Finds the IBverbs library.
+
+Imported Targets
+^^^^^^^^^^^^^^^^
+
+This module provides the following imported targets, if found:
+
+``IBverbs::IBverbs``
+ The IBverbs library
+
+Result Variables
+^^^^^^^^^^^^^^^^
+
+This will define the following variables:
+
+``IBverbs_FOUND``
+ True if the system has the IBverbs library.
+``IBverbs_INCLUDE_DIRS``
+ Include directories needed to use IBverbs.
+``IBverbs_LIBRARIES``
+ Libraries needed to link to IBverbs.
+
+Cache Variables
+^^^^^^^^^^^^^^^
+
+The following cache variables may also be set:
+
+``IBverbs_INCLUDE_DIR``
+ The directory containing the public headers.
+``IBverbs_LIBRARY``
+ The path to the IBverbs library.
+
+#]=======================================================================]
+
+find_path(IBverbs_INCLUDE_DIR
+ NAMES infiniband/verbs.h
+ )
+
+find_library(IBverbs_LIBRARY
+ NAMES ibverbs)
+
+include(FindPackageHandleStandardArgs)
+# handle the QUIETLY and REQUIRED arguments and set IBverbs_FOUND to TRUE
+# if all listed variables are TRUE
+find_package_handle_standard_args(IBverbs DEFAULT_MSG
+ IBverbs_INCLUDE_DIR IBverbs_LIBRARY)
+
+mark_as_advanced(IBverbs_INCLUDE_DIR IBverbs_LIBRARY)
+set(IBverbs_LIBRARIES ${IBverbs_LIBRARY})
+set(IBverbs_INCLUDE_DIRS ${IBverbs_INCLUDE_DIR})
+
+if(IBverbs_FOUND AND NOT TARGET IBverbs::IBverbs)
+ add_library(IBverbs::IBverbs SHARED IMPORTED GLOBAL)
+ target_include_directories(IBverbs::IBverbs INTERFACE ${IBverbs_INCLUDE_DIRS})
+ set_property(TARGET IBverbs::IBverbs PROPERTY IMPORTED_LOCATION ${IBverbs_LIBRARIES})
+endif()
diff --git a/cmake/FindPythonModules.cmake b/cmake/FindPythonModules.cmake
new file mode 100644
index 00000000..3cb0ed11
--- /dev/null
+++ b/cmake/FindPythonModules.cmake
@@ -0,0 +1,60 @@
+#[=======================================================================[.rst:
+FindPythonModules
+-------
+
+Finds installed PythonModules
+
+Result Variables
+^^^^^^^^^^^^^^^^
+
+This will define the following variables:
+
+``PythonModules_FOUND``
+ True if all the required PythonModules could be loaded.
+``PythonModules_modulename_FOUND``
+ True if `modulename` could be loaded.
+``Python_EXECUTABLE``
+ Path to the Python executable.
+
+Cache Variables
+^^^^^^^^^^^^^^^
+
+The following cache variables may also be set:
+
+``GPI2_INCLUDE_DIR``
+ The directory containing ``gaspi.h``.
+``GPI2_LIBRARY``
+ The path to the GPI2 library.
+
+#]=======================================================================]
+
+execute_process(COMMAND sh -c "which python"
+ OUTPUT_VARIABLE python_path
+ RESULT_VARIABLE result
+ ERROR_QUIET
+ OUTPUT_STRIP_TRAILING_WHITESPACE)
+if (result EQUAL "0" AND EXISTS ${python_path})
+ set(Python_EXECUTABLE "${python_path}")
+endif()
+
+set(PythonModules_FOUND TRUE)
+if (Python_EXECUTABLE)
+ foreach (module IN LISTS PythonModules_FIND_COMPONENTS)
+ execute_process(COMMAND ${Python_EXECUTABLE} -c
+ "import ${module}"
+ RESULT_VARIABLE result
+ ERROR_QUIET OUTPUT_QUIET)
+
+ if(result)
+ set (PythonModules_${module}_FOUND FALSE)
+ set (PythonModules_FOUND FALSE)
+ else()
+ set (PythonModules_${module}_FOUND TRUE)
+ endif()
+ endforeach()
+endif()
+
+include (FindPackageHandleStandardArgs)
+find_package_handle_standard_args (PythonModules
+ REQUIRED_VARS Python_EXECUTABLE PythonModules_FOUND
+ HANDLE_COMPONENTS)
diff --git a/cmake/FindSphinx.cmake b/cmake/FindSphinx.cmake
new file mode 100644
index 00000000..406dc8bb
--- /dev/null
+++ b/cmake/FindSphinx.cmake
@@ -0,0 +1,16 @@
+include(FindPackageHandleStandardArgs)
+
+find_program(Sphinx_EXECUTABLE
+ NAMES sphinx-build sphinx-build2
+ DOC "Path to sphinx-build executable")
+
+find_package_handle_standard_args(Sphinx REQUIRED_VARS Sphinx_EXECUTABLE)
+
+if (Sphinx_FOUND)
+ mark_as_advanced(Sphinx_EXECUTABLE)
+endif()
+
+if (Sphinx_FOUND AND NOT TARGET Sphinx::Sphinx)
+ add_executable(Sphinx::Sphinx IMPORTED)
+ set_property(TARGET Sphinx::Sphinx PROPERTY IMPORTED_LOCATION ${Sphinx_EXECUTABLE})
+endif()
diff --git a/cmake/FindTensorflow.cmake b/cmake/FindTensorflow.cmake
new file mode 100644
index 00000000..4afa4616
--- /dev/null
+++ b/cmake/FindTensorflow.cmake
@@ -0,0 +1,106 @@
+
+#[=======================================================================[.rst:
+FindTensorflow
+-------
+
+Finds the Tensorflow package as described in:
+https://www.tensorflow.org/guide/create_op#compile_the_op_using_your_system_compiler_tensorflow_binary_installation
+
+
+Imported Targets
+^^^^^^^^^^^^^^^^
+
+This module provides the following imported targets, if found:
+
+``Tensorflow::Tensorflow``
+ The Tensorflow library.
+ The target will set the CXX11_ABI_FLAG according to the ABI used to compile the TensorFlow library.
+
+Result Variables
+^^^^^^^^^^^^^^^^
+
+This will define the following variables:
+
+``Tensorflow_FOUND``
+ True if the system has the Tensorflow library.
+``Tensorflow_INCLUDE_DIRS``
+ Include directories needed to use Tensorflow.
+``Tensorflow_LIBRARIES``
+ Libraries needed to link to Tensorflow.
+
+Cache Variables
+^^^^^^^^^^^^^^^
+
+The following cache variables may also be set:
+
+``Tensorflow_INCLUDE_DIR``
+ The directory containing the Tensorflow library headers.
+``Tensorflow_LIBRARY``
+ The path to the Tensorflow library.
+
+#]=======================================================================]
+
+execute_process(COMMAND sh -c "which python"
+ OUTPUT_VARIABLE python_path
+ RESULT_VARIABLE result
+ ERROR_QUIET
+ OUTPUT_STRIP_TRAILING_WHITESPACE)
+if (result EQUAL "0" AND EXISTS ${python_path})
+ set(Python_EXECUTABLE "${python_path}")
+endif()
+
+if (Python_EXECUTABLE)
+ execute_process(COMMAND ${Python_EXECUTABLE} -c
+ "import tensorflow as tf; print(tf.sysconfig.get_lib())"
+ OUTPUT_STRIP_TRAILING_WHITESPACE
+ RESULT_VARIABLE result_tf_lib
+ OUTPUT_VARIABLE Tensorflow_LIBRARY_DIR
+ ERROR_QUIET)
+
+ execute_process(COMMAND ${Python_EXECUTABLE} -c
+ "import tensorflow as tf; print(tf.sysconfig.get_include())"
+ OUTPUT_STRIP_TRAILING_WHITESPACE
+ RESULT_VARIABLE result_tf_incl
+ OUTPUT_VARIABLE Tensorflow_INCLUDE_DIR
+ ERROR_QUIET)
+
+ execute_process(COMMAND ${Python_EXECUTABLE} -c
+ "import tensorflow as tf; print(tf.sysconfig.CXX11_ABI_FLAG)"
+ OUTPUT_STRIP_TRAILING_WHITESPACE
+ RESULT_VARIABLE result_tf_abi_flag
+ OUTPUT_VARIABLE Tensorflow_CXX11_ABI_FLAG
+ ERROR_QUIET)
+endif()
+
+set(Tensorflow_LIBRARY_NAME libtensorflow_framework.so.2)
+find_library (Tensorflow_LIBRARY ${Tensorflow_LIBRARY_NAME}
+ PATHS ${Tensorflow_LIBRARY_DIR}
+ PATHS ENV LD_LIBRARY_PATH DYLD_LIBRARY_PATH)
+
+include(FindPackageHandleStandardArgs)
+find_package_handle_standard_args(Tensorflow DEFAULT_MSG
+ Tensorflow_LIBRARY
+ Tensorflow_INCLUDE_DIR)
+
+mark_as_advanced(Tensorflow_INCLUDE_DIR Tensorflow_LIBRARY)
+set(Tensorflow_INCLUDE_DIRS ${Tensorflow_INCLUDE_DIR} )
+set(Tensorflow_LIBRARIES ${Tensorflow_LIBRARY} )
+
+message(STATUS "Found Tensorflow: " ${Tensorflow_FOUND})
+
+if(Tensorflow_FOUND AND NOT TARGET tensorflow_framework)
+ add_library(Tensorflow::Tensorflow SHARED IMPORTED GLOBAL)
+ target_include_directories(Tensorflow::Tensorflow INTERFACE ${Tensorflow_INCLUDE_DIRS})
+ set_property(TARGET Tensorflow::Tensorflow PROPERTY IMPORTED_LOCATION ${Tensorflow_LIBRARIES})
+
+ # Enable libraries that link against the TensorFlow library to use
+ # the correct value of the CXX11_ABI_FLAG.
+ # E.g., the official pip TensorFlow packages require CXX11_ABI_FLAG=0,
+ # whereas the conda packages set CXX11_ABI_FLAG=1.
+ if ("${result_tf_abi_flag}" EQUAL "0")
+ target_compile_definitions(Tensorflow::Tensorflow INTERFACE _GLIBCXX_USE_CXX11_ABI=${Tensorflow_CXX11_ABI_FLAG})
+ endif()
+endif()
+
+
+
diff --git a/cmake/add_macros.cmake b/cmake/add_macros.cmake
new file mode 100644
index 00000000..12411693
--- /dev/null
+++ b/cmake/add_macros.cmake
@@ -0,0 +1,101 @@
+
+macro (_default_if_unset VAR VAL)
+ if (NOT ${VAR})
+ set (${VAR} ${VAL})
+ endif()
+endmacro()
+
+include (parse_arguments)
+
+function (extended_add_library)
+ set (options POSITION_INDEPENDENT PRECOMPILED INSTALL)
+ set (one_value_options NAME NAMESPACE TYPE INSTALL_DESTINATION)
+ set (multi_value_options
+ LIBRARIES SOURCES PUBLIC_HEADERS INCLUDE_DIRECTORIES RPATH
+ SYSTEM_INCLUDE_DIRECTORIES COMPILE_DEFINITIONS COMPILE_OPTIONS DEPENDS
+ )
+ set (required_options NAME)
+ _parse_arguments (ARG "${options}" "${one_value_options}" "${multi_value_options}" "${required_options}" ${ARGN})
+
+ _default_if_unset (ARG_TYPE "STATIC")
+ _default_if_unset (ARG_INSTALL_DESTINATION "lib")
+
+ if (ARG_NAMESPACE)
+ set (target_name "${ARG_NAMESPACE}-${ARG_NAME}")
+ else()
+ set (target_name "${ARG_NAME}")
+ endif()
+
+ if (NOT (${ARG_TYPE} STREQUAL "STATIC" OR ${ARG_TYPE} STREQUAL "SHARED" OR ${ARG_TYPE} STREQUAL "MODULE"))
+ message (FATAL_ERROR "Bad library type: ${ARG_TYPE}")
+ endif()
+
+ set (_scope_specifier)
+ if ((NOT ARG_SOURCES AND NOT ARG_MOC) OR ARG_PRECOMPILED)
+ set (_scope_specifier INTERFACE)
+
+ add_library (${target_name} INTERFACE)
+
+ if (ARG_PRECOMPILED)
+ if (ARG_TYPE STREQUAL "STATIC")
+ list (APPEND ARG_LIBRARIES "${CMAKE_CURRENT_SOURCE_DIR}/lib${target_name}.a")
+ else()
+ list (APPEND ARG_LIBRARIES "${CMAKE_CURRENT_SOURCE_DIR}/lib${target_name}.so")
+ endif()
+ endif()
+
+ target_link_libraries (${target_name} INTERFACE ${ARG_LIBRARIES})
+ else()
+ set (_scope_specifier PUBLIC)
+
+ # _moc (${ARG_NAME}_mocced ${ARG_MOC})
+
+ add_library (${target_name} ${ARG_TYPE} #${${ARG_NAME}_mocced}
+ ${ARG_SOURCES})
+
+ target_link_libraries (${target_name} ${ARG_LIBRARIES})
+ endif()
+ if (ARG_NAMESPACE)
+ add_library (${ARG_NAMESPACE}::${ARG_NAME} ALIAS ${target_name})
+ endif()
+ if (ARG_PUBLIC_HEADERS)
+ set_property (TARGET ${target_name} APPEND
+ PROPERTY PUBLIC_HEADER ${ARG_PUBLIC_HEADERS}
+ )
+ endif()
+
+ if (ARG_SYSTEM_INCLUDE_DIRECTORIES)
+ target_include_directories (${target_name} SYSTEM
+ ${ARG_SYSTEM_INCLUDE_DIRECTORIES})
+ endif()
+ if (ARG_INCLUDE_DIRECTORIES)
+ target_include_directories (${target_name} PUBLIC
+ $)
+ endif()
+
+ if (ARG_POSITION_INDEPENDENT)
+ set_property (TARGET ${target_name} APPEND
+ PROPERTY COMPILE_FLAGS -fPIC
+ )
+ endif()
+
+ if (ARG_DEPENDS)
+ add_dependencies (${target_name} ${ARG_DEPENDS})
+ endif()
+
+ if (ARG_COMPILE_DEFINITIONS)
+ target_compile_definitions (${target_name} ${_scope_specifier} ${ARG_COMPILE_DEFINITIONS})
+ endif()
+
+ if (ARG_COMPILE_OPTIONS)
+ target_compile_options (${target_name} ${_scope_specifier} ${ARG_COMPILE_OPTIONS})
+ endif()
+
+ if (ARG_INSTALL)
+ install (TARGETS ${target_name}
+ LIBRARY DESTINATION "${ARG_INSTALL_DESTINATION}"
+ ARCHIVE DESTINATION "${ARG_INSTALL_DESTINATION}"
+ )
+ endif()
+endfunction()
+
diff --git a/cmake/add_test.cmake b/cmake/add_test.cmake
new file mode 100644
index 00000000..a803e2cb
--- /dev/null
+++ b/cmake/add_test.cmake
@@ -0,0 +1,174 @@
+include (parse_arguments)
+
+function (compile_tarantella_test)
+ set(one_value_options NAME DESCRIPTION)
+ set(multi_value_options SOURCES LIBRARIES INCLUDE_DIRECTORIES
+ SYSTEM_INCLUDE_DIRECTORIES ARGS COMPILE_FLAGS)
+ set(required_options NAME SOURCES)
+
+ # save each argument into a variable named "ARG_argname"
+ _parse_arguments_with_unknown(ARG "${options}" "${one_value_options}"
+ "${multi_value_options}" "${required_options}" ${ARGN})
+
+ _default_if_unset(ARG_DESCRIPTION "${ARG_NAME}")
+ set(target_name ${ARG_NAME})
+
+ add_executable (${target_name} ${ARG_SOURCES})
+ list (APPEND ARG_LIBRARIES Boost::unit_test_framework
+ Boost::dynamic_linking)
+ target_compile_definitions (${target_name} PRIVATE
+ "-DBOOST_TEST_MODULE=\"${ARG_DESCRIPTION}\""
+ "-DBOOST_TEST_DYN_LINK")
+
+ #! \note Use RPATH for all tests
+ set_property (TARGET ${target_name} PROPERTY BUILD_WITH_INSTALL_RPATH true)
+ set_property (TARGET ${target_name} APPEND PROPERTY
+ INSTALL_RPATH
+ ${Boost_INCLUDE_DIR}/../lib:${CMAKE_BINARY_DIR})
+
+ if (Boost_VERSION VERSION_EQUAL 1.60 OR Boost_VERSION VERSION_GREATER 1.60)
+ list (INSERT ARG_ARGS 0 "--")
+ endif()
+
+ if (ARG_SYSTEM_INCLUDE_DIRECTORIES)
+ target_include_directories (${target_name} SYSTEM
+ ${ARG_SYSTEM_INCLUDE_DIRECTORIES})
+ endif()
+ if (ARG_INCLUDE_DIRECTORIES)
+ target_include_directories (${target_name} PRIVATE ${ARG_INCLUDE_DIRECTORIES})
+ endif()
+
+ target_link_libraries (${target_name} ${ARG_LIBRARIES})
+ if (ARG_COMPILE_FLAGS)
+ set_property (TARGET ${target_name} PROPERTY COMPILE_FLAGS ${ARG_COMPILE_FLAGS})
+ endif()
+endfunction()
+
+function (tarantella_gen_environment_paths)
+ set(multi_value_options VARIABLE_LIST)
+ set(required_options VARIABLE_LIST)
+ _parse_arguments(ARG "${options}" "${one_value_options}"
+ "${multi_value_options}" "${required_options}" ${ARGN})
+ set(env_var_names PATH LIBRARY_PATH LD_LIBRARY_PATH DYLD_LIBRARY_PATH CPATH PYTHONPATH)
+ set(env_vars )
+
+ foreach (var_name ${env_var_names})
+ if (DEFINED ENV{${var_name}})
+ list(APPEND env_vars "${var_name}=$ENV{${var_name}}")
+ endif()
+ endforeach()
+ set(${ARG_VARIABLE_LIST} ${env_vars} PARENT_SCOPE)
+endfunction()
+
+function (tarantella_gen_executable_script)
+ set(one_value_options SCRIPT_DIR SCRIPT_NAME)
+ set(required_options SCRIPT_DIR SCRIPT_NAME)
+ _parse_arguments(ARG "${options}" "${one_value_options}"
+ "${multi_value_options}" "${required_options}" ${ARGN})
+
+ set(tmp_script_path ${CMAKE_CURRENT_BINARY_DIR}/tmp/${ARG_SCRIPT_NAME})
+ file(REMOVE ${ARG_SCRIPT_DIR}/${ARG_SCRIPT_NAME})
+ file(WRITE ${tmp_script_path} "")
+ file(COPY ${tmp_script_path}
+ DESTINATION ${ARG_SCRIPT_DIR}
+ FILE_PERMISSIONS OWNER_READ OWNER_WRITE OWNER_EXECUTE
+ )
+ file(REMOVE ${tmp_script_path})
+endfunction()
+
+function (tarantella_gen_gpi_machinefile)
+ set(one_value_options NRANKS FILENAME)
+ set(required_options NRANKS FILENAME)
+ _parse_arguments(ARG "${options}" "${one_value_options}"
+ "${multi_value_options}" "${required_options}" ${ARGN})
+
+ file(WRITE ${ARG_FILENAME} "")
+ cmake_host_system_information(RESULT hostname QUERY HOSTNAME)
+ foreach(index RANGE 1 ${ARG_NRANKS})
+ file(APPEND ${ARG_FILENAME} "${hostname}\n")
+ endforeach()
+endfunction()
+
+function (tarantella_gen_test_script)
+ set(one_value_options NAME SCRIPT_DIR TEST_FILE)
+ set(options IS_PYTHON_TEST)
+ set(required_options NAME SCRIPT_DIR TEST_FILE)
+ _parse_arguments_with_unknown(ARG "${options}" "${one_value_options}"
+ "${multi_value_options}" "${required_options}" ${ARGN})
+
+ message(STATUS "Test: Generating ${ARG_NAME} script")
+ tarantella_gen_executable_script(SCRIPT_NAME ${ARG_NAME}
+ SCRIPT_DIR ${ARG_SCRIPT_DIR})
+
+ tarantella_gen_environment_paths(VARIABLE_LIST env_paths)
+
+ set(script_path ${ARG_SCRIPT_DIR}/${ARG_NAME})
+ foreach (var ${env_paths})
+ file(APPEND ${script_path} "export ${var}\n")
+ endforeach()
+ if (ARG_IS_PYTHON_TEST)
+ # Python test
+ file(APPEND ${script_path} "export PYTHONPATH=${CMAKE_BINARY_DIR}:${CMAKE_SOURCE_DIR}/src:\$\{PYTHONPATH\}\n")
+ file(APPEND ${script_path} "\n${Python_EXECUTABLE} -m pytest ${ARG_TEST_FILE}\n")
+ else()
+ # regular executable test
+ file(APPEND ${script_path} "\n${ARG_TEST_FILE}\n")
+ endif()
+endfunction()
+
+function (tarantella_add_gpi_test)
+ set(one_value_options NAME TARGET_FILE NRANKS RUNCOMMAND TEST_FILE
+ MACHINEFILE CLEANUP TIMEOUT SLEEP)
+ set(multi_value_options LABELS)
+ set(required_options NAME TARGET_FILE NRANKS RUNCOMMAND)
+ _parse_arguments_with_unknown(ARG "${options}" "${one_value_options}"
+ "${multi_value_options}" "${required_options}" ${ARGN})
+ _default_if_unset(ARG_SLEEP 0)
+ set(test_name ${ARG_NAME}_${ARG_NRANKS}ranks)
+
+ # increase overall timeout time to include the sleep time after the actual test
+ if (ARG_TIMEOUT)
+ math(EXPR ARG_TIMEOUT "${ARG_SLEEP} + ${ARG_TIMEOUT}")
+ endif()
+
+ if (ARG_MACHINEFILE)
+ # use user-defined machinefile
+ set(runparams "-n ${ARG_NRANKS} -m ${ARG_MACHINEFILE}")
+ else()
+ # generate machinefile for ARG_NRANKS running on the localhost
+ set(machinefile_path ${CMAKE_CURRENT_BINARY_DIR}/machinefile_${ARG_NAME}_${ARG_NRANKS}.tmp)
+ tarantella_gen_gpi_machinefile(NRANKS ${ARG_NRANKS}
+ FILENAME ${machinefile_path})
+ set(runparams "-n ${ARG_NRANKS} -m ${machinefile_path}")
+ endif()
+
+ # create gaspi_run test
+ add_test(NAME ${test_name}
+ WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}"
+ COMMAND "${CMAKE_COMMAND}"
+ -DRUNCOMMAND=${ARG_RUNCOMMAND}
+ -DRUNCOMMAND_ARGS="${runparams}"
+ -DTEST_EXECUTABLE="${ARG_TARGET_FILE}"
+ -DTEST_DIR="${CMAKE_BINARY_DIR}"
+ -DSLEEP="${ARG_SLEEP}"
+ -P "${CMAKE_SOURCE_DIR}/cmake/run_test.cmake"
+ )
+
+ # set labels if specified
+ if (ARG_LABELS)
+ set_property(TEST ${test_name} PROPERTY LABELS ${ARG_LABELS})
+ endif()
+
+ # set cleanup fixture script if specified
+ if (ARG_CLEANUP)
+ set_tests_properties(${test_name} PROPERTIES FIXTURES_REQUIRED ${ARG_CLEANUP})
+ endif()
+
+ # set timeout if specified
+ if (ARG_TIMEOUT)
+ set_tests_properties(${test_name} PROPERTIES TIMEOUT ${ARG_TIMEOUT})
+ endif()
+
+ # make sure the GPI tests are not run in parallel
+ set_tests_properties(${test_name} PROPERTIES RESOURCE_LOCK GPI_run_serial)
+endfunction()
diff --git a/cmake/add_test_wrappers.cmake b/cmake/add_test_wrappers.cmake
new file mode 100644
index 00000000..119c1ef8
--- /dev/null
+++ b/cmake/add_test_wrappers.cmake
@@ -0,0 +1,151 @@
+include (add_test)
+
+function (tarantella_compile_and_generate_gpi_test)
+ set (one_value_options NAME DESCRIPTION TIMEOUT)
+ set (multi_value_options LOCALRANKS_LIST SOURCES LIBRARIES INCLUDE_DIRECTORIES
+ SYSTEM_INCLUDE_DIRECTORIES ARGS COMPILE_FLAGS)
+ set (required_options NAME SOURCES LOCALRANKS_LIST)
+ _parse_arguments (ARG "${options}" "${one_value_options}"
+ "${multi_value_options}" "${required_options}" ${ARGN})
+ _default_if_unset (ARG_TIMEOUT 10)
+ set(CLEANUP_TEST_NAME gpi_cleanup)
+
+ set (target_name ${ARG_NAME}.test)
+ compile_tarantella_test(${ARGN}
+ NAME ${target_name})
+
+ # wrap call to the test executable in a script that exports the current environment
+ # the script can then be executed within a `gaspi_run` call
+ set(script_name run_${ARG_NAME}.sh)
+ set(script_path ${CMAKE_CURRENT_BINARY_DIR}/${script_name})
+ tarantella_gen_test_script(NAME ${script_name}
+ SCRIPT_DIR ${CMAKE_CURRENT_BINARY_DIR}
+ TEST_FILE ${CMAKE_CURRENT_BINARY_DIR}/${target_name})
+
+ message(STATUS "Test: Generating gaspi_run tests for ${ARG_NAME} with ${ARG_LOCALRANKS_LIST} ranks")
+ foreach(nlocalranks ${ARG_LOCALRANKS_LIST})
+ tarantella_add_gpi_test (NAME ${ARG_NAME}
+ NRANKS ${nlocalranks}
+ TARGET_FILE ${script_path}
+ TEST_FILE "${CMAKE_CURRENT_BINARY_DIR}/${target_name}"
+ RUNCOMMAND ${GPI2_GASPI_RUN}
+ CLEANUP ${CLEANUP_TEST_NAME}
+ TIMEOUT ${ARG_TIMEOUT}
+ SLEEP ${SLEEP_TIME_AFTER_TEST})
+ endforeach()
+endfunction()
+
+function (tarantella_compile_and_generate_test)
+ set (one_value_options NAME DESCRIPTION TIMEOUT)
+ set (multi_value_options SOURCES LIBRARIES INCLUDE_DIRECTORIES
+ SYSTEM_INCLUDE_DIRECTORIES ARGS COMPILE_FLAGS
+ LABELS)
+ set (required_options NAME SOURCES)
+ _parse_arguments (ARG "${options}" "${one_value_options}"
+ "${multi_value_options}" "${required_options}" ${ARGN})
+ _default_if_unset (ARG_TIMEOUT 10)
+
+ set (target_name ${ARG_NAME}.test)
+ compile_tarantella_test(${ARGN}
+ NAME ${target_name})
+ add_test (NAME ${ARG_NAME}
+ COMMAND $ ${ARGS})
+
+ # set labels if specified
+ if (ARG_LABELS)
+ set_property(TEST ${test_name} PROPERTY LABELS ${ARG_LABELS})
+ endif()
+
+ # set timeout if specified
+ if (ARG_TIMEOUT)
+ set_tests_properties(${test_name} PROPERTIES TIMEOUT ${ARG_TIMEOUT})
+ endif()
+endfunction()
+
+function (tarantella_generate_python_gpi_test)
+ set (one_value_options NAME TEST_FILE DESCRIPTION TIMEOUT)
+ set (multi_value_options LOCALRANKS_LIST LABELS ARGS)
+ set (required_options NAME TEST_FILE LOCALRANKS_LIST)
+ _parse_arguments (ARG "${options}" "${one_value_options}"
+ "${multi_value_options}" "${required_options}" ${ARGN})
+ set(CLEANUP_TEST_NAME gpi_cleanup)
+ _default_if_unset (ARG_TIMEOUT 600)
+ _default_if_unset (ARG_LABELS "Python")
+
+ list(APPEND ARG_LABELS "Python")
+ list(REMOVE_DUPLICATES ARG_LABELS)
+
+ # wrap call to the test executable in a script that exports the current environment
+ # the script can then be executed within a `gaspi_run` call
+ set(script_name run_${ARG_NAME}.sh)
+ set(script_path ${CMAKE_CURRENT_BINARY_DIR}/${script_name})
+ tarantella_gen_test_script(NAME ${script_name}
+ SCRIPT_DIR ${CMAKE_CURRENT_BINARY_DIR}
+ TEST_FILE ${ARG_TEST_FILE}
+ IS_PYTHON_TEST)
+
+ message(STATUS "Test: Generating gaspi_run tests for ${ARG_NAME} with ${ARG_LOCALRANKS_LIST} ranks")
+ foreach(nlocalranks ${ARG_LOCALRANKS_LIST})
+ tarantella_add_gpi_test (NAME ${ARG_NAME}
+ NRANKS ${nlocalranks}
+ TARGET_FILE ${script_path}
+ TEST_FILE "${ARG_TEST_FILE}"
+ RUNCOMMAND ${GPI2_GASPI_RUN}
+ TIMEOUT ${ARG_TIMEOUT}
+ CLEANUP ${CLEANUP_TEST_NAME}
+ SLEEP ${SLEEP_TIME_AFTER_TEST}
+ LABELS ${ARG_LABELS})
+ endforeach()
+endfunction()
+
+function (tarantella_generate_python_test)
+ set (one_value_options NAME TEST_FILE DESCRIPTION TIMEOUT)
+ set (multi_value_options LABELS ARGS)
+ set (required_options NAME TEST_FILE)
+ _parse_arguments (ARG "${options}" "${one_value_options}"
+ "${multi_value_options}" "${required_options}" ${ARGN})
+ set(CLEANUP_TEST_NAME gpi_cleanup)
+ _default_if_unset (ARG_TIMEOUT 600)
+ _default_if_unset (ARG_LABELS "Python")
+
+ list(APPEND ARG_LABELS "Python")
+ list(REMOVE_DUPLICATES ARG_LABELS)
+
+ # wrap call to the test executable in a script that exports the current environment
+ # the script can then be executed within a `gaspi_run` call
+ set(script_name run_${ARG_NAME}.sh)
+ set(script_path ${CMAKE_CURRENT_BINARY_DIR}/${script_name})
+ tarantella_gen_test_script(NAME ${script_name}
+ SCRIPT_DIR ${CMAKE_CURRENT_BINARY_DIR}
+ TEST_FILE ${ARG_TEST_FILE}
+ IS_PYTHON_TEST)
+
+ # create gaspi_run test
+ add_test(NAME ${ARG_NAME}
+ WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}"
+ COMMAND "${CMAKE_COMMAND}"
+ -DRUNCOMMAND=bash
+ -DRUNCOMMAND_ARGS=" "
+ -DTEST_EXECUTABLE="${script_path}"
+ -DTEST_DIR="${CMAKE_BINARY_DIR}"
+ -DSLEEP="1"
+ -P "${CMAKE_SOURCE_DIR}/cmake/run_test.cmake"
+ )
+
+ # set labels if specified
+ if (ARG_LABELS)
+ set_property(TEST ${ARG_NAME} PROPERTY LABELS ${ARG_LABELS})
+ endif()
+
+ # set cleanup fixture script if specified
+ if (ARG_CLEANUP)
+ set_tests_properties(${ARG_NAME} PROPERTIES FIXTURES_REQUIRED ${ARG_CLEANUP})
+ endif()
+
+ # set timeout if specified
+ if (ARG_TIMEOUT)
+ set_tests_properties(${ARG_NAME} PROPERTIES TIMEOUT ${ARG_TIMEOUT})
+ endif()
+
+ message(STATUS "Test: Generating test ${ARG_NAME}")
+endfunction()
diff --git a/cmake/cleanup.sh b/cmake/cleanup.sh
new file mode 100644
index 00000000..2562bf59
--- /dev/null
+++ b/cmake/cleanup.sh
@@ -0,0 +1,6 @@
+#!/usr/bin/env bash
+
+procs=`ps aux | grep --regexp="\(py\)\?test" | grep -v ctest | grep -v grep`
+if [ -n "$procs" ] ;then
+ ps aux | grep --regexp="\(py\)\?test" | grep -v ctest | grep -v grep | awk '{print $2}' | xargs kill 2>&1 > /dev/null
+fi
diff --git a/cmake/parse_arguments.cmake b/cmake/parse_arguments.cmake
new file mode 100644
index 00000000..2d4290d8
--- /dev/null
+++ b/cmake/parse_arguments.cmake
@@ -0,0 +1,27 @@
+# equivalent to CMakeParseArguments except that parse_arguments
+# * forbids UNPARSED_ARGUMENTS but requires to explicitly use
+# parse_arguments_with_unknown
+# * allows to specify required arguments
+
+include (CMakeParseArguments)
+
+macro (_parse_arguments _prefix _options _one_value_options _multi_value_options _required_options)
+ _parse_arguments_with_unknown ("${_prefix}" "${_options}" "${_one_value_options}" "${_multi_value_options}" "${_required_options}" ${ARGN})
+
+ if (${_prefix}_UNPARSED_ARGUMENTS)
+ list (LENGTH ${_prefix}_UNPARSED_ARGUMENTS _unparsed_length)
+ if (NOT _unparsed_length EQUAL 0)
+ message (FATAL_ERROR "unknown arguments: ${${_prefix}_UNPARSED_ARGUMENTS}")
+ endif()
+ endif()
+endmacro()
+
+macro (_parse_arguments_with_unknown _prefix _options _one_value_options _multi_value_options _required_options)
+ cmake_parse_arguments ("${_prefix}" "${_options}" "${_one_value_options}" "${_multi_value_options}" ${ARGN})
+
+ foreach (required ${_required_options})
+ if (NOT ${_prefix}_${required})
+ message (FATAL_ERROR "required argument ${required} missing")
+ endif()
+ endforeach()
+endmacro()
diff --git a/cmake/run_test.cmake b/cmake/run_test.cmake
new file mode 100644
index 00000000..88ef0e67
--- /dev/null
+++ b/cmake/run_test.cmake
@@ -0,0 +1,50 @@
+# Kill old processes that may be still running
+function (kill_old_processes)
+ set(one_value_options TEST_DIR TEST_EXECUTABLE)
+ cmake_parse_arguments(ARG "${options}" "${one_value_options}"
+ "${multi_value_options}" ${ARGN})
+
+ set(find_processes_command "ps -ef | grep ${ARG_TEST_DIR} | grep -v grep | grep -v ${ARG_TEST_EXECUTABLE}")
+ set(kill_command "${find_processes_command} | awk '{print $2}' | xargs -r kill -9")
+
+ execute_process(COMMAND sh -c "echo \"Killing `${find_processes_command} | wc -l` processes\"; ${find_processes_command}")
+ execute_process(COMMAND sh -c "${kill_command}"
+ COMMAND_ECHO STDOUT)
+endfunction()
+
+foreach(var TEST_DIR TEST_EXECUTABLE RUNCOMMAND RUNCOMMAND_ARGS SLEEP)
+ if(NOT DEFINED ${var})
+ message(FATAL_ERROR "'${var}' must be defined on the command line")
+ endif()
+
+ separate_arguments(var_value UNIX_COMMAND "${${var}}")
+ string(LENGTH "${var_value}" var_length)
+ if (var_length LESS 1)
+ message(FATAL_ERROR "'${var}' must be defined on the command line and not be empty")
+ endif()
+endforeach()
+
+separate_arguments(runparams_list UNIX_COMMAND "${RUNCOMMAND_ARGS}")
+separate_arguments(all_command_params UNIX_COMMAND
+ "${runparams_list} ${TEST_EXECUTABLE} ${TEST_ARGS}")
+kill_old_processes(TEST_DIR ${TEST_DIR}
+ TEST_EXECUTABLE ${TEST_EXECUTABLE})
+
+# Execute the test-executable
+execute_process(COMMAND ${RUNCOMMAND} ${all_command_params}
+ COMMAND_ECHO STDOUT
+ RESULT_VARIABLE result)
+
+# Sleep to ensure all processes are done and kill the remainder
+separate_arguments(sleep_time UNIX_COMMAND "${SLEEP}")
+execute_process(COMMAND ${CMAKE_COMMAND} -E sleep "${sleep_time}"
+ COMMAND ${CMAKE_COMMAND} -E echo "Sleep ${sleep_time}")
+kill_old_processes(TEST_DIR ${TEST_DIR}
+ TEST_EXECUTABLE ${TEST_EXECUTABLE})
+
+# Check return status
+if(result)
+ message(FATAL_ERROR "Test failed:'${result}'")
+endif()
+
+
diff --git a/cmake/version.py.in b/cmake/version.py.in
new file mode 100644
index 00000000..863f92f9
--- /dev/null
+++ b/cmake/version.py.in
@@ -0,0 +1,2 @@
+global tnt_version
+tnt_version = "@PROJECT_VERSION@"
diff --git a/docs/CMakeLists.txt b/docs/CMakeLists.txt
new file mode 100644
index 00000000..afe5b076
--- /dev/null
+++ b/docs/CMakeLists.txt
@@ -0,0 +1,17 @@
+set(SPHINX_SOURCE ${CMAKE_CURRENT_SOURCE_DIR}/source)
+set(SPHINX_BUILD ${CMAKE_CURRENT_BINARY_DIR}/)
+
+if (Sphinx_FOUND)
+ add_custom_target(docs ALL
+ COMMAND
+ Sphinx::Sphinx -b html
+ -Drelease=${PROJECT_VERSION}
+ ${SPHINX_SOURCE} ${SPHINX_BUILD}
+ WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
+ COMMENT "Generating documentation with Sphinx")
+
+ install(DIRECTORY ${SPHINX_BUILD}
+ DESTINATION ${CMAKE_INSTALL_PREFIX}/docs)
+else()
+ message(STATUS "Sphinx not found. Skipping documentation build.")
+endif()
\ No newline at end of file
diff --git a/docs/source/advanced_topics.rst b/docs/source/advanced_topics.rst
new file mode 100644
index 00000000..28caa105
--- /dev/null
+++ b/docs/source/advanced_topics.rst
@@ -0,0 +1,144 @@
+Advanced Topics
+===============
+
+This guide covers a number of advanced topics, such as
+performance, reproducibility and user customization.
+
+
+.. _ranks-label:
+
+GASPI ranks
+^^^^^^^^^^^
+
+In order to execute distributed DNN training, Tarantella starts multiple processes
+on different devices. These processes will be assigned different IDs by the GASPI
+communication library, in order to organize communication and synchronization between
+the different devices. These IDs are called *ranks*. Usually, Tarantella abstracts away
+the concept of *ranks*, in such a way that Tarantella's user interface is essentially
+the same as Keras' user interface.
+
+However, sometimes it is useful, to execute a specific part of code only on one
+or a subgroup of all ranks. In particular, one sometimes wants to execute a code
+block on the devices that started ``tarantella``, the so-called *master rank*.
+
+To access ranks, Tarantella provides the following functions
+
+* ``tnt.get_rank()``
+* ``tnt.get_size()``
+* ``tnt.get_master_rank()``
+* ``tnt.is_master_rank()``
+
+``tnt.get_rank()`` returns the ID of the local rank.
+``tnt.get_size()`` returns the total number of ranks.
+``tnt.get_master_rank()`` and ``tnt.is_master_rank()`` return the ID of the master rank
+and a boolean for whether the local rank is the master rank or not, respectively.
+
+Here is a simple example, when using the master rank can be useful to print notifications
+only once to ``stdout``:
+
+.. code-block:: python
+
+ if tnt.is_master_rank():
+ print("Printing from the master rank")
+
+In the same vein, you might want to use ranks to execute :ref:`callbacks ` for logging
+only on one rank:
+
+.. code-block:: python
+
+ history_callback = tf.keras.callbacks.History()
+ tnt_model.fit(train_dataset,
+ callbacks = [history_callback] if tnt.is_master_rank() else [])
+
+
+.. _using-local-batch-sizes-label:
+
+Using local batch sizes
+^^^^^^^^^^^^^^^^^^^^^^^
+
+As it has been stated in the :ref:`points to consider `, when using
+Tarantella the user always specifies the *global* batch size. This has the advantage that
+the optimization process during the training of a DNN, and in particular the loss function do not
+depend on the number of devices used during execution.
+
+However, when the number of devices becomes
+very large, the (device-local) micro-batch size might become so small, that DNN kernel implementations
+are less efficient, resulting in overall performance degradation.
+This is why it is in practice often advisable to scale the global batch size with the number of nodes.
+This will often lead to linear speedups in terms of the time to accuracy when increasing
+the number of devices used, at least up to some *critical batch size*, cf. [Shallue]_ and [McCandlish]_.
+Changing the batch size of the optimizer will however also imply the need to adapt the learning rate
+schedule.
+
+.. todo::
+
+ Enable when the Tutorial is updated:
+ For details, cf. for instance the :ref:`ResNet-50 tutorial `.
+
+If you decide to scale the batch size with the number of nodes, Tarantella provides
+two different ways to achieve this easily. The first option is to multiply the local batch size
+(for instance passed via a command-line parameter) with the number of devices used,
+batch your dataset with it, and call ``fit`` on it:
+
+.. code-block:: python
+
+ micro_batch_size = args.micro_batch_size
+ batch_size = tnt.get_size() * micro_batch_size
+ train_dataset = train_dataset.batch(batch_size)
+ tnt_model.fit(train_dataset)
+
+As a second option you can also pass the local batch size directly to the ``tnt_micro_batch_size``
+parameter in fit, and leave your dataset unbatched:
+
+.. code-block:: python
+
+ micro_batch_size = args.micro_batch_size
+ tnt_model.fit(train_dataset,
+ tnt_micro_batch_size = micro_batch_size)
+
+This parameter is also available in ``evaluate`` and ``predict``. In addition, ``fit`` also supports
+setting the validation set micro batch size in a similar way with ``tnt_validation_micro_batch_size``.
+For more information, please also read :ref:`using distributed datasets `.
+
+
+.. _tensor-fusion-threshold-label:
+
+Setting Tensor Fusion threshold
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Tarantella automatically uses :ref:`Tensor Fusion ` with a default
+threshold of 32kB. This threshold specifies the minimal size of local buffers in *allreduce*
+communication operations used to accumulate partial gradients during *backpropagation*.
+
+Note that the threshold value implies a trade-off between the potential to utilize network
+bandwidth, and the overlap of computation and communication during *backpropagation*. The
+larger the threshold, the more bandwidth-bound the *allreduce* algorithm will get, but
+the less potential there will be to overlap its execution with kernel computations.
+Also note that the ideal threshold value will generally depend on the number of nodes used.
+
+To change the default value, you can pass a threshold value in kB to ``tarantella``:
+
+.. code-block:: bash
+
+ tarantella --hostfile hostfile --fusion-threshold= -- model.py
+
+
+.. _reproducibility-label:
+
+Reproducibility
+^^^^^^^^^^^^^^^
+
+Reproducibility is a very important prerequisite to obtain meaningful results in
+scientific computing and research. Unfortunately, using stochastic algorithms,
+pseudo random generators and having to deal with the pitfalls of floating-point arithmetics,
+it is particularly difficult to achieve reproducibility in Deep Learning research.
+
+In order to be able to reproduce results obtained with TensorFlow, when running in
+a multi-node/multi-device setting with Tarantella, one needs to meet at least
+the following requirements:
+
+* set the random seed with ``tf.random.set_seed(seed)``
+* set the environment variable ``os.environ['TF_CUDNN_DETERMINISTIC']='1'``
+* set the shuffle seeds when using ``tf.data.Dataset`` with ``shuffle(seed=seed)`` and ``list_files(seed=seed)``
+* set the ``deterministic`` parameter to ``True`` in ``Dataset`` transformations such as ``interleave`` and ``map``
+* make sure the number of samples in your datasets equal a multiple of ``batch_size``
diff --git a/docs/source/bug_reports.rst b/docs/source/bug_reports.rst
new file mode 100644
index 00000000..20f632ec
--- /dev/null
+++ b/docs/source/bug_reports.rst
@@ -0,0 +1,35 @@
+.. _bug-reports-label:
+
+Bug Reports
+===========
+
+To report a bug please open an `issue on GitHub `_.
+
+When opening an issue, please make sure you include as much
+information as possible about the issue. Please consider providing at
+least the following points:
+
+ * What version of Tarantella you are using
+ * What linux distribution you are using (e.g., Linux Ubuntu 20.04)
+ * What kind of system you are experiencing the issue on (type and
+ number of nodes, network interconnect, etc.)
+ * What did you expect to see and what have you seen instead
+ * What exact steps are needed to reproduce the issue
+
+.. _feature-requests-label:
+
+Feature Requests
+================
+
+For contributions other than modifications to the source code, as for
+example suggestions of a feature or enhancement, please open
+an `issue on GitHub `_
+with the label ``Feature``.
+
+When providing a feature request, please consider providing at least
+the following information:
+
+ * What is the current behavior of the software and how does the feature improve it
+ * Who would benefit from the feature
+ * Is there a relevant reference or academic paper describing the feature
+ * Are you willing to contribute to and/or maintain the feature
diff --git a/docs/source/conf.py b/docs/source/conf.py
new file mode 100644
index 00000000..cc2e7e3b
--- /dev/null
+++ b/docs/source/conf.py
@@ -0,0 +1,72 @@
+# Configuration file for the Sphinx documentation builder.
+#
+# This file only contains a selection of the most common options. For a full
+# list see the documentation:
+# https://www.sphinx-doc.org/en/master/usage/configuration.html
+
+# -- Path setup --------------------------------------------------------------
+
+# If extensions (or modules to document with autodoc) are in another directory,
+# add these directories to sys.path here. If the directory is relative to the
+# documentation root, use os.path.abspath to make it absolute, like shown here.
+#
+# import os
+# import sys
+# sys.path.insert(0, os.path.abspath('.'))
+
+
+# -- Project information -----------------------------------------------------
+
+project = 'Tarantella'
+copyright = '2020 Fraunhofer'
+author = 'Peter Labus, Alexandra Carpen-Amarie, Martin Kuehn'
+
+# The full version, including alpha/beta/rc tags
+release = '0'
+
+
+# -- General configuration ---------------------------------------------------
+
+# Add any Sphinx extension module names here, as strings. They can be
+# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
+# ones.
+extensions = ['sphinx.ext.todo']
+try:
+ import sphinx_rtd_theme
+ extensions += ['sphinx_rtd_theme']
+except:
+ pass
+
+# Display TODOs by setting to True
+todo_include_todos = False
+
+# Add any paths that contain templates here, relative to this directory.
+templates_path = ['_templates']
+
+# List of patterns, relative to source directory, that match files and
+# directories to ignore when looking for source files.
+# This pattern also affects html_static_path and html_extra_path.
+exclude_patterns = []
+
+
+# -- Options for HTML output -------------------------------------------------
+
+# The theme to use for HTML and HTML Help pages. See the documentation for
+# a list of builtin themes.
+#
+# html_theme = 'alabaster' # default
+try:
+ import sphinx_rtd_theme
+ html_theme = "sphinx_rtd_theme"
+except:
+ pass
+
+# Add any paths that contain custom static files (such as style sheets) here,
+# relative to this directory. They are copied after the builtin static files,
+# so a file named "default.css" will overwrite the builtin "default.css".
+html_static_path = ['_static']
+html_title = release
+html_theme_options = dict()
+html_theme_options ['logo_only'] = False
+# html_theme_options['display_version']= True
+# html_logo = "pics/tnt_logo.png"
diff --git a/docs/source/contact.rst b/docs/source/contact.rst
new file mode 100644
index 00000000..472374e2
--- /dev/null
+++ b/docs/source/contact.rst
@@ -0,0 +1,14 @@
+.. _contact-label:
+
+Contact
+=======
+
+In case you have any feature request,
+or want to report a bug please follow
+:ref:`these instructions `.
+
+If you consider contributing to Tarantella, please follow
+the instructions :ref:`here `.
+
+If you have any further questions or comments please email to
+support@tarantella.org
diff --git a/docs/source/contributing.rst b/docs/source/contributing.rst
new file mode 100644
index 00000000..08cdd8ad
--- /dev/null
+++ b/docs/source/contributing.rst
@@ -0,0 +1,20 @@
+.. _contributing-label:
+
+Contributing
+============
+
+Thank you for considering to contribute to Tarantella.
+
+There are many ways to contribute to Tarantella.
+This includes sharing DNN models distributed through Tarantella,
+providing suggestions on improving the documentation,
+as well as contributing with changes to the
+`Tarantella code base `_.
+Even by simply providing suggestions on how we can
+:ref:`improve Tarantella `
+and help spreading the word about it are great ways to contribute
+and make Tarantella better software.
+
+If you want to contribute to Tarantella with changes to its code,
+please open a `pull request `_
+on GitHub.
diff --git a/docs/source/data_parallel.rst b/docs/source/data_parallel.rst
new file mode 100644
index 00000000..0a914d47
--- /dev/null
+++ b/docs/source/data_parallel.rst
@@ -0,0 +1,195 @@
+Distributed Data Parallel Training
+==================================
+
+The following section explains the parallelization strategy Tarantella uses to
+provide distributed training. A full understanding thereof is, however, not required
+to be able to use the software. Please note the :ref:`points to consider `
+to achieve best performance and reproducibility.
+
+The general idea
+----------------
+
+In order to parallelize the training of DNNs, different, complementary strategies are available.
+The conceptually simplest and most efficient one is called *data parallelism*. This strategy
+is already in use when deploying batched optimizers, such as stochastic gradient descent (SGD)
+or ADAM. In this case, input samples are grouped together in so-called mini-batches and
+are processed in parallel.
+
+Distribution of mini-batches
+----------------------------
+
+Tarantella extends this scheme by splitting each mini-batch into a number of micro-batches,
+which are then executed on different devices (e.g., GPUs).
+In order to do this, the DNN is replicated on each device,
+which then processes part of the data independently of the other devices.
+During the *backpropagation* pass, partial results need to be accumulated via a so-called
+`allreduce `_
+collective operation.
+
+Overlapping communication with computation
+------------------------------------------
+
+Tarantella implements this communication scheme using the
+`Global Address Space Programming Interface (GASPI) `_.
+This allows in particular to overlap the communication needed to execute *allreduce* operations
+with the computation done in the *backpropagation* part of the DNN training.
+This is done by starting *allreduce* operations as soon as the required local incoming gradients are
+available, while continuing with *backpropagation* calculations at the same time.
+The final, accumulated gradients are only expected once the entire *backpropagation* is completed.
+This drastically mitigates the communication overhead introduced by the need to synchronize
+the different devices, and leads to higher scalability.
+
+.. _tensor-fusion-label:
+
+Tensor Fusion
+-------------
+
+The granularity at which Tarantella executes *allreduce* operations can be varied from
+one *allreduce* per layer (finest granularity) to one *allreduce* per iteration (coarsest granularity).
+Using coarser granularities, i.e., *fusing* gradient tensors,
+can lead to better bandwidth utilization, thus potentially increasing performance.
+*Tensor Fusion* is set up before the first iteration of training and incurs no additional communication overhead.
+Tarantella enables *Tensor Fusion* by default, but its granularity can be adjusted by the user,
+cf. :ref:`here `.
+
+Model initialization and loading
+--------------------------------
+
+In order to guarantee that all devices have the same copy of the DNN when training is initiated,
+the model needs to be communicated from one device to all the others.
+This is done in Tarantella via the use of a so-called
+`broadcast `_ operation.
+This scheme applies both when the weights of a DNN are initialized randomly,
+or loaded from a checkpoint.
+As Tarantella provides this functionality automatically,
+the user does not have to take care of it.
+
+.. _points-to-consider-label:
+
+Distributed Datasets
+=====================
+
+In order to process micro-batches independently on each device and to obtain the same results
+as in serial execution, the input data of each mini-batch has to be split and distributed
+among all devices.
+
+Tarantella automatically takes care of this through the use of distributed datasets.
+The user simply provides Tarantella with a ``tf.data.Dataset`` that is batched
+with the mini-batch size. Tarantella will then automatically distribute the input data
+by sharding the mini-batch into individual micro-batches. Sharding is done at the level
+of samples (as opposed to e.g., files) to ensure :ref:`reproducibility `
+of serial results.
+
+To guarantee reproducibility, it is also important that shuffling of samples is done
+in the same way on all devices. Tarantella does this using either the ``seed`` provided
+by the user, or a specific default seed. Please refer to the
+:ref:`Quick Start `
+for more details.
+
+Points to Consider
+==================
+
+.. _global-vs-local-batch-size-label:
+
+Global versus local batch size
+------------------------------
+
+As explained above, when using data parallelism, there exists a *mini-batch size*
+(in the following also called global batch size or simply batch size)
+as well as a *micro-batch size* (also called local batch size).
+The former represents the number of samples that
+is averaged over in the loss function of the optimizer, and is equivalent to
+the (mini-)batch size used in non-distributed training. The latter is the number
+of samples that is processed locally by each of the devices per iteration.
+
+.. note::
+
+ In Tarantella, the user always specifies the **global batch size**.
+
+Using a strictly synchronous optimization scheme, and by carefully handling the data distribution,
+**Tarantella guarantees the reproducibility of DNN training results independently of the number of
+devices used**, as long as all hyperparameters (such as global batch size and learning rate)
+are kept constant. [#footnote_random_seeds]_
+
+However, to achieve best performance for certain DNN operators (`Conv2D`, `Dense`, etc.)
+it is often advisable to *keep the local batch size constant*, and scale the global
+batch size with the number of devices used. This, in turn, will force you to
+adjust other hyperparameters, such as the learning rate, in order to converge
+to a comparable test accuracy, as observed for instance in [Shallue]_.
+
+In practice, the use of a learning rate schedule with initial *warm up* and
+a *linear learning rate scaling* [Goyal]_, as it is described
+:ref:`here `, often suffices.
+
+.. tip::
+
+ For best performance, scale the batch size with the number of devices used,
+ and :ref:`adapt the learning rate schedule `.
+
+Batch normalization layers
+--------------------------
+
+The issue of global versus local batch size particularly affects the layers
+that calculate (and learn) statistics over entire batches.
+A well-known example of this type of layer is
+`batch normalization `_.
+
+.. caution::
+
+ Tarantella always calculates batch statistics over **local batches**.
+
+As a consequence, the training results for DNNs with batch-normalization layers
+**will not be identical when changing the number of devices, even if
+the global batch size stays the same.**
+At the moment, this can be circumvented by using normalization layers that
+do *not* average over entire batches, such as instance normalization
+[Ulyanov]_.
+
+Averaging over *local* batches instead of global batches should in practice
+have only minor influence on the quality of the final test accuracy.
+Note however, the extreme case of very small *local* batch sizes.
+
+.. caution::
+
+ Avoid using ``BatchNormalization`` layers when the global batch size
+ divided by the number of devices used is *smaller than 16*.
+
+In such cases, the local batches that are used to collect statistics are
+too small to obtain meaningful results. This will likely reduce the
+benefits of batch normalization, cf. for instance [Yang]_ and [Uppal]_.
+In this case, please consider increasing the global batch size,
+or reducing the number of devices used.
+
+Managing individual devices
+---------------------------
+
+Although Tarantella's user interface abstracts away most of the details of
+parallel programming, it is sometimes useful to be able to control
+Python code execution at device level. This can be achieved using the
+`GASPI `_ concept
+of a ``rank``. Details on how to do this can be found in the
+:ref:`advanced topics `.
+
+.. rubric:: References
+
+.. [Shallue] Shallue, Christopher J., et al. "Measuring the effects of data parallelism on neural network training." arXiv preprint arXiv:1811.03600 (2018).
+
+.. [Ulyanov] Ulyanov, Dmitry, Andrea Vedaldi, and Victor Lempitsky. "Instance normalization: The missing ingredient for fast stylization." arXiv preprint arXiv:1607.08022 (2016).
+
+.. [Goyal] Goyal, Priya, et al. "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour." arXiv preprint arXiv:1706.02677 (2017).
+
+.. [Yang] Yang, Greg, et al. "A mean field theory of batch normalization." arXiv preprint arXiv:1902.08129 (2019).
+
+.. [Uppal] https://towardsdatascience.com/curse-of-batch-normalization-8e6dd20bc304
+
+.. [McCandlish] McCandlish, Sam, et al. "An empirical model of large-batch training." arXiv preprint arXiv:1812.06162 (2018).
+
+.. [He] He, Kaiming, et al. "Deep residual learning for image recognition." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
+
+.. [Vaswani] Vaswani, Ashish, et al. "Attention is all you need." Advances in neural information processing systems. 2017.
+
+.. rubric:: Footnotes
+
+.. [#footnote_random_seeds] This is strictly true, only when all randomness in TensorFlow is
+ seeded or switched off, as explained in the :ref:`advanced topics `
+
diff --git a/docs/source/faq.rst b/docs/source/faq.rst
new file mode 100644
index 00000000..b2ba811c
--- /dev/null
+++ b/docs/source/faq.rst
@@ -0,0 +1,80 @@
+.. _faq-label:
+
+Frequently Asked Questions (FAQ)
+================================
+
+This is a list of frequently asked questions about Tarantella.
+Please feel free to :ref:`suggest new ones `!
+
+.. admonition:: Question
+
+ How can I ssh to ``localhost`` without password?
+
+In order to run Tarantella programs, you will need to be able to ssh to ``localhost`` without password.
+In order to do that generate ``ssh`` keys first:
+
+.. code-block:: bash
+
+ cd ~/.ssh
+ ssh-keygen
+
+Make sure not to overwrite existing keys.
+When asked for a passphrase, ``Enter passphrase (empty for no passphrase):``, simply leave empty
+and return with enter.
+Also take specific care to set correct user rights on all files in ``.ssh``,
+cf. for instance `here `__.
+Next, append the public key to the ``authorized_keys`` file:
+
+.. code-block:: bash
+
+ cat id_rsa.pub >> authorized_keys
+
+Now, install and start an ssh server, e.g., openssh-server on Fedora.
+More details can be found for instance
+`here `__.
+
+.. admonition:: Question
+
+ I get an execution error ``GPI library initialization incorrect environment vars`` when
+ trying to run my script. What shall I do?
+
+Most likely you are running your program with ``python my_script.py`` or ``./my_script.py``.
+Please make sure to execute your code with ``tarantella my_script.py`` instead.
+
+.. admonition:: Question
+
+ I get an execution error ``GPI library initialization general error``. What shall I do?
+
+This error occurs when the GASPI library tries to connect to a previously used socket, that is not yet released.
+Try to re-run your code after a short while so that the port becomes available again.
+
+.. admonition:: Question
+
+ The execution seems to stall. What shall I do?
+
+Please kill any processes that might be still running from a previous (aborted) call to ``tarantella``.
+
+.. admonition:: Question
+
+ | When trying to build Tarantella, CMake cannot find pybind11:
+ | ``Could not find a package configuration file provided by "pybind11" with any``
+ | ``of the following names: [...]``
+ | What shall I do?
+
+This error occurs when pybind11 is installed using pip.
+Please instead use conda, as recommended in the :ref:`installation guide `.
+
+.. admonition:: Question
+
+ When trying to build Tarantella, CMake does not detect the Python interpreter from the
+ active conda environment. What shall I do?
+
+You will need to manually add the path to the conda environment's ``bin`` directory to your ``PATH``.
+You will also need to specify the path to the python library on the command line when configuring Tarantella:
+
+.. code-block:: bash
+
+ PATH_TO_CONDA_ENV=/path/to/conda/env
+ export PATH=${PATH_TO_CONDA_ENV}/bin:${PATH}
+ cmake -DPYTHON_EXECUTABLE=${PATH_TO_CONDA_ENV}/bin/python \
+ -DPYTHON_LIBRARY=${PATH_TO_CONDA_ENV}/lib ../
diff --git a/docs/source/index.rst b/docs/source/index.rst
new file mode 100644
index 00000000..8a3a7e37
--- /dev/null
+++ b/docs/source/index.rst
@@ -0,0 +1,46 @@
+.. image:: pics/tnt_logo_text.png
+ :width: 750
+ :align: center
+
+|
+`Tarantella `_
+is an open-source, distributed Deep Learning framework built on top of TensorFlow 2,
+providing scalable Deep Neural Network training on CPU and GPU compute clusters.
+
+Tarantella is easy-to-use, allows to re-use existing TensorFlow 2/Keras models,
+and does not require any knowledge of parallel computing.
+
+.. image:: pics/tnt_run.gif
+ :width: 750
+ :align: center
+
+|
+
+Table of contents
+=================
+
+.. toctree::
+ :maxdepth: 2
+ :caption: Overview
+
+ why_tarantella
+ data_parallel
+
+.. toctree::
+ :maxdepth: 2
+ :caption: Getting started
+
+ installation
+ quick_start
+ tutorials
+ advanced_topics
+ faq
+
+.. toctree::
+ :maxdepth: 2
+ :caption: Community
+
+ bug_reports
+ contributing
+ contact
+ license
diff --git a/docs/source/installation.rst b/docs/source/installation.rst
new file mode 100644
index 00000000..6bb6f809
--- /dev/null
+++ b/docs/source/installation.rst
@@ -0,0 +1,197 @@
+.. _installation-label:
+
+Installation
+============
+
+Tarantella needs to be built `from source `_.
+Since Tarantella is built on top of `TensorFlow 2 `_,
+you will require a recent version of it. Additionally, you will need an installation of
+the open-source communication library `GPI-2 `_, which Tarantella uses
+to communicate between processes.
+Lastly, you will need `pybind11 `_, which is required
+for Python and C++ inter-communication.
+
+In the following we will look at the required steps in detail.
+
+Installing dependencies
+-----------------------
+
+Compiler and build system
+^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Tarantella can be built using a recent `gcc `_
+compiler (from version ``7.4.0``).
+You will also need the build tool `CMake `_ (from version ``3.8``).
+
+Installing GPI-2
+^^^^^^^^^^^^^^^^
+
+Next, you will need to download, compile and install the GPI-2 library.
+The currently supported version is ``v1.4.0``, which needs to be built with
+position independent flags (``-fPIC``).
+
+To download the required version, clone the
+`git repository `_
+and checkout the correct ``tag``:
+
+.. code-block:: bash
+
+ git clone https://github.com/cc-hpc-itwm/GPI-2.git
+ cd GPI-2
+ git fetch --tags
+ git checkout -b v1.4.0 v1.4.0
+
+Now, use `autotools `_ to configure and compile the code
+
+.. code-block:: bash
+
+ ./autogen.sh
+ export GPI2_INSTALLATION_PATH=/your/installation/path
+ CFLAGS="-fPIC" CPPFLAGS="-fPIC" ./configure --with-ethernet --prefix=${GPI2_INSTALLATION_PATH}
+ make
+
+where ``${GPI2_INSTALLATION_PATH}`` needs to be replaced with the path where you want to install
+GPI-2. Note the ``--with-ethernet`` option, which will use standard TCP sockets for communication.
+This is the correct option for laptops and workstations.
+
+In case you want to use Infiniband, replace the above option with ``--with-infiniband``.
+Now you are ready to install GPI-2 with
+
+.. code-block:: bash
+
+ make install
+ export PATH=${GPI2_INSTALLATION_PATH}/bin:$PATH
+ export LD_LIBRARY_PATH=${GPI2_INSTALLATION_PATH}/lib64:$LD_LIBRARY_PATH
+
+where the last two commands make the library visible to your system.
+If required, GPI-2 can be removed from the target directory by using ``make uninstall``.
+
+Installing TensorFlow 2
+^^^^^^^^^^^^^^^^^^^^^^^
+
+Next you will need to install TensorFlow 2.
+Tarantella supports TensorFlow versions ``2.0`` to ``2.2``.
+Either version can be installed in a conda environment using pip,
+as recommended on the `TensorFlow website `_.
+
+In order to do that, first install `conda `_ on your system.
+Then, create and activate an environment for Tarantella:
+
+.. code-block:: bash
+
+ conda create tarantella
+ conda activate tarantella
+
+Now, you can install the latest supported TensorFlow version with
+
+.. code-block:: bash
+
+ conda install python=3.7
+ pip install --upgrade tensorflow==2.2
+
+.. _installation-pybind11-label:
+
+Installing pybind11
+^^^^^^^^^^^^^^^^^^^
+
+The last dependency you will need to install is
+`pybind11 `__,
+which is available through pip and conda.
+We recommend installing pybind11 via conda:
+
+.. code-block:: bash
+
+ conda install pybind11 -c conda-forge
+
+SSH key-based authentication
+----------------------------
+
+In order to use Tarantella on a cluster, make sure you can ssh between nodes
+without password. For details, refer to the :ref:`FAQ section `.
+In particular, to test Tarantella on your local machine, make sure
+you can ssh to ``localhost`` without password.
+
+Building Tarantella from source
+-------------------------------
+
+With all dependencies installed, we can now download, configure and compile Tarantella.
+To download the source code, simply clone the
+`GitHub repository `__:
+
+.. code-block:: bash
+
+ git clone https://github.com/cc-hpc-itwm/tarantella.git
+
+Next, we need to configure the build system using CMake.
+For a standard out-of-source build, we create a separate ``build`` folder and run ``cmake``
+in it:
+
+.. code-block:: bash
+
+ cd tarantella
+ mkdir build && cd build
+ export TARANTELLA_INSTALLATION_PATH=/your/installation/path
+ cmake -DCMAKE_INSTALL_PREFIX=${TARANTELLA_INSTALLATION_PATH} ..
+
+Now, we can compile and install Tarantella to ``TARANTELLA_INSTALLATION_PATH``:
+
+.. code-block:: bash
+
+ make
+ make install
+ export PATH=${TARANTELLA_INSTALLATION_PATH}/bin:${PATH}
+
+[Optional] Building and running tests
+-------------------------------------
+
+In order to build Tarantella with tests, you will also need to install
+`Boost `_
+(for C++ tests), and `pytest `_ (for Python tests).
+
+To install boost with the required `devel`-packages, under Ubuntu you can use
+
+.. code-block:: bash
+
+ sudo apt install libboost-all-dev
+
+while in Fedora you can use
+
+.. code-block:: bash
+
+ sudo dnf install boost boost-devel
+
+To install pytest you can use pip:
+
+.. code-block:: bash
+
+ pip install -U pytest
+
+After having installed these libraries, make sure to configure Tarantella with testing switched on:
+
+.. code-block:: bash
+
+ cmake -DENABLE_TESTING=ON ..
+
+Now you can compile Tarantella and run its tests in the ``build`` directory.
+
+.. code-block:: bash
+
+ make
+ ctest
+
+[Optional] Building documentation
+---------------------------------
+
+If you would like to build `the documentation `_
+locally, run the following ``cmake`` command
+
+.. code-block:: bash
+
+ cmake -DCMAKE_INSTALL_PREFIX=${TARANTELLA_INSTALLATION_PATH} -DBUILD_DOCS=ON ..
+
+before compiling.
+This requires you to have `Sphinx `_ installed:
+
+.. code-block:: bash
+
+ pip install -U sphinx
diff --git a/docs/source/license.rst b/docs/source/license.rst
new file mode 100644
index 00000000..b46bf268
--- /dev/null
+++ b/docs/source/license.rst
@@ -0,0 +1,5 @@
+License
+=======
+
+.. literalinclude:: ../../LICENSE
+ :language: text
diff --git a/docs/source/model.py b/docs/source/model.py
new file mode 100644
index 00000000..2845f141
--- /dev/null
+++ b/docs/source/model.py
@@ -0,0 +1,89 @@
+import argparse
+import tensorflow as tf
+from tensorflow import keras
+
+import tarantella as tnt
+tnt.init()
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-bs", "--batch_size", type=int, default=64)
+ parser.add_argument("-e", "--number_epochs", type=int, default=1)
+ parser.add_argument("-lr", "--learning_rate", type=float, default=0.01)
+ parser.add_argument("-train", "--train_size", type=int, default=48000)
+ parser.add_argument("-val", "--val_size", type=int, default=6400)
+ parser.add_argument("-test", "--test_size", type=int, default=6400)
+ args = parser.parse_args()
+ return args
+
+def mnist_as_np_arrays(training_samples, validation_samples, test_samples):
+ mnist_train_size = 60000
+ mnist_test_size = 10000
+ assert(training_samples + validation_samples <= mnist_train_size)
+ assert(test_samples <= mnist_test_size)
+
+ # load given number of samples
+ (x_train_all, y_train_all), (x_test_all, y_test_all) = \
+ keras.datasets.mnist.load_data()
+ x_train = x_train_all[:training_samples]
+ y_train = y_train_all[:training_samples]
+ x_val = x_train_all[training_samples:training_samples+validation_samples]
+ y_val = y_train_all[training_samples:training_samples+validation_samples]
+ x_test = x_test_all[:test_samples]
+ y_test = y_test_all[:test_samples]
+
+ # normalization and reshape
+ x_train = x_train.reshape(training_samples,28,28,1).astype('float32') / 255.
+ x_val = x_val.reshape(validation_samples,28,28,1).astype('float32') / 255.
+ x_test = x_test.reshape(test_samples,28,28,1).astype('float32') / 255.
+ y_train = y_train.astype('float32')
+ y_val = y_val.astype('float32')
+ y_test = y_test.astype('float32')
+
+ return (x_train, y_train), (x_val, y_val), (x_test, y_test)
+
+def lenet5_model_generator():
+ inputs = keras.Input(shape=(28,28,1,), name='input')
+ x = keras.layers.Conv2D(20, 5, padding="same", activation='relu')(inputs)
+ x = keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(x)
+ x = keras.layers.Conv2D(50, 5, padding="same", activation='relu')(x)
+ x = keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(x)
+ x = keras.layers.Flatten()(x)
+ x = keras.layers.Dense(500, activation='relu')(x)
+ outputs = keras.layers.Dense(10, activation='softmax')(x)
+ return keras.Model(inputs=inputs, outputs=outputs)
+
+args = parse_args()
+
+# Create Tarantella model
+model = tnt.Model(lenet5_model_generator())
+
+# Compile Tarantella model (as with Keras)
+model.compile(optimizer = keras.optimizers.SGD(learning_rate=args.learning_rate),
+ loss = keras.losses.SparseCategoricalCrossentropy(),
+ metrics = [keras.metrics.SparseCategoricalAccuracy()])
+
+# Load MNIST dataset (as with Keras)
+shuffle_seed = 42
+(x_train, y_train), (x_val, y_val), (x_test, y_test) = \
+ mnist_as_np_arrays(args.train_size, args.val_size, args.test_size)
+
+train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
+train_dataset = train_dataset.shuffle(len(x_train), shuffle_seed)
+train_dataset = train_dataset.batch(args.batch_size)
+train_dataset = train_dataset.prefetch(buffer_size = tf.data.experimental.AUTOTUNE)
+
+val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
+val_dataset = val_dataset.batch(args.batch_size)
+
+test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
+test_dataset = test_dataset.batch(args.batch_size)
+
+# Train Tarantella model (as with Keras)
+model.fit(train_dataset,
+ validation_data = val_dataset,
+ epochs = args.number_epochs,
+ verbose = 1)
+
+# Evaluate Tarantella model (as with Keras)
+model.evaluate(test_dataset, verbose = 1)
diff --git a/docs/source/pics/tnt_logo.png b/docs/source/pics/tnt_logo.png
new file mode 100644
index 00000000..87f475be
Binary files /dev/null and b/docs/source/pics/tnt_logo.png differ
diff --git a/docs/source/pics/tnt_logo_text.png b/docs/source/pics/tnt_logo_text.png
new file mode 100644
index 00000000..4a829d20
Binary files /dev/null and b/docs/source/pics/tnt_logo_text.png differ
diff --git a/docs/source/pics/tnt_run.gif b/docs/source/pics/tnt_run.gif
new file mode 100644
index 00000000..cc39935a
Binary files /dev/null and b/docs/source/pics/tnt_run.gif differ
diff --git a/docs/source/quick_start.rst b/docs/source/quick_start.rst
new file mode 100644
index 00000000..36225ff7
--- /dev/null
+++ b/docs/source/quick_start.rst
@@ -0,0 +1,455 @@
+.. _quick-start-label:
+
+Quick Start
+===========
+
+This section explains how to get started using Tarantella to distributedly
+train an existing TensorFlow 2/Keras model.
+First, we will examine what changes have to be made to your code, before we will look into
+the execution of your script with ``tarantella`` on the command line.
+Finally, we will present the features Tarantella currently supports and
+what important points need to be taken into account when using Tarantella.
+
+Code example: LeNet-5 on MNIST
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+After having :ref:`build and installed ` Tarantella
+we are ready to add distributed training support to an existing TensorFlow 2/Keras model.
+We will first illustrate all the necessary steps, using the well-known example of
+**LeNet-5** on the **MNIST** dataset. Although this is not necessarily a good use case
+to take full advantage of Tarantella's capabilities, it will allow you to simply
+copy-paste the code snippets and try them out, even on your laptop.
+
+**Let's get started!**
+
+.. literalinclude:: quick_start_model.py
+ :language: Python
+ :linenos:
+ :emphasize-lines: 3,9,13
+
+As you can see from the marked lines in the code snippet,
+you only need to add *3 lines of code* to train LeNet-5 distributedly using Tarantella!
+Let us go through the code in some more detail, in order to understand what is going on.
+
+First we need to import the Tarantella library:
+
+.. code-block:: Python
+
+ import tarantella as tnt
+
+Having done that we need to initialize the library (which will setup the communication infrastructure):
+
+.. code-block:: Python
+
+ tnt.init()
+
+Note that this should be done before executing any other code. Next, we need to wrap the
+``keras.Model`` object, generated by ``lenet5_model_generator()``, into a ``tnt.Model`` object:
+
+.. code-block:: Python
+
+ model = tnt.Model(lenet5_model_generator())
+
+**That's it!**
+
+All the necessary steps to distribute training and datasets will now automatically be handled by Tarantella.
+In particular, we still run ``model.compile`` on the new ``model`` to generate a compute graph,
+just as we would have done with a typical Keras model.
+
+Next, we load the MNIST data for training and testing, and
+create ``Dataset`` s from it. Note that we ``batch`` the dataset for training.
+This will guarantee that Tarantella is able to distribute the data later on in the correct way.
+Also note that the ``batch_size`` used here, is the same as for the original model,
+that is the *global* batch size. For details concerning local and global batch sizes have a look
+:ref:`here `.
+
+Now we are able to train our ``model`` using ``model.fit``, in the same familiar
+way used by the standard Keras interface. Note, however, that Tarantella is taking care of proper
+distribution of the ``train_dataset`` in the background. All the possibilities of how to
+feed datasets to Tarantella are explained in more detail below.
+Lastly, we can evaluate the final accuracy of our ``model`` on the ``test_dataset`` using
+``model.evaluate``.
+
+To test and run ``tarantella`` in the next section, you can find a full version of the above example
+`here `__.
+
+Executing your model with ``tarantella``
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Next, let's execute our model distributedly using ``tarantella`` on the command line.
+The simplest way to do that is by passing the Python script of the model to ``tarantella``:
+
+.. code-block:: bash
+
+ tarantella -- model.py
+
+This will execute our model distributedly on a single node, using all the available GPUs.
+In case no GPUs can be found, ``tarantella`` will executed in serial mode on the CPU,
+and an ``WARNING`` message will be issued. In case you have GPUs available, but
+want to execute ``tarantella`` on CPUs nonetheless, you can specify the ``--no-gpu`` option.
+
+.. code-block:: bash
+
+ tarantella --no-gpu -- model.py
+
+We can also set command line parameters for the python script ``model.py``, which have to
+succeed the name of the script:
+
+.. code-block:: bash
+
+ tarantella --no-gpu -- model.py --batch_size=64 --learning_rate=0.01
+
+On a single node, we can also explicitly specify the number of TensorFlow instances
+we want to use. This is done with the ``-n`` option:
+
+.. code-block:: bash
+
+ tarantella -n 4 -- model.py --batch_size=64
+
+Here, ``tarantella`` would try to execute distributedly on 4 GPUs.
+If there are not enough GPUs available, ``tarantella`` will print a ``WARNING``
+and run 4 instances of TensorFlow on the CPU instead.
+If there are no GPUs installed or the ``--no-gpu`` option is use,
+``tarantella`` will not print a ``WARNING``.
+
+Next, let's run ``tarantella`` on multiple nodes. In order to do this,
+we need to provide ``tarantella`` with a ``hostfile`` that contains
+the ``hostname`` s of the nodes that we want to use:
+
+.. code-block:: bash
+
+ $ cat hostfile
+ name_of_node_1
+ name_of_node_2
+
+With this ``hostfile`` we can run ``tarantella`` on multiple nodes:
+
+.. code-block:: bash
+
+ tarantella --hostfile hostfile -- model.py
+
+In this case, ``tarantella`` uses *all* GPUs it can find.
+If no GPUs are available, ``tarantella`` will start *one* TensorFlow instance
+per node on the CPUs, and will issue an ``WARNING`` message.
+Again, this can be disabled by explicitly using the ``--no-gpu``
+option.
+
+As before, you can specify the number of GPUs/CPUs used per node
+explicitly with the option ``--n-per-node=``:
+
+.. code-block:: bash
+
+ tarantella --hostfile hostfile --n-per-node=4 --no-gpu -- model.py --batch_size=64
+
+In this example, ``tarantella`` would execute 4 instances of TensorFlow on the CPUs
+of each node specified in ``hostfile``.
+
+.. caution::
+
+ ``tarantella`` requires all the names in the ``hostfile`` be **unique**,
+ and all nodes be **homogeneous** (number and type of CPUs and GPUs).
+
+In addition, ``tarantella`` can be run with different levels of logging output.
+The log-levels that are available are ``INFO``, ``WARNING``, ``DEBUG`` and ``ERROR``,
+and can be set with ``--log-level``:
+
+.. code-block:: bash
+
+ tarantella --hostfile hostfile --log-level=INFO -- model.py
+
+By default, ``tarantella`` will log on the :ref:`master rank ` only.
+This can be changed by using the ``--log-on-all-devices`` option which will print
+log messages for each :ref:`rank ` individually.
+
+Similarly, by default ``tarantella`` will print outputs from functions like ``fit``,
+``evaluate`` and ``predict``, as well as callbacks only on the master rank.
+Sometimes, it might be useful to print outputs from all devices (e.g., for debugging),
+which can be switched on with the ``--output-on-all-devices`` option.
+
+``tarantella`` uses GPI-2's ``gaspi_run`` internally, taking care of ``export`` ing
+environment variables, and generating an execution script from the user inputs.
+Details of this process can be monitored using the ``--dry-run`` option.
+
+Lastly, you can overwrite the *Tensor Fusion* threshold ``tarantella`` uses
+with ``--fusion-threshold FUSION_THRESHOLD_KB``
+(cf. :ref:`here ` and :ref:`here `),
+and set and number of environment variables, most notably
+``TNT_TENSORBOARD_ON_ALL_DEVICES``, as explained
+:ref:`here `.
+
+Save and load Tarantella models
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Storing and loading your trained ``Tarantella.Model`` is very simple.
+
+Tarantella supports all the different ways, in which you can load and store a ``keras.Model``
+(for a guide look for instance `here `__).
+In particular, you can:
+
+* save the whole model (including the architecture, the weights and the state of the optimizer)
+* save the model's architecture/configuration only
+* save the model's weights only
+
+Whole-model saving and loading
+------------------------------
+
+Saving the entire model including the architecture, weights and optimizer can be done via
+
+.. code-block:: python
+
+ model = ... # get `tnt.Model`
+ model.save('path/to/location')
+
+Alternatively, you could use ``tnt.models.save_model('path/to/location')``, which works
+on both ``keras.Model`` s and ``tnt.Model`` s.
+
+You can than load your model back using
+
+.. code-block:: python
+
+ import tarantella as tnt
+ model = tnt.models.load_model('path/to/location')
+
+which will return an instance of ``tnt.Model``.
+
+.. caution::
+
+ At the moment, you will need to re-compile your model after loading.
+
+This is again done with
+
+.. code-block:: python
+
+ model.compile(optimizer = keras.optimizers.SGD(learning_rate=args.learning_rate),
+ loss = keras.losses.SparseCategoricalCrossentropy(),
+ metrics = [keras.metrics.SparseCategoricalAccuracy()])
+
+or similar.
+
+Architecture saving and loading
+-------------------------------
+
+If you only want to save the configuration (that is the architecture) of your model
+(in memory), you can use one of the following functions:
+
+* ``tnt.Model.get_config``
+* ``tnt.Model.to_json``
+* ``tnt.Model.to_yaml``
+
+The architecture without its original weights and optimizer can then be restored
+using:
+
+* ``tnt.models.model_from_config`` / ``tnt.Model.from_config``
+* ``tnt.models.model_from_json``
+* ``tnt.models.model_from_yaml``
+
+respectively.
+Here is an example:
+
+.. code-block:: python
+
+ import tarantella as tnt
+ model = ... # get `tnt.Model`
+ config = model.get_config()
+ new_model = tnt.models.model_from_config(config)
+
+The same can be achieved through cloning:
+
+.. code-block:: python
+
+ import tarantella as tnt
+ model = ... # get `tnt.Model`
+ new_model = tnt.models.clone_model(model)
+
+
+Weights saving and loading
+--------------------------
+
+Storing and loading the weights of a model to/from memory can be done
+using the functions ``tnt.Model.get_weights`` and ``tnt.Model.set_weights``,
+respectively. Saving and loading weights to/from disk is done
+using the functions ``tnt.Model.save_weights`` and ``tnt.Model.load_weights``,
+respectively.
+
+Here is an example how this can be used to restore a model:
+
+.. code-block:: python
+
+ import tarantella as tnt
+ model = ... # get `tnt.Model`
+ config = model.get_config()
+ weights = model.get_weights()
+
+ # initialize a new model with original model's weights
+ new_model = tnt.models.model_from_config(config)
+ new_model.set_weights(weights)
+
+.. _checkpointing-via-callbacks-label:
+
+Checkpointing via callbacks
+---------------------------
+
+Apart from saving and loading models manually, Tarantella also supports checkpointing
+via Keras' ``ModelCheckpoint`` callback, as it is described for instance
+`here `__.
+
+.. code-block:: python
+
+ import tensorflow as tf
+ import tarantella as tnt
+
+ model = ... # get `tnt.Model`
+
+ checkpoint_path = 'path/to/checkpoint/location'
+ model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
+ filepath=checkpoint_path, monitor='val_acc', verbose=1, save_best_only=False,
+ save_weights_only=False, mode='auto', save_freq='epoch', options=None)
+
+ model.fit(train_dataset,
+ validation_data = val_dataset,
+ epochs = 2,
+ callbacks = [model_checkpoint_callback])
+
+
+.. note::
+
+ All saving to the filesystem (including ``tnt.Model.save`` and ``tnt.Model.save_weights``)
+ by Tarantella will only be done on the master rank.
+
+This is the default and will yield correct behavior when you are using a distributed filesystem.
+If you wish to explicitly save on all devices you can pass ``tnt_save_all_devices = True``
+to ``tnt.Model.save``, ``tnt.Model.save_weights`` and ``tnt.models.save_model``.
+
+
+.. _using-distributed-datasets-label:
+
+Using distributed datasets
+^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+This section explains what needs to be done in order to use Tarantella's distributed datasets correctly.
+
+The recommended way in which to provide your dataset to Tarantella is by passing a
+*batched* ``tf.data.Dataset`` to ``tnt.Model.fit``.
+In order to do this, create a ``Dataset`` and apply the ``batch``
+`transformation `_
+using the (global) batch size to it. However, do not provide a value to ``batch_size``
+in ``tnt.Model.fit``, which would lead to double batching, and thus modified shapes
+for the input data.
+
+Tarantella also supports batched and unbatched ``Dataset`` s in ``tnt.Model.fit``
+when setting the ``tnt_micro_batch_size`` argument. This can be useful to obtain
+maximal performance in multi-node execution, as explained
+:ref:`here `. Keep in mind however, that Tarantella still expects
+the ``Dataset`` to be batched with the global batch size, and that the micro-batch
+size has to be consistent with the global batch size. [#footnote_consistent]_
+This is why, it is recommended to use an unbatched ``Dataset`` when setting
+a ``tnt_micro_batch_size`` explicitly.
+
+Tarantella does not support any other way to feed data to ``fit`` at the moment.
+In particular, Numpy arrays, TensorFlow tensors and generators are not supported.
+
+Tarantella's automatic data distribution can be switched off by passing
+``tnt_distribute_dataset=False`` in ``tnt.Model.fit``, in which case Tarantella
+will issue an ``INFO`` message.
+If a validation dataset is passed to ``tnt.Model.fit``, it should also be batched
+with the global batch size. You can similarly switch off its automatic
+micro-batching mechanism by setting ``tnt_distribute_validation_dataset=False``.
+
+There are a few important points when using distributed datasets in Tarantella:
+
+.. note::
+
+ Batch size must be a multiple of the number of devices used.
+
+This issue will be fixed in the next release.
+
+.. note::
+
+ The last incomplete batch is always dropped.
+
+We recommend to use ``drop_remainder=True`` when generating a ``Dataset``.
+If ``drop_remainder`` is set to ``False``, Tarantella will ignore it
+and issue a ``WARNING`` message. This behavior will be fixed in the next release.
+
+.. note::
+
+ When using ``shuffle`` without a ``seed``, Tarantella will use a fixed default ``seed``.
+
+This guarantees that the input data is shuffled the same way on all devices,
+when no ``seed`` is given, which is necessary for consistency.
+However, when a random ``seed`` is provided by the user, Tarantella will use that one instead.
+
+.. _callbacks-label:
+
+Callbacks
+^^^^^^^^^
+
+At the moment, Tarantella fully supports 3 of the
+`Keras callbacks `__:
+
+* ``tf.keras.callbacks.LearningRateScheduler``
+* ``tf.keras.callbacks.ModelCheckpoint``
+* ``tf.keras.callbacks.TensorBoard``
+
+The ``LearningRateScheduler`` takes a ``schedule`` which will change the learning rate
+on each of the devices used (for detailed explanation, cf.
+`here `__
+and
+`here `__
+).
+If ``verbose=1`` is set, Tarantella will only print on one device by default.
+This behavior can be changed by passing ``--output-on-all-devices`` to ``tarantella``.
+
+``ModelCheckpoint`` can be used to automatically checkpoint the state of the model
+during training. For an example look :ref:`here `,
+and into the
+`Keras documentation `__.
+
+The ``TensorBoard`` callback can be used to collect training information for visualization
+in `TensorBoard `__. By default, Tarantella
+will only collect (device local) information on one device. If you want to collect
+the local information on all devices use the environment variable ``TNT_TENSORBOARD_ON_ALL_DEVICES``:
+
+.. code-block:: bash
+
+ TNT_TENSORBOARD_ON_ALL_DEVICES=true tarantella -- model.py
+
+.. note::
+
+ At the moment, all of the other Keras callbacks will be executed on all devices with
+ local information only.
+
+For instance, the ``BaseLogger`` callback will be executed on each and every rank,
+and will log the acculumated metric averages for the local (micro-batch) information.
+
+Important points
+^^^^^^^^^^^^^^^^
+
+There is a number of points you should be aware of when using Tarantella.
+
+.. note::
+
+ ``tnt.init()`` needs to be called **after** ``import tarantella as tnt``, but **before**
+ any other statement.
+
+This will make sure the GPI-2 communication infrastructure is correctly initialized.
+
+.. note::
+
+ Tarantella does not support custom training loops.
+
+Instead of using custom training loops, please use ``Model.fit(...)``.
+
+.. note::
+
+ Tarantella supports all
+ `TensorFlow optimizers `_
+ with the exception of ``tf.keras.optimizers.Ftrl``.
+
+Since the ``Ftrl`` optimizer does not use batches, it is not supported in Tarantella.
+
+
+.. rubric:: Footnotes
+
+.. [#footnote_consistent] That is, the global batch size must equal the micro batch size times
+ the number of devices used.
diff --git a/docs/source/quick_start_model.py b/docs/source/quick_start_model.py
new file mode 100644
index 00000000..7a345bae
--- /dev/null
+++ b/docs/source/quick_start_model.py
@@ -0,0 +1,39 @@
+import tensorflow as tf
+from tensorflow import keras
+import tarantella as tnt
+
+# Skip function implementations for brevity
+[...]
+
+# Initialize Tarantella (before doing anything else)
+tnt.init()
+args = parse_args()
+
+# Create Tarantella model
+model = tnt.Model(lenet5_model_generator())
+
+# Compile Tarantella model (as with Keras)
+model.compile(optimizer = keras.optimizers.SGD(learning_rate=args.learning_rate),
+ loss = keras.losses.SparseCategoricalCrossentropy(),
+ metrics = [keras.metrics.SparseCategoricalAccuracy()])
+
+# Load MNIST dataset (as with Keras)
+shuffle_seed = 42
+(x_train, y_train), (x_val, y_val), (x_test, y_test) = \
+ mnist_as_np_arrays(args.train_size, args.val_size, args.test_size)
+
+train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
+train_dataset = train_dataset.shuffle(len(x_train), shuffle_seed)
+train_dataset = train_dataset.batch(args.batch_size)
+train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE)
+
+test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
+test_dataset = test_dataset.batch(args.batch_size)
+
+# Train Tarantella model (as with Keras)
+model.fit(train_dataset,
+ epochs = args.number_epochs,
+ verbose = 1)
+
+# Evaluate Tarantella model (as with Keras)
+model.evaluate(test_dataset, verbose = 1)
diff --git a/docs/source/tutorials.rst b/docs/source/tutorials.rst
new file mode 100644
index 00000000..c7c9b189
--- /dev/null
+++ b/docs/source/tutorials.rst
@@ -0,0 +1,304 @@
+Tutorials
+=========
+
+This section delves into more advanced usage of Tarantella with the help of
+state-of-the-art models for two widely-used applications in Deep Learning:
+
+* Image classification: ResNet-50
+* Machine translation: Transformer
+
+The models shown here are adapted from the
+`TensorFlow Model Garden `_.
+While the model implementations and hyperparameters are unchanged to preserve
+compatibility with the TensorFlow official models, we provide simplified training
+schemes that allow for a seemless transition from basic serial training to distributed
+data parallelism using Tarantella.
+
+
+Prerequisites
+-------------
+
+The tutorial models can be downloaded from the
+`Tnt Models repository `_
+
+.. code-block:: bash
+
+ export TNT_MODELS_PATH=/your/installation/path
+ cd ${TNT_MODELS_PATH}
+ git clone https://github.com/cc-hpc-itwm/tarantella_models
+
+To use these models, install the the following dependencies:
+
+* TensorFlow 2.2.1
+* Tarantella 0.6.0
+
+For a step-by-step installation, follow the :ref:`installation-label` guide.
+In the following we will assume that TensorFlow was installed in a ``conda``
+environment called ``tarantella``.
+
+Now we can install the final dependency,
+`TensorFlow official Model Garden `__:
+
+.. code-block:: bash
+
+ conda activate tarantella
+ pip install tf-models-official==2.2.1
+
+
+.. _resnet50-label:
+
+ResNet-50
+---------
+
+Deep Residual Networks (ResNets) represented a breakthrough in the field of
+computer vision, enabling deeper and more complex deep convolutional networks.
+Introduced in [He]_, ResNet50 has become a standard model for image classification
+tasks, and has be shown to scale to very large number of nodes in data parallel
+training [Goyal]_.
+
+Run Resnet-50 with Tarantella
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+Before running the model, we need to add it to the existing ``PYTHONPATH``.
+
+.. code-block:: bash
+
+ export PYTHONPATH=${TNT_MODELS_PATH}/models/resnet:${PYTHONPATH}
+
+Furthermore, the ``ImageNet`` dataset needs to be installed and available on
+all the nodes that we want to use for training.
+TensorFlow provides convenience scripts to download datasets, in their ``datasets``
+package that is installed as a dependency for the TensorFlow Model Garden.
+Install ImageNet to your local machine as described
+`here `_.
+
+.. code-block:: bash
+
+ export TNT_DATASETS_PATH=/path/to/downloaded/datasets
+
+ python -m tensorflow_datasets.scripts.download_and_prepare \
+ --datasets=imagenet2012 --data_dir=${TNT_DATASETS_PATH}
+
+
+Let's assume we have access to two nodes (saved in ``hostfile``) equipped with 4 GPUs each.
+We can now simply run the ResNet-50 as follows:
+
+.. code-block:: bash
+
+ tarantella --hostfile ./hostfile --devices-per-node 4 \
+ -- ${TNT_MODELS_PATH}/models/resnet/resnet50_tnt.py --data_dir=${TNT_DATASETS_PATH} \
+ --batch_size=512 \
+ --train_epochs=90 \
+ --epochs_between_evals=10
+
+The above command will train a ResNet-50 models on the 8 devices available in parallel
+for ``90`` epochs, as suggested in [Goyal]_ to achieve convergence.
+The ``--epochs_between_evals`` parameter specifies the frequency of evaluations of the
+``validation`` data performed in between training epochs.
+
+Note the ``--batch_size`` parameter, which specifies the global batch size used in training.
+
+Implementation overview
+^^^^^^^^^^^^^^^^^^^^^^^
+We will now look closer into the implementation of the ResNet-50 training scheme.
+The main training steps reside in the ``models/resnet/resnet50_tnt.py`` file.
+
+The most important step in enabling data parallelism with Tarantella is
+to wrap the Keras model:
+
+.. code-block:: python
+
+ model = resnet_model.resnet50(num_classes=tf_imagenet_preprocessing.NUM_CLASSES)
+ model = tnt.Model(model)
+
+The following operations can be used for training the model serially, as they do not
+require any change.
+In particular, the ImageNet dataset is loaded and preprocessed as follows:
+
+.. code-block:: python
+
+ train_dataset = imagenet_preprocessing.input_fn(is_training=True,
+ data_dir=flags_obj.data_dir,
+ batch_size=flags_obj.batch_size,
+ shuffle_seed = 42,
+ drop_remainder=True)
+
+The
+`imagenet_preprocessing.input_fn
+`_
+function takes the input files in ``data_dir``, loads the training samples and processes
+them into TensorFlow datasets.
+
+The user only needs to pass the global ``batch_size`` value, and the Tarantella
+framework will ensure that the dataset is properly distributed among devices,
+such that:
+
+ * each device will process an independent set of samples
+ * each device will group the samples into micro batches, where the micro-batch
+ size will be computed as ``batch_size / num_devices``
+ * each device will apply the same set of transformation to its imput samples as
+ specified in the ``input_fn`` function.
+
+Before starting the training, the model is compiled to use a standard Keras optimizer
+and loss.
+
+.. code-block:: python
+
+ model.compile(optimizer=optimizer,
+ loss='sparse_categorical_crossentropy',
+ metrics=(['sparse_categorical_accuracy']))
+
+We provide flags to enable the most commonly used Keras ``callbacks``, such as
+the ``TensorBoard`` profiler, which can simply be passed to the ``fit`` function
+of the Tarantella model.
+
+.. code-block:: python
+
+ callbacks.append(tf.keras.callbacks.TensorBoard(log_dir=flags_obj.model_dir,
+ profile_batch=2))
+
+If model checkpointing is required, it can be enabled through the ``ModelCheckpoint``
+callback as usual (cf. :ref:`checkpointing models with Tarantella `).
+
+.. code-block:: python
+
+ callbacks.append(tf.keras.callbacks.ModelCheckpoint(ckpt_full_path, save_weights_only=True))
+
+
+There is no need for any further changes to proceed with training:
+
+.. code-block:: python
+
+ history = model.fit(train_dataset,
+ epochs=flags_obj.train_epochs,
+ callbacks=callbacks,
+ validation_data=validation_dataset,
+ validation_freq=flags_obj.epochs_between_evals,
+ verbose=1)
+
+.. todo::
+
+ Advanced topics:
+
+ * scaling batch size with number of ranks (-> only mention here & link to advanced topics)
+ * introduce learning rate warm up
+ * introduce learning rate scaling (with #ranks)
+
+
+.. _transformer-label:
+
+Transformers
+------------
+
+The Transformer is a Deep Neural Network widely used in the field of natural language processing (NLP),
+in particular for tasks such as machine translation.
+It was first proposed by [Vaswani]_.
+
+Run the Transformer with Tarantella
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+The Tranformer training scheme can be found
+`here `__,
+and has to be added to
+the existing ``PYTHONPATH``:
+
+.. code-block:: bash
+
+ export PYTHONPATH=${TNT_MODELS_PATH}/models/transformer:${PYTHONPATH}
+
+We will follow the training procedure presented in [Vaswani]_, where the authors
+show results for training the `big` variant of the Transformer model on
+a machine translation dataset called
+`WMT14 `_.
+
+To install the dataset, we will use the Tensorflow ``datasets`` package, which
+should have been already installed in your ``conda`` environment as a
+dependency for the TensorFlow Model Garden, and download the English-German
+dataset to match the results by [Vaswani]_.
+Detailed instructions on how to obtain the dataset are provided in the
+`TensorFlow documentation `_.
+
+Now we can start training.
+Once again, let's assume we have access to two nodes (specified in ``hostfile``)
+equipped with 4 GPUs each.
+
+.. code-block:: bash
+
+ export WMT14_PATH=/path/to/the/installed/dataset
+
+ tarantella --hostfile ./hostfile --devices-per-node 4 \
+ -- ${TNT_MODELS_PATH}/models/transformer/transformer_tnt.py \
+ --data_dir=${WMT14_PATH} \
+ --vocab_file=${WMT14_PATH}/vocab.ende.32768
+ --bleu_ref=${WMT14_PATH}/newstest2014.de
+ --bleu_source=${WMT14_PATH}/newstest2014.en
+ --param_set=big
+ --train_epochs=30
+ --batch_size=32736
+
+The above command will select the ``big`` model implementation and train it
+distributedly on the 8 specified devices.
+To reach the target accuracy, [Vaswani]_ specifies that the model needs to be
+trained for ``30`` epochs.
+
+The Transformer requires access to a vocabulary file, which contains all the
+tokens derived from the dataset. This is provided as the ``vocab_file`` parameter
+and is part of the pre-processed dataset.
+
+After training, one round of evaluation is conducted using the ``newstest2014``
+dataset to translate English sentences into German.
+
+Implementation overview
+^^^^^^^^^^^^^^^^^^^^^^^
+
+The Transformer model itself is implemented and imported from the
+`TensorFlow Model Garden
+`__.
+The training procedure and dataset loading and pre-processing do not require
+extensive changes to work with Tarantella. However, we provide a simplified
+version to highlight the usage of Tarantella with Keras training loops.
+
+Thus, the Keras transformer model is created in
+``models/transformer/transformer_tnt.py`` and wrapped into a Tarantella model:
+
+.. code-block:: python
+
+ model = resnet_model.resnet50(num_classes=tf_imagenet_preprocessing.NUM_CLASSES)
+ model = tnt.Model(model)
+
+Data is loaded as follows, without any specific modification to trigger
+distributed training:
+
+.. code-block:: python
+
+ train_ds = data_pipeline.train_input_fn(self.params)
+
+Here, the ``data_pipeline.train_input_fn`` reads in the dataset and applies a series
+of transformations to convert it into a batched set of sentences.
+The advantage of using the *automatic dataset distribution* mechanism of Tarantella
+is that users can reason about their I/O pipeline without taking care of the details
+about how to distribute it.
+Note however, that the batch size has to be a multiple of the number of ranks, so
+that it can be efficiently divided into micro-batches.
+
+Next, the user can also create callbacks, which can then be simply passed on to
+the training function.
+
+.. code-block:: python
+
+ callbacks.append(tf.keras.callbacks.TensorBoard(log_dir=self.flags_obj.model_dir))
+
+Finally, we can call ``model.fit`` to start distributed training on all devices:
+
+.. code-block:: python
+
+ history = model.fit(train_ds,
+ epochs=self.params["train_epochs"],
+ callbacks=callbacks,
+ verbose=1)
+
+.. todo::
+
+ Important points
+
+ * Mixing Keras and Tarantella models
+
diff --git a/docs/source/why_tarantella.rst b/docs/source/why_tarantella.rst
new file mode 100644
index 00000000..4313dd35
--- /dev/null
+++ b/docs/source/why_tarantella.rst
@@ -0,0 +1,44 @@
+Why Tarantella?
+===============
+
+Tarantella is an open-source Deep Learning framework that focuses on providing fast, scalable and
+efficient training of Deep Neural Networks (DNNs) on High Performance Computing (HPC) clusters.
+
+Goals
+-----
+
+Tarantella is designed to meet the following goals:
+
+.. code-block:: text
+
+ Tarantella...
+
+ 1. ...provides strong scalability
+ 2. ...is easy to use
+ 3. ...follows a synchronous training scheme
+ 4. ...integrates well with existing models
+ 5. ...provides support for GPU and CPU systems
+
+Tarantella provides close to linear speed-up for the training of common Deep Learning architectures,
+thus considerably reducing the required time-to-accuracy in many Deep Learning workflows.
+To make this capability accessible to as many users as possible, Tarantella's interface
+is designed such that its use does not require any expertise in HPC or parallel computing.
+
+To allow integrating Tarantella into any TensorFlow-based Deep Learning workflow,
+we put special emphasis on strictly following the synchronous optimization scheme
+used to train DNNs. This guarantees that results obtained in serial execution can be
+reproduced when using distributed training
+(cf. however :ref:`these guidelines `),
+so that computation can be scaled up at any point in time without losing reproducibility
+of the results.
+
+Furthermore, we made sure that existing TensorFlow 2/Keras
+models can be made ready for distributed training with minimal effort
+(follow the :ref:`Quick Start guide ` to learn more).
+Tarantella supports distributed training on GPU and pure CPU clusters,
+independently of the hardware vendors.
+
+.. todo::
+
+ Performance Results
+
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
new file mode 100644
index 00000000..4870d072
--- /dev/null
+++ b/src/CMakeLists.txt
@@ -0,0 +1,17 @@
+set(TNT_PYTHON_DIRS
+ ${SRC_DIR}/tarantella
+ ${SRC_DIR}/runtime
+ ${SRC_DIR}/gpi_comm_lib/tf_ops/tnt_tfops)
+
+install(DIRECTORY ${TNT_PYTHON_DIRS}
+ DESTINATION ${INSTALL_LIB_DIR}/python
+ FILES_MATCHING PATTERN "*.py")
+
+install(PROGRAMS ${SRC_DIR}/bin/tarantella
+ DESTINATION ${INSTALL_BIN_DIR})
+
+set(VERSION_FILE_TEMPLATE ${CMAKE_SOURCE_DIR}/cmake/version.py.in)
+set(VERSION_FILE ${CMAKE_BUILD_DIR}/version.py)
+configure_file(${VERSION_FILE_TEMPLATE} ${VERSION_FILE} @ONLY)
+install(FILES ${VERSION_FILE}
+ DESTINATION ${INSTALL_LIB_DIR}/python)
diff --git a/src/bin/tarantella b/src/bin/tarantella
new file mode 100755
index 00000000..950e3531
--- /dev/null
+++ b/src/bin/tarantella
@@ -0,0 +1,198 @@
+#!/usr/bin/env python
+import argparse
+import logging
+import os
+import shutil
+import subprocess
+import sys
+
+TNT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+LIB_DIR = os.path.join(TNT_DIR, "lib/tarantella")
+PYLIB_DIR = os.path.join(TNT_DIR, "lib/tarantella/python")
+sys.path.insert(0, LIB_DIR)
+sys.path.insert(0, PYLIB_DIR)
+
+try:
+ from version import tnt_version
+except:
+ tnt_version = "Unknown version"
+
+try:
+ import runtime
+except ModuleNotFoundError as e:
+ raise RuntimeError("[TNT_CLI] Cannot find Tarantella `runtime` module; \
+make sure the `tarantella` script is started from an installed version.") from e
+
+import runtime.file_management as file_man
+import runtime.logging_config as logging_config
+import runtime.platform_config as platform_config
+import runtime.environment_config as env_config
+from runtime import logger
+
+def parse_args():
+ parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter)
+ singlenode_group = parser.add_argument_group('Single-node execution')
+ singlenode_group.add_argument("-n",
+ help="number of TensorFlow instances to start on the local node",
+ dest = "npernode",
+ metavar = "N",
+ type = int,
+ default = None)
+ multinode_group = parser.add_argument_group('Multi-node execution')
+ multinode_group.add_argument("--hostfile",
+ dest = "hostfile",
+ help="path to the list of nodes (hostnames) on which to execute the SCRIPT",
+ default = None)
+ multinode_group.add_argument("--n-per-node", "--devices-per-node",
+ help="number of devices (i.e., either GPUs or processes on CPUs) to be used on each node",
+ dest = "npernode",
+ type = int,
+ default = None)
+
+ parser.add_argument("--no-gpu", "--no-gpus",
+ help="disallow GPU usage",
+ dest = "use_gpus",
+ action='store_false',
+ default = True)
+ parser.add_argument("--output-on-all-devices",
+ help="enable output on all devices (e.g., training info)",
+ dest = "output_all",
+ action='store_true',
+ default = False)
+ parser.add_argument("--log-on-all-devices",
+ help="enable library logging messages on all devices",
+ dest = "log_all",
+ action='store_true',
+ default = False)
+ log_levels = ('DEBUG', 'INFO', 'WARNING', 'ERROR')
+ parser.add_argument('--log-level', default='WARNING', choices=log_levels,
+ help = "logging level for library messages")
+ parser.add_argument("--fusion-threshold",
+ help="tensor fusion threshold [kilobytes]",
+ dest = "fusion_threshold_kb",
+ type = int,
+ default = None)
+ parser.add_argument("--dry-run",
+ help="print generated files and execution command",
+ dest = "dry_run",
+ action='store_true',
+ default = False)
+ parser.add_argument("--version",
+ action='version',
+ version=generate_version_message())
+ parser.add_argument('script', nargs='+',metavar='-- SCRIPT')
+ args = parser.parse_args()
+ return args
+
+def tnt_run_message(command_list, hostfile_path, exec_script_path):
+ msg = ""
+ if not hostfile_path is None:
+ msg += "\n{}\nGenerated hostfile:\n".format("="*80)
+ with open(hostfile_path, 'r') as f:
+ msg += "============= {} =============\n{}\n".format(hostfile_path,
+ "".join(f.readlines()))
+ if not exec_script_path is None:
+ msg += "\n{}\nGenerated script:\n".format("="*80)
+ with open(exec_script_path, 'r') as f:
+ msg += "============= {} =============\n{}\n".format(exec_script_path,
+ "".join(f.readlines()))
+ msg += "\n{}".format("="*80)
+ msg += "\nCommand:\n\t{}\n".format(" ".join(command_list))
+ return msg
+
+def generate_dry_run_message(command_list, hostfile_path, exec_script_path):
+ msg = "\n{}".format("="*80)
+ msg += "\n{0}{1}DRY RUN {1}{0}\n".format("="*6, " "*30)
+ msg += tnt_run_message(command_list, hostfile_path, exec_script_path)
+ return msg
+
+def generate_run_error_message(e, hostfile_path = None,
+ executed_script_path = None):
+ error_string = ""
+ if not e.stdout is None:
+ error_string += "============= STDOUT =============\n{}\n".format(e.stdout)
+ if not e.stderr is None:
+ error_string += "============= STDERR =============\n{}\n".format(e.stderr)
+ error_string += tnt_run_message(e.cmd, hostfile_path = hostfile_path,
+ exec_script_path = executed_script_path)
+ error_string += "[TNT_CLI] Execution failed with status {}".format(e.returncode)
+ return error_string
+
+def generate_version_message():
+ msg = ["Tarantella {}".format(tnt_version),
+ "Path: {}".format(os.path.dirname(os.path.abspath(__file__))),
+ "Copyright (C) 2020 Fraunhofer"]
+ return "\n".join(msg)
+
+class Tarantella:
+ def __init__(self, hostlist, num_gpus_per_node, num_cpus_per_node, args):
+ self.args = args
+
+ self.hostlist = hostlist
+ self.command_list = args.script
+ self.num_gpus_per_node = num_gpus_per_node
+
+ # compute number of ranks per node to create the hostfile
+ npernode = num_gpus_per_node
+ device_type = "GPUs"
+ if npernode == 0:
+ npernode = num_cpus_per_node
+ device_type = "CPU processes"
+
+ self.nranks = len(hostlist) * npernode
+ self.hostfile = file_man.HostFile(self.hostlist, npernode)
+ self.executable_script = self.generate_executable_script()
+
+ logger.info("Starting Tarantella on {} devices ({} nodes x {} {})".format(self.nranks,
+ len(self.hostlist), npernode, device_type))
+
+
+ def generate_executable_script(self):
+ # create execution script
+ header = "#!/bin/bash\n"
+ header += "cd {}".format(os.path.abspath(os.getcwd()))
+
+ environment = env_config.gen_exports_from_dict(env_config.collect_environment_variables()) + \
+ env_config.gen_exports_from_dict(env_config.collect_tensorflow_variables()) + \
+ env_config.gen_exports_from_dict(env_config.collect_tarantella_variables()) + \
+ env_config.gen_exports_from_dict(env_config.get_tnt_variables_from_args(self.args)) +\
+ env_config.gen_exports_from_dict(env_config.get_tnt_gpus(self.num_gpus_per_node))
+
+ command = "python {}".format(' '.join(self.command_list))
+ return file_man.GPIScriptFile(header, environment, command, dir = os.getcwd())
+
+ def run(self, dry_run = False):
+ with self.hostfile, self.executable_script:
+ command_list = ["gaspi_run", "-n", str(self.nranks),
+ "-m", self.hostfile.name,
+ self.executable_script.filename]
+
+ if dry_run:
+ print(generate_dry_run_message(command_list, self.hostfile.name,
+ self.executable_script.filename))
+ return
+
+ path_to_gpi = shutil.which("gaspi_run")
+ if path_to_gpi is None:
+ sys.exit("[TNT_CLI] Cannot execute `gaspi_run`; make sure it is added to the current `PATH`.")
+
+ try:
+ result = subprocess.run(command_list,
+ check = True,
+ cwd = os.getcwd(),
+ stdout = None, stderr = None,)
+ except subprocess.CalledProcessError as e:
+ sys.exit(generate_run_error_message(e, self.hostfile.name,
+ self.executable_script.filename))
+
+if __name__ == "__main__":
+ args = parse_args()
+ logging_config.setup_logging(logger, args.log_level)
+
+ nodes_list = platform_config.generate_nodes_list(args.hostfile)
+ num_gpus, num_cpus = platform_config.generate_num_devices_per_node(npernode = args.npernode,
+ use_gpus = args.use_gpus)
+ env_config.update_environment_paths(LIB_DIR)
+
+ tarantella = Tarantella(nodes_list, num_gpus, num_cpus, args)
+ tarantella.run(args.dry_run)
\ No newline at end of file
diff --git a/src/gpi_comm_lib/AtomicCondition.hpp b/src/gpi_comm_lib/AtomicCondition.hpp
new file mode 100644
index 00000000..99fc54d3
--- /dev/null
+++ b/src/gpi_comm_lib/AtomicCondition.hpp
@@ -0,0 +1,29 @@
+#pragma once
+
+#include
+#include
+
+class AtomicCondition
+{
+ public:
+ void notify()
+ {
+ {
+ std::lock_guard lk(lock);
+ done = true;
+ }
+ condition.notify_one();
+ }
+
+ void wait()
+ {
+ std::unique_lock lk(lock);
+ condition.wait(lk, [&done = done]{return done;});
+ done = false;
+ }
+
+ private:
+ std::mutex lock;
+ std::condition_variable condition;
+ bool done;
+};
\ No newline at end of file
diff --git a/src/gpi_comm_lib/CMakeLists.txt b/src/gpi_comm_lib/CMakeLists.txt
new file mode 100644
index 00000000..f31227f3
--- /dev/null
+++ b/src/gpi_comm_lib/CMakeLists.txt
@@ -0,0 +1,34 @@
+include (add_macros)
+
+set(GPI_LIB_MODULE "GPICommLib")
+
+set (GPICOMMLIB_SOURCES
+ ${SRC_DIR}/gpi_comm_lib/distribution/SegmentIDBuilder.cpp
+ ${SRC_DIR}/gpi_comm_lib/distribution/utilities.cpp
+ ${SRC_DIR}/gpi_comm_lib/PipelineCommunicator.cpp
+ ${SRC_DIR}/gpi_comm_lib/SynchCommunicator.cpp
+ ${SRC_DIR}/gpi_comm_lib/TensorBroadcaster.cpp
+)
+
+extended_add_library(NAME gpicommlib
+ NAMESPACE tnt
+ TYPE SHARED
+ SOURCES
+ ${GPICOMMLIB_SOURCES}
+ LIBRARIES
+ tnt::gpiresources
+ tnt::collectives
+ INCLUDE_DIRECTORIES
+ ${SRC_DIR}/gpi_comm_lib/
+ INSTALL
+ INSTALL_DESTINATION
+ ${INSTALL_LIB_DIR}
+ POSITION_INDEPENDENT)
+
+pybind11_add_module(${GPI_LIB_MODULE} MODULE
+ ${SRC_DIR}/gpi_comm_lib/pybind11_wrappers.cpp)
+target_link_libraries(${GPI_LIB_MODULE} PRIVATE pybind11::module
+ tnt::gpicommlib)
+install(TARGETS ${GPI_LIB_MODULE}
+ LIBRARY
+ DESTINATION ${INSTALL_LIB_DIR})
diff --git a/src/gpi_comm_lib/PipelineCommunicator.cpp b/src/gpi_comm_lib/PipelineCommunicator.cpp
new file mode 100644
index 00000000..cf4c328d
--- /dev/null
+++ b/src/gpi_comm_lib/PipelineCommunicator.cpp
@@ -0,0 +1,116 @@
+#include "PipelineCommunicator.hpp"
+
+#include "collectives/barrier/GPIBarrier.hpp"
+#include "distribution/GroupBuilder.hpp"
+#include "gpi/gaspiCheckReturn.hpp"
+
+#include
+
+#include
+
+namespace tarantella
+{
+ PipelineCommunicator::PipelineCommunicator(
+ GPI::Context& context,
+ std::unordered_map const& connection_infos,
+ std::size_t num_micro_batches)
+ : resource_manager(context.get_resource_manager())
+ {
+ for(auto const& [conn_id, conn_info] : connection_infos)
+ {
+ auto const segment_id = conn_info.segment_id;
+ auto const buffer_size = conn_info.microbatched_buffer_size_bytes;
+ auto const segment_size = 2 * num_micro_batches * buffer_size;
+
+ auto const segment_group = resource_manager.make_group({context.get_rank(), conn_info.other_rank});
+ resource_manager.make_segment_resources(segment_id, segment_group, segment_size);
+
+ std::vector send_bufs;
+ std::vector recv_bufs;
+ std::vector notifications;
+ for(std::size_t m_id = 0; m_id < num_micro_batches; ++m_id)
+ {
+ send_bufs.push_back(resource_manager.get_buffer_of_size(segment_id, buffer_size));
+ recv_bufs.push_back(resource_manager.get_buffer_of_size(segment_id, buffer_size));
+ notifications.push_back(resource_manager.get_notification_range(segment_id, 1).first);
+ }
+ connections.emplace(conn_id, SendRecvResources(conn_info.other_rank,
+ send_bufs,
+ recv_bufs,
+ notifications));
+ }
+
+ // Barrier is required, to ensure all ranks have finished registering
+ // their segments to their communication partners
+ collectives::Barrier::GPIBarrierAllRanks barrier;
+ barrier.blocking_barrier();
+ }
+
+ void PipelineCommunicator::non_blocking_send(void* local_send_buf,
+ ConnectionID conn_id,
+ MicrobatchID micro_id)
+ {
+ auto const& local_segment_buf = connections[conn_id].send_bufs[micro_id];
+ auto const& remote_segment_buf = connections[conn_id].recv_bufs[micro_id];
+
+ copy_data_to_segment(local_send_buf, local_segment_buf);
+
+ GPI::gaspiCheckReturn(
+ gaspi_write_notify(local_segment_buf.get_segment_id(),
+ local_segment_buf.get_offset(),
+ connections[conn_id].other_rank,
+ remote_segment_buf.get_segment_id(),
+ remote_segment_buf.get_offset(),
+ local_segment_buf.get_size(),
+ connections[conn_id].notifications[micro_id],
+ micro_id + 1, // to check micro_id at recv (must not be zero)
+ resource_manager.get_queue_id_for_write_notify(),
+ GASPI_BLOCK),
+ "PipelineCommunicator::non_blocking_send");
+ }
+
+ void PipelineCommunicator::blocking_recv(void* local_recv_buf,
+ ConnectionID conn_id,
+ MicrobatchID micro_id)
+ {
+ auto const& local_segment_buf = connections[conn_id].recv_bufs[micro_id];
+ gaspi_notification_id_t received_notification_id = 0;
+ gaspi_notification_t received_notification_value = 0;
+
+ GPI::gaspiCheckReturn(
+ gaspi_notify_waitsome(local_segment_buf.get_segment_id(),
+ connections[conn_id].notifications[micro_id],
+ 1,
+ &received_notification_id,
+ GASPI_BLOCK),
+ "PipelineCommunicator::blocking_recv : gaspi_notify_waitsome");
+ GPI::gaspiCheckReturn(
+ gaspi_notify_reset(local_segment_buf.get_segment_id(),
+ received_notification_id,
+ &received_notification_value),
+ "PipelineCommunicator::blocking_recv : gaspi_notify_reset");
+ if (received_notification_value != micro_id + 1)
+ {
+ throw std::runtime_error("PipelineCommunicator::blocking_recv : \
+ Incorrect notification value received");
+ }
+
+ copy_data_from_segment(local_recv_buf, local_segment_buf);
+ }
+
+ void PipelineCommunicator::copy_data_to_segment(void* local_send_buf,
+ GPI::SegmentBuffer const& segment_buffer)
+ {
+ auto const segment_ptr = segment_buffer.get_ptr();
+ auto const buffer_size = segment_buffer.get_size();
+ std::memcpy(segment_ptr, local_send_buf, buffer_size);
+ }
+
+ void PipelineCommunicator::copy_data_from_segment(void* local_recv_buf,
+ GPI::SegmentBuffer const& segment_buffer)
+ {
+ auto const segment_ptr = segment_buffer.get_ptr();
+ auto const buffer_size = segment_buffer.get_size();
+ std::memcpy(local_recv_buf, segment_ptr, buffer_size);
+ }
+}
diff --git a/src/gpi_comm_lib/PipelineCommunicator.hpp b/src/gpi_comm_lib/PipelineCommunicator.hpp
new file mode 100644
index 00000000..3c6effaa
--- /dev/null
+++ b/src/gpi_comm_lib/PipelineCommunicator.hpp
@@ -0,0 +1,65 @@
+#pragma once
+
+#include
+#include
+#include
+
+#include
+#include
+
+namespace tarantella
+{
+ class SendRecvResources
+ {
+ public:
+ SendRecvResources() = default;
+ SendRecvResources(GPI::Rank rank,
+ std::vector const& send_bufs,
+ std::vector const& recv_bufs,
+ std::vector const& notifications)
+ : other_rank(rank), send_bufs(send_bufs), recv_bufs(recv_bufs), notifications(notifications)
+ {}
+
+ GPI::Rank other_rank;
+ std::vector send_bufs;
+ std::vector recv_bufs;
+ std::vector notifications;
+ };
+
+ class ConnectionInfo
+ {
+ public:
+ explicit ConnectionInfo(GPI::SegmentID segment_id, GPI::Rank other_rank, std::size_t buffer_size_bytes)
+ : segment_id(segment_id), other_rank(other_rank), microbatched_buffer_size_bytes(buffer_size_bytes)
+ {}
+
+ GPI::SegmentID segment_id;
+ GPI::Rank other_rank;
+ std::size_t microbatched_buffer_size_bytes;
+ };
+
+ class PipelineCommunicator
+ {
+ public:
+ using ConnectionID = std::size_t;
+ using MicrobatchID = std::size_t;
+
+ PipelineCommunicator(GPI::Context&,
+ std::unordered_map const&,
+ std::size_t num_micro_batches);
+
+ void non_blocking_send(void* local_send_buf,
+ ConnectionID,
+ MicrobatchID);
+ void blocking_recv(void* local_recv_buf,
+ ConnectionID,
+ MicrobatchID);
+
+ private:
+ GPI::ResourceManager& resource_manager;
+ std::unordered_map connections;
+
+ void copy_data_to_segment(void* local_send_buf, GPI::SegmentBuffer const&);
+ void copy_data_from_segment(void* local_recv_buf, GPI::SegmentBuffer const&);
+ };
+}
diff --git a/src/gpi_comm_lib/SynchCommunicator.cpp b/src/gpi_comm_lib/SynchCommunicator.cpp
new file mode 100644
index 00000000..163635b5
--- /dev/null
+++ b/src/gpi_comm_lib/SynchCommunicator.cpp
@@ -0,0 +1,166 @@
+#include "SynchCommunicator.hpp"
+#include "collectives/allreduce/RecursiveHalvingDoubleBuffer.hpp"
+
+#include
+#include
+#include
+
+namespace tarantella
+{
+ void SynchCommunicator::create_fused_tensor_infos_and_ids(
+ std::vector const& tensor_infos,
+ std::size_t threshold_bytes)
+ {
+ collectives::TensorFusor fusor {threshold_bytes};
+ fusor.fuse_tensor_infos_and_ids(tensor_infos, fused_ids, fused_tensor_infos);
+ }
+
+ void SynchCommunicator::create_fused_tensors_synchronization()
+ {
+ for(auto const& fused_info : fused_tensor_infos)
+ {
+ auto const fused_id = fused_info.first;
+ ready_to_start_counters[fused_id] = std::make_unique>(0UL);
+ finished_counters[fused_id] = std::make_unique>(0UL);
+ ready_to_copy_back[fused_id] = std::make_unique>(false);
+ ready_to_reset_counters[fused_id] = std::make_unique>(0UL);
+ }
+ }
+
+ SynchCommunicator::SynchCommunicator(GPI::Context& context,
+ GPI::SegmentID segment_id,
+ GPI::Group const& group,
+ std::vector const& tensor_infos,
+ std::size_t threshold_for_tensor_fusion_bytes)
+ : resource_manager(context.get_resource_manager()),
+ segment_id(segment_id),
+ group(group),
+ queue_handler(),
+ fused_ids(),
+ fused_tensor_infos(),
+ operators(),
+ ready_to_start_counters(),
+ finished_counters(),
+ ready_to_copy_back(),
+ ready_to_reset_counters(),
+ setup_has_finished(),
+ terminate_man_thread(false),
+ management_thread(&tarantella::SynchCommunicator::management_thread_task, this)
+ {
+ using AllreduceImplementation = collectives::Allreduce::RecursiveHalvingDoubleBuffer;
+ create_fused_tensor_infos_and_ids(tensor_infos, threshold_for_tensor_fusion_bytes);
+ create_fused_tensors_synchronization();
+ create_segment_resources(tensor_infos);
+ create_operators_with_state();
+ setup_has_finished.notify();
+ }
+
+ SynchCommunicator::SynchCommunicator(GPI::Context& context,
+ GPI::SegmentID segment_id,
+ GPI::Group const& group,
+ std::vector const& tensor_infos)
+ : SynchCommunicator(context, segment_id, group, tensor_infos, 0UL)
+ { }
+
+ SynchCommunicator::~SynchCommunicator()
+ {
+ terminate_man_thread = true;
+ if (management_thread.joinable())
+ {
+ management_thread.join();
+ }
+ }
+
+ void SynchCommunicator::start_allreduce_impl(GradID const& grad_id, const void* data_ptr)
+ {
+ auto const fused_id = fused_ids[grad_id];
+
+ // All `grad_id`s copy-in their respective data
+ copy_data_to_segment(grad_id, data_ptr);
+ auto const value = ready_to_start_counters[fused_id]->fetch_add(1UL);
+
+ // Make sure all copies are done, before last `grad_id` starts operator
+ if (value == fused_tensor_infos[fused_id].get_num_tensors()-1)
+ {
+ operators[fused_id].allreduce->start();
+ ready_to_start_counters[fused_id]->store(0UL);
+ }
+ }
+
+ void SynchCommunicator::finish_allreduce_impl(GradID const& grad_id, void* results_ptr)
+ {
+ auto const fused_id = fused_ids[grad_id];
+
+ // First `grad_id` to arrive waits for `has_finished`, and notifies
+ // everyone that results can be copied back
+ auto const num_arrived = finished_counters[fused_id]->fetch_add(1UL);
+ if (num_arrived == 0)
+ {
+ operators[fused_id].has_finished->wait();
+ ready_to_copy_back[fused_id]->store(true);
+ }
+
+ // All `grad_id`s copy-out their respective data,
+ // once results have been obtained
+ while(true)
+ {
+ if(ready_to_copy_back[fused_id]->load())
+ {
+ copy_data_from_segment(grad_id, results_ptr);
+ break;
+ }
+ }
+
+ // Make sure all copies are done, before last `grad_id` resets initial state
+ auto const copied_grads = ready_to_reset_counters[fused_id]->fetch_add(1UL);
+ if (copied_grads == fused_tensor_infos[fused_id].get_num_tensors()-1)
+ {
+ operators[fused_id].allreduce->reset_for_reuse();
+ finished_counters[fused_id]->store(0UL);
+ ready_to_copy_back[fused_id]->store(false);
+ ready_to_reset_counters[fused_id]->store(0UL);
+ }
+ }
+
+ void SynchCommunicator::copy_data_to_segment(GradID const& grad_id, const void* data_ptr)
+ {
+ auto const fused_id = fused_ids[grad_id];
+ auto const segment_ptr = reinterpret_cast(operators[fused_id].allreduce->get_input_ptr())
+ + fused_tensor_infos[fused_id].get_local_offset_bytes(grad_id);
+ std::memcpy(segment_ptr, data_ptr, fused_tensor_infos[fused_id].get_local_size_bytes(grad_id));
+ }
+
+ void SynchCommunicator::copy_data_from_segment(GradID const& grad_id, void* results_ptr)
+ {
+ auto const fused_id = fused_ids[grad_id];
+ auto const segment_ptr = reinterpret_cast( operators[fused_id].allreduce->get_result_ptr())
+ + fused_tensor_infos[fused_id].get_local_offset_bytes(grad_id);
+ std::memcpy(results_ptr, segment_ptr, fused_tensor_infos[fused_id].get_local_size_bytes(grad_id));
+ }
+
+ void SynchCommunicator::management_thread_task()
+ {
+ setup_has_finished.wait();
+ while (!terminate_man_thread)
+ {
+ while (true)
+ {
+ if (terminate_man_thread)
+ {
+ break;
+ }
+ for (auto& element : operators)
+ {
+ auto& op = *(element.second.allreduce.get());
+ if (op.is_finished()) continue;
+
+ op.trigger_communication_step();
+ if (op.is_finished())
+ {
+ element.second.has_finished->notify();
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/src/gpi_comm_lib/SynchCommunicator.hpp b/src/gpi_comm_lib/SynchCommunicator.hpp
new file mode 100644
index 00000000..db0a2fc7
--- /dev/null
+++ b/src/gpi_comm_lib/SynchCommunicator.hpp
@@ -0,0 +1,134 @@
+#pragma once
+
+#include "AtomicCondition.hpp"
+#include "collectives/allreduce/Operator.hpp"
+#include "collectives/barrier/GPIBarrier.hpp"
+#include "collectives/FusedTensorInfo.hpp"
+#include "collectives/TensorInfo.hpp"
+#include "collectives/Types.hpp"
+#include "distribution/utilities.hpp"
+#include "gpi/Context.hpp"
+#include "gpi/ResourceManager.hpp"
+#include "queues.h"
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+namespace tarantella
+{
+ using GradID = collectives::GradID;
+ using FusedID = collectives::FusedID;
+
+ class SynchCommunicator
+ {
+ public:
+ SynchCommunicator(GPI::Context&, GPI::SegmentID, GPI::Group const&, std::vector const&);
+ SynchCommunicator(GPI::Context&, GPI::SegmentID, GPI::Group const&, std::vector const&, std::size_t);
+ SynchCommunicator(SynchCommunicator&) = delete;
+ SynchCommunicator& operator=(SynchCommunicator&) = delete;
+ ~SynchCommunicator();
+
+ // TODO: Replace void* with a LocalBuffer struct {ptr, size}
+ void start_allreduce_impl(GradID const&, const void*);
+ void finish_allreduce_impl(GradID const&, void*);
+
+ private:
+ struct OperatorWithState
+ {
+ std::unique_ptr allreduce;
+ std::unique_ptr has_finished;
+ };
+
+ static collectives::Allreduce::Operator::ReductionOp const reduction_op = collectives::Allreduce::Operator::ReductionOp::AVERAGE;
+
+ GPI::ResourceManager& resource_manager;
+ GPI::SegmentID segment_id;
+ GPI::Group const& group;
+ collectives::queues queue_handler; // TODO replace with the ResourceManager
+
+ std::unordered_map fused_ids;
+ std::unordered_map fused_tensor_infos;
+ std::unordered_map operators;
+
+ std::unordered_map>> ready_to_start_counters;
+ std::unordered_map>> finished_counters;
+ std::unordered_map>> ready_to_copy_back;
+ std::unordered_map>> ready_to_reset_counters;
+
+ AtomicCondition setup_has_finished;
+ std::atomic terminate_man_thread;
+ std::thread management_thread;
+ void management_thread_task();
+
+ void copy_data_to_segment(GradID const&, const void*);
+ void copy_data_from_segment(GradID const&, void*);
+
+ void create_fused_tensor_infos_and_ids(std::vector const&, std::size_t);
+ void create_fused_tensors_synchronization();
+
+ template
+ constexpr float get_overhead_factor() const;
+
+ template
+ void create_segment_resources(std::vector const& tensor_infos) const;
+
+ void create_fused_tensor_infos(std::vector const &tensor_infos);
+
+ template
+ std::unique_ptr create_allreduce_op(collectives::TensorInfo const&);
+
+ template
+ void create_operators_with_state();
+ };
+
+ template
+ constexpr float SynchCommunicator::get_overhead_factor() const
+ {
+ return 3.5;
+ }
+
+ template
+ void SynchCommunicator::create_segment_resources(std::vector const& tensor_infos) const
+ {
+ auto const segment_size = distribution::get_segment_size(tensor_infos, get_overhead_factor());
+ resource_manager.make_segment_resources(segment_id, group, segment_size);
+
+ // Barrier is required, to ensure all ranks have finished registering
+ // their segments to their communication partners
+ collectives::Barrier::GPIBarrier barrier(group);
+ barrier.blocking_barrier();
+ }
+
+ template
+ std::unique_ptr SynchCommunicator::create_allreduce_op(collectives::TensorInfo const& tensor_info)
+ {
+ auto const required_resources = AllreduceAlgorithm::get_required_resources(tensor_info, group);
+
+ collectives::Allreduce::Operator::ResourceList resources;
+ for (auto const& resource : required_resources)
+ {
+ resources.emplace_back(
+ resource_manager.get_buffer_of_size(segment_id, resource.buffer_size),
+ resource_manager.get_notification_range(segment_id, resource.num_notifications));
+ }
+
+ return std::make_unique(tensor_info, reduction_op, resources, queue_handler, group);
+ }
+
+ template
+ void SynchCommunicator::create_operators_with_state()
+ {
+ for(auto const& fused_info : fused_tensor_infos)
+ {
+ auto const tensor_id = fused_info.first;
+ auto const tensor_info = fused_info.second.to_tensor_info();
+ OperatorWithState op{create_allreduce_op(tensor_info), std::make_unique()};
+ operators.emplace(tensor_id, std::move(op));
+ }
+ }
+}
diff --git a/src/gpi_comm_lib/TensorBroadcaster.cpp b/src/gpi_comm_lib/TensorBroadcaster.cpp
new file mode 100644
index 00000000..f28cc19b
--- /dev/null
+++ b/src/gpi_comm_lib/TensorBroadcaster.cpp
@@ -0,0 +1,84 @@
+#include "TensorBroadcaster.hpp"
+
+#include "distribution/utilities.hpp"
+#include "gpi/Context.hpp"
+#include "gpi/ResourceManager.hpp"
+#include "gpi/SegmentBuffer.hpp"
+
+#include
+#include
+
+namespace tarantella
+{
+ TensorBroadcaster::TensorBroadcaster(GPI::Context& context,
+ GPI::SegmentID segment_id,
+ GPI::Group const& group,
+ std::vector const& tensor_infos,
+ GPI::Rank root_rank)
+ : context(context),
+ group(group),
+ queue_handler(),
+ root(root_rank),
+ barrier(group)
+ {
+ if(!group.contains_rank(root_rank))
+ {
+ throw std::runtime_error("[TensorBroadcaster::constructor]:\
+ Incorrect root_rank is not part of the broadcast group");
+ }
+
+ auto const overhead_factor = 1.0;
+ auto& resource_manager = context.get_resource_manager();
+ auto const segment_size = distribution::get_segment_size(tensor_infos, overhead_factor);
+
+ resource_manager.make_segment_resources(segment_id, group, segment_size);
+
+ // Barrier is required, to ensure all ranks have finished registering
+ // their segments to their communication partners
+ barrier.blocking_barrier();
+
+ for(auto const& info : tensor_infos)
+ {
+ auto const size_in_bytes = info.get_nelems() * getDataTypeSize(info.get_elem_type());
+ buffers.emplace_back(resource_manager.get_buffer_of_size(segment_id, size_in_bytes));
+ }
+
+ auto const notifications = resource_manager.get_notification_range(segment_id,
+ collectives::broadcast::getNumberOfNotifications(group.get_size()));
+ bcast_op = std::make_unique(root, segment_size, segment_id, buffers.front().get_offset(),
+ notifications.first, queue_handler);
+ }
+
+ void TensorBroadcaster::exec_broadcast(std::vector const& data_ptrs)
+ {
+ // copy data to segments
+ if (context.get_rank() == root)
+ {
+ for (std::size_t i = 0; i < data_ptrs.size(); ++i)
+ {
+ std::memcpy(buffers[i].get_ptr(), data_ptrs[i], buffers[i].get_size());
+ }
+ }
+
+ // start the operation
+ if (context.get_rank() == root)
+ {
+ bcast_op->signal();
+ }
+ // execute broadcast
+ while(bcast_op->operator()() != 0);
+
+ // copy results back to buffers
+ if (context.get_rank() != root)
+ {
+ for (std::size_t i = 0; i < data_ptrs.size(); ++i)
+ {
+ std::memcpy(data_ptrs[i], buffers[i].get_ptr(), buffers[i].get_size());
+ }
+ }
+
+ // finalize operation
+ barrier.blocking_barrier();
+ }
+}
+
diff --git a/src/gpi_comm_lib/TensorBroadcaster.hpp b/src/gpi_comm_lib/TensorBroadcaster.hpp
new file mode 100644
index 00000000..c72440ce
--- /dev/null
+++ b/src/gpi_comm_lib/TensorBroadcaster.hpp
@@ -0,0 +1,33 @@
+#pragma once
+
+#include "collectives/barrier/GPIBarrier.hpp"
+#include "collectives/TensorInfo.hpp"
+#include "gpi/Context.hpp"
+#include "gpi/Group.hpp"
+#include "gpi/SegmentBuffer.hpp"
+#include "broadcast.h"
+
+#include
+#include
+
+namespace tarantella
+{
+
+ class TensorBroadcaster
+ {
+ public:
+ TensorBroadcaster(GPI::Context&, GPI::SegmentID, GPI::Group const&,
+ std::vector const&, GPI::Rank root_rank);
+ void exec_broadcast(std::vector const&);
+
+ private:
+ GPI::Context& context;
+ GPI::Group const group;
+ collectives::queues queue_handler; // FIXME: use GPI::ResourcesManager
+ GPI::Rank root;
+ collectives::Barrier::GPIBarrier barrier;
+
+ std::vector buffers;
+ std::unique_ptr bcast_op;
+ };
+}
diff --git a/src/gpi_comm_lib/collectives/BufferElementType.cpp b/src/gpi_comm_lib/collectives/BufferElementType.cpp
new file mode 100644
index 00000000..ac0b457e
--- /dev/null
+++ b/src/gpi_comm_lib/collectives/BufferElementType.cpp
@@ -0,0 +1,26 @@
+#include "BufferElementType.hpp"
+
+#include
+
+namespace tarantella
+{
+ namespace collectives
+ {
+ std::size_t getDataTypeSize(const BufferElementType d)
+ {
+ std::unordered_map const sizes
+ {
+ {BufferElementType::FLOAT, sizeof(float)},
+ {BufferElementType::DOUBLE, sizeof(double)},
+ {BufferElementType::INT16, sizeof(int16_t)},
+ {BufferElementType::INT32, sizeof(int32_t)}
+ };
+ return sizes.at(d);
+ }
+
+ std::ostream &operator<<(std::ostream& os, BufferElementType const& elem_type)
+ {
+ return os << static_cast(elem_type);
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/gpi_comm_lib/collectives/BufferElementType.hpp b/src/gpi_comm_lib/collectives/BufferElementType.hpp
new file mode 100644
index 00000000..846dda2c
--- /dev/null
+++ b/src/gpi_comm_lib/collectives/BufferElementType.hpp
@@ -0,0 +1,21 @@
+#pragma once
+
+#include
+#include
+
+namespace tarantella
+{
+ namespace collectives
+ {
+ enum class BufferElementType
+ {
+ FLOAT,
+ DOUBLE,
+ INT16,
+ INT32
+ };
+
+ std::size_t getDataTypeSize(const BufferElementType d);
+ std::ostream &operator<<(std::ostream& os, BufferElementType const& elem_type);
+ }
+}
diff --git a/src/gpi_comm_lib/collectives/CMakeLists.txt b/src/gpi_comm_lib/collectives/CMakeLists.txt
new file mode 100644
index 00000000..9ed5643d
--- /dev/null
+++ b/src/gpi_comm_lib/collectives/CMakeLists.txt
@@ -0,0 +1,41 @@
+
+set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON)
+
+set(COLLECTIVES_SRC_DIR ${SRC_DIR}/gpi_comm_lib/collectives)
+set(libSources
+ ${COLLECTIVES_SRC_DIR}/lib/allreduceButterfly.cpp
+ ${COLLECTIVES_SRC_DIR}/lib/allreduceButterflyDoubleBuffer.cpp
+ ${COLLECTIVES_SRC_DIR}/lib/broadcast.cpp
+ ${COLLECTIVES_SRC_DIR}/lib/counter.cpp
+ ${COLLECTIVES_SRC_DIR}/lib/mailBoxGaspi.cpp
+ ${COLLECTIVES_SRC_DIR}/lib/mailBoxLocal.cpp
+ ${COLLECTIVES_SRC_DIR}/lib/queues.cpp
+ ${COLLECTIVES_SRC_DIR}/lib/reduce.cpp
+ ${COLLECTIVES_SRC_DIR}/lib/writer.cpp
+ ${COLLECTIVES_SRC_DIR}/allreduce/RecursiveHalving.cpp
+ ${COLLECTIVES_SRC_DIR}/allreduce/RecursiveHalvingDoubleBuffer.cpp
+ ${COLLECTIVES_SRC_DIR}/allreduce/utils.cpp
+ ${COLLECTIVES_SRC_DIR}/barrier/GPIBarrier.cpp
+ ${COLLECTIVES_SRC_DIR}/BufferElementType.cpp
+ ${COLLECTIVES_SRC_DIR}/FusedTensorInfo.cpp
+ ${COLLECTIVES_SRC_DIR}/TensorInfo.cpp
+)
+
+extended_add_library(NAME collectives
+ NAMESPACE tnt
+ TYPE SHARED
+ SOURCES
+ ${libSources}
+ LIBRARIES
+ optimized GPI2::GPI2
+ debug GPI2::GPI2dbg
+ tnt::gpiresources
+ INCLUDE_DIRECTORIES
+ ${COLLECTIVES_SRC_DIR}/lib/
+ COMPILE_OPTIONS
+ -Wno-unused-private-field
+ INSTALL
+ INSTALL_DESTINATION
+ ${INSTALL_LIB_DIR}
+ POSITION_INDEPENDENT)
+
diff --git a/src/gpi_comm_lib/collectives/FusedTensorInfo.cpp b/src/gpi_comm_lib/collectives/FusedTensorInfo.cpp
new file mode 100644
index 00000000..f0ceed6b
--- /dev/null
+++ b/src/gpi_comm_lib/collectives/FusedTensorInfo.cpp
@@ -0,0 +1,185 @@
+#include "FusedTensorInfo.hpp"
+
+namespace tarantella
+{
+ namespace collectives
+ {
+ void FusedTensorInfo::initialise_from_tensor_info(TensorInfo const& tensor_info)
+ {
+ local_offset_bytes.clear();
+ local_size_bytes.clear();
+
+ id = tensor_info.get_id();
+ nelems = tensor_info.get_nelems();
+ elem_type = tensor_info.get_elem_type();
+ elem_size = getDataTypeSize(elem_type);
+ size_bytes = nelems * elem_size;
+ num_tensors = 1UL;
+ tensor_ids.push_back(id);
+ local_offset_bytes[id] = 0UL;
+ local_size_bytes[id] = size_bytes;
+ }
+
+ FusedTensorInfo::FusedTensorInfo()
+ : id(),
+ nelems(),
+ elem_type(),
+ elem_size(),
+ size_bytes(),
+ num_tensors(),
+ tensor_ids(),
+ local_offset_bytes(),
+ local_size_bytes()
+ { }
+
+ FusedTensorInfo::FusedTensorInfo(TensorInfo const& tensor_info)
+ : FusedTensorInfo()
+ {
+ initialise_from_tensor_info(tensor_info);
+ }
+
+ FusedTensorInfo& FusedTensorInfo::operator=(TensorInfo const& tensor_info)
+ {
+ initialise_from_tensor_info(tensor_info);
+ return *this;
+ }
+
+ bool FusedTensorInfo::operator==(FusedTensorInfo const& other) const
+ {
+ return ( this->id == other.id &&
+ this->nelems == other.nelems &&
+ this->elem_type == other.elem_type &&
+ this->num_tensors == other.num_tensors &&
+ this->local_offset_bytes == other.local_offset_bytes &&
+ this->local_size_bytes == other.local_size_bytes );
+
+ }
+
+ FusedID FusedTensorInfo::get_id() const
+ {
+ return id;
+ }
+
+ std::size_t FusedTensorInfo::get_nelems() const
+ {
+ return nelems;
+ }
+
+ BufferElementType FusedTensorInfo::get_elem_type() const
+ {
+ return elem_type;
+ }
+
+ std::size_t FusedTensorInfo::get_size_bytes() const
+ {
+ return size_bytes;
+ }
+
+ std::size_t FusedTensorInfo::get_num_tensors() const
+ {
+ return num_tensors;
+ }
+
+ std::vector FusedTensorInfo::get_tensor_ids() const
+ {
+ return tensor_ids;
+ }
+
+ std::size_t FusedTensorInfo::get_local_offset_bytes(GradID const& grad_id) const
+ {
+ auto const it = local_offset_bytes.find(grad_id);
+ if (it == local_offset_bytes.end())
+ {
+ throw std::logic_error("FusedTensorInfo::get_local_offset_bytes: FusedTensorInfo does not contain GradID");
+ }
+ return it->second;
+ }
+
+ std::size_t FusedTensorInfo::get_local_size_bytes(GradID const& grad_id) const
+ {
+ auto const it = local_size_bytes.find(grad_id);
+ if (it == local_size_bytes.end())
+ {
+ throw std::logic_error("FusedTensorInfo::get_local_size_bytes: FusedTensorInfo does not contain GradID");
+ }
+ return it->second;
+ }
+
+ void FusedTensorInfo::add_tensor_info(TensorInfo const& tensor_info)
+ {
+ if (tensor_info.get_elem_type() != get_elem_type())
+ {
+ throw std::logic_error("FusedTensorInfo::add_tensor_info: Tensors need to have same data type");
+ }
+
+ auto const grad_id = tensor_info.get_id();
+ auto const grad_nelems = tensor_info.get_nelems();
+ auto const grad_size_bytes = grad_nelems * elem_size;
+ auto const current_offset = size_bytes;
+
+ nelems += grad_nelems;
+ size_bytes += grad_size_bytes;
+ num_tensors += 1UL;
+
+ tensor_ids.push_back(grad_id);
+ local_offset_bytes[grad_id] = current_offset;
+ local_size_bytes[grad_id] = grad_size_bytes;
+ }
+
+ TensorInfo FusedTensorInfo::to_tensor_info() const
+ {
+ return {get_id(), get_nelems(), get_elem_type()};
+ }
+
+ TensorFusor::TensorFusor()
+ : threshold_bytes(0UL)
+ { }
+
+ TensorFusor::TensorFusor(std::size_t threshold)
+ : threshold_bytes(threshold)
+ { }
+
+ void TensorFusor::fuse_tensor_infos_and_ids(std::vector const& tensor_infos,
+ IDMap& fused_ids,
+ InfoMap& fused_tensor_infos)
+ {
+ if (tensor_infos.size() == 1)
+ {
+ auto const tensor_info = tensor_infos.front();
+ auto const id = tensor_info.get_id();
+ fused_ids[id] = id;
+ fused_tensor_infos[id] = tensor_info;
+ }
+
+ collectives::FusedTensorInfo fused_info(tensor_infos.front());
+ auto tensor_id = tensor_infos.front().get_id();
+ FusedID fused_id(tensor_id);
+ fused_ids[tensor_id] = fused_id;
+
+ for (auto idx = 1UL; idx < tensor_infos.size(); ++idx)
+ {
+ tensor_id = tensor_infos[idx].get_id();
+
+ if (fused_info.get_size_bytes() < threshold_bytes)
+ {
+ fused_info.add_tensor_info(tensor_infos[idx]);
+ }
+ else
+ {
+ fused_tensor_infos[fused_id] = fused_info;
+ fused_id = tensor_id;
+ fused_info = tensor_infos[idx];
+ }
+
+ fused_ids[tensor_id] = fused_id;
+
+ // Always add the last fused_tensor to the vector.
+ // Note, that it might still be smaller than `threshold_bytes`.
+ if (idx == tensor_infos.size() - 1)
+ {
+ fused_tensor_infos[fused_id] = fused_info;
+ }
+ }
+ }
+ }
+}
diff --git a/src/gpi_comm_lib/collectives/FusedTensorInfo.hpp b/src/gpi_comm_lib/collectives/FusedTensorInfo.hpp
new file mode 100644
index 00000000..92925d2a
--- /dev/null
+++ b/src/gpi_comm_lib/collectives/FusedTensorInfo.hpp
@@ -0,0 +1,69 @@
+#pragma once
+
+#include "BufferElementType.hpp"
+#include "TensorInfo.hpp"
+#include "Types.hpp"
+
+#include
+#include
+#include
+
+namespace tarantella
+{
+ namespace collectives
+ {
+ class FusedTensorInfo
+ {
+ public:
+ FusedTensorInfo();
+ FusedTensorInfo(TensorInfo const&);
+ FusedTensorInfo& operator=(TensorInfo const&);
+ bool operator==(FusedTensorInfo const&) const;
+
+ FusedID get_id() const;
+ std::size_t get_nelems() const;
+ BufferElementType get_elem_type() const;
+ std::size_t get_size_bytes() const;
+
+ std::size_t get_num_tensors() const;
+ std::vector get_tensor_ids() const;
+
+ std::size_t get_local_offset_bytes(GradID const&) const;
+ std::size_t get_local_size_bytes(GradID const&) const;
+
+ void add_tensor_info(TensorInfo const&);
+ TensorInfo to_tensor_info() const;
+
+ private:
+ FusedID id;
+ std::size_t nelems;
+ BufferElementType elem_type;
+ std::size_t elem_size;
+ std::size_t size_bytes;
+ std::size_t num_tensors;
+
+ std::vector tensor_ids;
+ std::unordered_map local_offset_bytes;
+ std::unordered_map local_size_bytes;
+
+ void initialise_from_tensor_info(TensorInfo const&);
+ };
+
+ class TensorFusor
+ {
+ public:
+ using IDMap = std::unordered_map;
+ using InfoMap = std::unordered_map;
+
+ TensorFusor();
+ TensorFusor(std::size_t threshold);
+
+ void fuse_tensor_infos_and_ids(std::vector const&,
+ IDMap&,
+ InfoMap&);
+
+ private:
+ std::size_t threshold_bytes;
+ };
+ }
+}
diff --git a/src/gpi_comm_lib/collectives/TensorInfo.cpp b/src/gpi_comm_lib/collectives/TensorInfo.cpp
new file mode 100644
index 00000000..b8b9d345
--- /dev/null
+++ b/src/gpi_comm_lib/collectives/TensorInfo.cpp
@@ -0,0 +1,26 @@
+#include "TensorInfo.hpp"
+
+namespace tarantella
+{
+ namespace collectives
+ {
+ TensorInfo::TensorInfo(GradID tensid, std::size_t nelems, BufferElementType elem_type)
+ : id(tensid), nelems(nelems), elem_type(elem_type)
+ {}
+
+ GradID TensorInfo::get_id() const
+ {
+ return id;
+ }
+
+ std::size_t TensorInfo::get_nelems() const
+ {
+ return nelems;
+ }
+
+ BufferElementType TensorInfo::get_elem_type() const
+ {
+ return elem_type;
+ }
+ }
+}
diff --git a/src/gpi_comm_lib/collectives/TensorInfo.hpp b/src/gpi_comm_lib/collectives/TensorInfo.hpp
new file mode 100644
index 00000000..374ce0dc
--- /dev/null
+++ b/src/gpi_comm_lib/collectives/TensorInfo.hpp
@@ -0,0 +1,27 @@
+#pragma once
+
+#include "BufferElementType.hpp"
+#include "Types.hpp"
+
+#include
+
+namespace tarantella
+{
+ namespace collectives
+ {
+ class TensorInfo
+ {
+ public:
+ TensorInfo(GradID tensid, std::size_t nelems, BufferElementType elem_type);
+
+ GradID get_id() const;
+ std::size_t get_nelems() const;
+ BufferElementType get_elem_type() const;
+
+ private:
+ const GradID id;
+ const std::size_t nelems;
+ const BufferElementType elem_type;
+ };
+ }
+}
diff --git a/src/gpi_comm_lib/collectives/Types.hpp b/src/gpi_comm_lib/collectives/Types.hpp
new file mode 100644
index 00000000..2efe0caa
--- /dev/null
+++ b/src/gpi_comm_lib/collectives/Types.hpp
@@ -0,0 +1,10 @@
+#pragma once
+
+namespace tarantella
+{
+ namespace collectives
+ {
+ using GradID = std::size_t;
+ using FusedID = std::size_t;
+ }
+}
diff --git a/src/gpi_comm_lib/collectives/allreduce/Operator.hpp b/src/gpi_comm_lib/collectives/allreduce/Operator.hpp
new file mode 100644
index 00000000..e268721c
--- /dev/null
+++ b/src/gpi_comm_lib/collectives/allreduce/Operator.hpp
@@ -0,0 +1,71 @@
+#pragma once
+
+#include "collectives/BufferElementType.hpp"
+#include "gpi/NotificationManager.hpp"
+#include "gpi/SegmentBuffer.hpp"
+
+#include
+#include
+
+namespace tarantella
+{
+ namespace collectives
+ {
+ namespace Allreduce
+ {
+ // \note
+ // Interface for non-blocking, asynchronous Allreduce algorithms (not thread-safe)
+ class Operator
+ {
+ public:
+ class RequiredResource
+ {
+ public:
+ std::size_t buffer_size;
+ std::size_t num_notifications;
+ };
+ using RequiredResourceList = std::vector;
+ using Resource = std::pair;
+ using ResourceList = std::vector;
+
+ enum class ReductionOp
+ {
+ SUM,
+ AVERAGE
+ };
+
+ enum class OperatorState
+ {
+ NOT_STARTED,
+ RUNNING,
+ FINISHED
+ };
+
+ virtual ~Operator() = default;
+
+ // Initiates the Allreduce operation (non-blocking)
+ // and sets is_running == TRUE
+ virtual void start() = 0;
+
+ // Makes partial progress towards computing the Allreduce result
+ // and has to be called multiple times until the operation is completed,
+ // when is_finished == TRUE
+ // can be called independently of the state;
+ // it only tries to make progress if is_running == TRUE
+ virtual void trigger_communication_step() = 0;
+
+ // Enables the Allreduce to be started again
+ // and sets is_running == FALSE and is_finished == FALSE
+ virtual void reset_for_reuse() = 0;
+ virtual bool is_running() const = 0;
+
+ // If TRUE, results are available until reset_for_reuse() is called
+ virtual bool is_finished() const = 0;
+
+ // TODO: void* -> SegmentBuffer
+ virtual void* get_input_ptr() const = 0;
+ virtual void* get_result_ptr() const = 0;
+ };
+ }
+ }
+}
diff --git a/src/gpi_comm_lib/collectives/allreduce/RecursiveHalving.cpp b/src/gpi_comm_lib/collectives/allreduce/RecursiveHalving.cpp
new file mode 100644
index 00000000..be62d494
--- /dev/null
+++ b/src/gpi_comm_lib/collectives/allreduce/RecursiveHalving.cpp
@@ -0,0 +1,98 @@
+#include "RecursiveHalving.hpp"
+
+#include "utils.hpp"
+
+namespace tarantella
+{
+ namespace collectives
+ {
+ namespace Allreduce
+ {
+ RecursiveHalving::RecursiveHalving(TensorInfo tensor_info,
+ ReductionOp reduction_op,
+ ResourceList const &resource_list,
+ queues &queues,
+ GPI::Group const &group)
+ : group(group),
+ state(OperatorState::NOT_STARTED),
+ allreduce(tensor_info.get_nelems(), to_allreduce_dataType(tensor_info.get_elem_type()),
+ to_allreduce_reductionType(reduction_op),
+ to_allreduce_segment_buffer(resource_list.at(0)),
+ to_allreduce_segment_buffer(resource_list.at(1)),
+ queues, group),
+ barrier(group)
+ {}
+
+ void RecursiveHalving::start()
+ {
+ if (is_running())
+ {
+ throw std::logic_error("[RecursiveHalving::start] Operation already started.");
+ }
+ if (is_finished())
+ {
+ throw std::logic_error("[RecursiveHalving::start] Operation not reset after finish.");
+ }
+ allreduce.signal();
+ state = OperatorState::RUNNING;
+ }
+
+ void RecursiveHalving::trigger_communication_step()
+ {
+ if (is_running())
+ {
+ auto const result = allreduce();
+ if (result == 0)
+ {
+ barrier.blocking_barrier();
+ state = OperatorState::FINISHED;
+ }
+ }
+ else
+ {
+ // do nothing before start() is called
+ }
+ }
+
+ void RecursiveHalving::reset_for_reuse()
+ {
+ if (is_running())
+ {
+ throw std::logic_error("[RecursiveHalving::reset] Cannot reset while running.");
+ }
+ state = OperatorState::NOT_STARTED;
+ }
+
+ bool RecursiveHalving::is_running() const
+ {
+ return state == OperatorState::RUNNING;
+ }
+
+ bool RecursiveHalving::is_finished() const
+ {
+ return state == OperatorState::FINISHED;
+ }
+
+ Operator::RequiredResourceList RecursiveHalving::get_required_resources(
+ TensorInfo const& tensor_info, GPI::Group const& group)
+ {
+ auto const num_notifications = allreduceButterfly::getNumberOfNotifications(group.get_size());
+ auto const num_elements_data_segment = tensor_info.get_nelems();
+ auto const num_elements_temp_segment = static_cast(
+ allreduceButterfly::getNumberOfElementsSegmentCommunicate(tensor_info.get_nelems(), group.get_size()));
+ return {{num_elements_data_segment * getDataTypeSize(tensor_info.get_elem_type()), num_notifications},
+ {num_elements_temp_segment * getDataTypeSize(tensor_info.get_elem_type()), num_notifications}};
+ }
+
+ void* RecursiveHalving::get_input_ptr() const
+ {
+ return allreduce.getReducePointer();
+ }
+
+ void* RecursiveHalving::get_result_ptr() const
+ {
+ return allreduce.getReducePointer();
+ }
+ }
+ }
+}
diff --git a/src/gpi_comm_lib/collectives/allreduce/RecursiveHalving.hpp b/src/gpi_comm_lib/collectives/allreduce/RecursiveHalving.hpp
new file mode 100644
index 00000000..1687cf20
--- /dev/null
+++ b/src/gpi_comm_lib/collectives/allreduce/RecursiveHalving.hpp
@@ -0,0 +1,49 @@
+#pragma once
+
+#include "Operator.hpp"
+#include "allreduceButterfly.h"
+#include "collectives/barrier/GPIBarrier.hpp"
+#include "collectives/TensorInfo.hpp"
+#include "gpi/Group.hpp"
+#include "gpi/NotificationManager.hpp"
+#include "gpi/SegmentBuffer.hpp"
+
+namespace tarantella
+{
+ namespace collectives
+ {
+ namespace Allreduce
+ {
+ class RecursiveHalving : public Operator
+ {
+ public:
+ RecursiveHalving(TensorInfo,
+ ReductionOp,
+ ResourceList const&,
+ queues&,
+ GPI::Group const&);
+ RecursiveHalving(const RecursiveHalving&) = delete;
+ RecursiveHalving& operator=(const RecursiveHalving&) = delete;
+ ~RecursiveHalving() = default;
+
+ void start() override;
+ void trigger_communication_step() override;
+
+ void reset_for_reuse() override;
+ bool is_running() const override;
+ bool is_finished() const override;
+
+ void* get_input_ptr() const override;
+ void* get_result_ptr() const override;
+
+ static RequiredResourceList get_required_resources(TensorInfo const&, GPI::Group const&);
+
+ private:
+ GPI::Group const& group;
+ std::atomic state;
+ allreduceButterfly allreduce;
+ Barrier::GPIBarrier barrier;
+ };
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/gpi_comm_lib/collectives/allreduce/RecursiveHalvingDoubleBuffer.cpp b/src/gpi_comm_lib/collectives/allreduce/RecursiveHalvingDoubleBuffer.cpp
new file mode 100644
index 00000000..0a2f842a
--- /dev/null
+++ b/src/gpi_comm_lib/collectives/allreduce/RecursiveHalvingDoubleBuffer.cpp
@@ -0,0 +1,98 @@
+#include "RecursiveHalvingDoubleBuffer.hpp"
+
+#include "gpi/gaspiCheckReturn.hpp"
+#include "utils.hpp"
+
+namespace tarantella
+{
+ namespace collectives
+ {
+ namespace Allreduce
+ {
+ RecursiveHalvingDoubleBuffer::RecursiveHalvingDoubleBuffer(TensorInfo tensor_info,
+ ReductionOp reduction_op,
+ ResourceList const& resource_list,
+ queues& queues,
+ GPI::Group const& group)
+ : state(OperatorState::NOT_STARTED),
+ allreduce(tensor_info.get_nelems(),
+ to_allreduce_dataType(tensor_info.get_elem_type()),
+ to_allreduce_reductionType(reduction_op),
+ to_allreduce_segment_buffer(resource_list.at(0)),
+ to_allreduce_segment_buffer(resource_list.at(1)),
+ to_allreduce_segment_buffer(resource_list.at(2)),
+ queues, group)
+ { }
+
+ void RecursiveHalvingDoubleBuffer::start()
+ {
+ if (is_running())
+ {
+ throw std::logic_error("[RecursiveHalvingDoubleBuffer::start] Operation already started.");
+ }
+ if (is_finished())
+ {
+ throw std::logic_error("[RecursiveHalvingDoubleBuffer::start] Operation not reset after finish.");
+ }
+ allreduce.signal();
+ state = OperatorState::RUNNING;
+ }
+
+ void RecursiveHalvingDoubleBuffer::trigger_communication_step()
+ {
+ if (is_running())
+ {
+ auto const result = allreduce();
+ if (result == 0)
+ {
+ state = OperatorState::FINISHED;
+ }
+ }
+ }
+
+ void RecursiveHalvingDoubleBuffer::reset_for_reuse()
+ {
+ if (is_running())
+ {
+ throw std::logic_error("[RecursiveHalvingDoubleBuffer::reset] Cannot reset while running.");
+ }
+ state = OperatorState::NOT_STARTED;
+ }
+
+ bool RecursiveHalvingDoubleBuffer::is_running() const
+ {
+ return state == OperatorState::RUNNING;
+ }
+
+ bool RecursiveHalvingDoubleBuffer::is_finished() const
+ {
+ return state == OperatorState::FINISHED;
+ }
+
+ Operator::RequiredResourceList RecursiveHalvingDoubleBuffer::get_required_resources(
+ TensorInfo const& tensor_info, GPI::Group const& group)
+ {
+ auto const num_notifications = allreduceButterflyDoubleBuffer::getNumberOfNotifications(group.get_size());
+
+ auto const num_elements_data_segment = tensor_info.get_nelems();
+ auto const num_elements_temp_segment = static_cast(
+ allreduceButterflyDoubleBuffer::getNumberOfElementsSegmentCommunicate(
+ tensor_info.get_nelems(), group.get_size()));
+
+ return {{num_elements_data_segment * getDataTypeSize(tensor_info.get_elem_type()), num_notifications},
+ {num_elements_data_segment * getDataTypeSize(tensor_info.get_elem_type()), num_notifications},
+ {num_elements_temp_segment * getDataTypeSize(tensor_info.get_elem_type()), num_notifications}};
+ }
+
+ void* RecursiveHalvingDoubleBuffer::get_input_ptr() const
+ {
+ return allreduce.getActiveReducePointer();
+ }
+
+ void* RecursiveHalvingDoubleBuffer::get_result_ptr() const
+ {
+ return allreduce.getResultsPointer();
+ }
+ }
+ }
+}
diff --git a/src/gpi_comm_lib/collectives/allreduce/RecursiveHalvingDoubleBuffer.hpp b/src/gpi_comm_lib/collectives/allreduce/RecursiveHalvingDoubleBuffer.hpp
new file mode 100644
index 00000000..653a891f
--- /dev/null
+++ b/src/gpi_comm_lib/collectives/allreduce/RecursiveHalvingDoubleBuffer.hpp
@@ -0,0 +1,44 @@
+#pragma once
+
+#include "Operator.hpp"
+#include "allreduceButterflyDoubleBuffer.h"
+#include "collectives/TensorInfo.hpp"
+#include "gpi/Group.hpp"
+
+namespace tarantella
+{
+ namespace collectives
+ {
+ namespace Allreduce
+ {
+ class RecursiveHalvingDoubleBuffer : public Operator
+ {
+ public:
+ RecursiveHalvingDoubleBuffer(TensorInfo,
+ ReductionOp,
+ ResourceList const&,
+ queues&,
+ GPI::Group const&);
+ RecursiveHalvingDoubleBuffer(const RecursiveHalvingDoubleBuffer&) = delete;
+ RecursiveHalvingDoubleBuffer& operator=(const RecursiveHalvingDoubleBuffer&) = delete;
+ ~RecursiveHalvingDoubleBuffer() = default;
+
+ void start() override;
+ void trigger_communication_step() override;
+
+ void reset_for_reuse() override;
+ bool is_running() const override;
+ bool is_finished() const override;
+
+ virtual void* get_input_ptr() const override;
+ virtual void* get_result_ptr() const override;
+
+ static RequiredResourceList get_required_resources(TensorInfo const&, GPI::Group const& group);
+
+ private:
+ std::atomic state;
+ allreduceButterflyDoubleBuffer allreduce;
+ };
+ }
+ }
+}
diff --git a/src/gpi_comm_lib/collectives/allreduce/utils.cpp b/src/gpi_comm_lib/collectives/allreduce/utils.cpp
new file mode 100644
index 00000000..da5f5b3b
--- /dev/null
+++ b/src/gpi_comm_lib/collectives/allreduce/utils.cpp
@@ -0,0 +1,39 @@
+#include "utils.hpp"
+
+namespace tarantella
+{
+ namespace collectives
+ {
+ namespace Allreduce
+ {
+ allreduce::dataType to_allreduce_dataType(const BufferElementType type)
+ {
+ std::unordered_map const types{
+ {BufferElementType::FLOAT, allreduce::FLOAT},
+ {BufferElementType::DOUBLE, allreduce::DOUBLE},
+ {BufferElementType::INT16, allreduce::INT16},
+ {BufferElementType::INT32, allreduce::INT32},
+ };
+ return types.at(type);
+ }
+
+ allreduce::reductionType to_allreduce_reductionType(const Operator::ReductionOp op)
+ {
+ std::unordered_map const reduction_ops{
+ {Operator::ReductionOp::SUM, allreduce::SUM},
+ {Operator::ReductionOp::AVERAGE, allreduce::AVERAGE},
+ };
+ return reduction_ops.at(op);
+ }
+
+ allreduceButterfly::segmentBuffer to_allreduce_segment_buffer(Operator::Resource const& resource)
+ {
+ auto const [data_segment_buffer, notif_range] = resource;
+ allreduceButterfly::segmentBuffer buffer{data_segment_buffer.get_segment_id(),
+ data_segment_buffer.get_offset(),
+ static_cast(notif_range.first)};
+ return buffer;
+ }
+ }
+ }
+}
diff --git a/src/gpi_comm_lib/collectives/allreduce/utils.hpp b/src/gpi_comm_lib/collectives/allreduce/utils.hpp
new file mode 100644
index 00000000..462faa51
--- /dev/null
+++ b/src/gpi_comm_lib/collectives/allreduce/utils.hpp
@@ -0,0 +1,21 @@
+#pragma once
+
+#include "allreduce.h"
+#include "allreduceButterfly.h"
+#include "collectives/BufferElementType.hpp"
+#include "Operator.hpp"
+
+namespace tarantella
+{
+ namespace collectives
+ {
+ namespace Allreduce
+ {
+ allreduce::dataType to_allreduce_dataType(const BufferElementType type);
+ allreduce::reductionType to_allreduce_reductionType(
+ const Operator::ReductionOp op);
+ allreduceButterfly::segmentBuffer to_allreduce_segment_buffer(
+ Operator::Resource const &resource);
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/gpi_comm_lib/collectives/barrier/GPIBarrier.cpp b/src/gpi_comm_lib/collectives/barrier/GPIBarrier.cpp
new file mode 100644
index 00000000..c954a6a9
--- /dev/null
+++ b/src/gpi_comm_lib/collectives/barrier/GPIBarrier.cpp
@@ -0,0 +1,38 @@
+#include "GPIBarrier.hpp"
+#include "gpi/gaspiCheckReturn.hpp"
+
+#include
+
+namespace tarantella
+{
+ namespace collectives
+ {
+ namespace Barrier
+ {
+ GPIBarrier::GPIBarrier(GPI::Group const &group)
+ {
+ gaspi_rank_t comm_size;
+ GPI::gaspiCheckReturn(gaspi_proc_num(&comm_size),
+ "GPIBarrier::GPIBarrier : get number of ranks");
+ if (group.get_size() != comm_size)
+ {
+ throw std::invalid_argument("GPIBarrier::GPIBarrier : can only be used with all ranks in \
+ the default GPI communicator");
+ }
+ }
+
+ // TODO: implement for any GPI::Group
+ void GPIBarrier::blocking_barrier()
+ {
+ GPI::gaspiCheckReturn(gaspi_barrier(GASPI_GROUP_ALL, GASPI_BLOCK),
+ "GPIBarrier::GPIBarrier : barrier failed");
+ }
+
+ void GPIBarrierAllRanks::blocking_barrier()
+ {
+ GPI::gaspiCheckReturn(gaspi_barrier(GASPI_GROUP_ALL, GASPI_BLOCK),
+ "GPIBarrierAllRanks::GPIBarrierAllRanks : barrier failed");
+ }
+ }
+ }
+}
diff --git a/src/gpi_comm_lib/collectives/barrier/GPIBarrier.hpp b/src/gpi_comm_lib/collectives/barrier/GPIBarrier.hpp
new file mode 100644
index 00000000..3ac367a1
--- /dev/null
+++ b/src/gpi_comm_lib/collectives/barrier/GPIBarrier.hpp
@@ -0,0 +1,30 @@
+#pragma once
+
+#include "gpi/Group.hpp"
+#include "Operator.hpp"
+
+namespace tarantella
+{
+ namespace collectives
+ {
+ namespace Barrier
+ {
+ // GPI Barrier implementation for GROUP_COMM_ALL
+ class GPIBarrier : public Operator
+ {
+ public:
+
+ GPIBarrier(GPI::Group const & group);
+ void blocking_barrier();
+ };
+
+ class GPIBarrierAllRanks : public Operator
+ {
+ public:
+
+ GPIBarrierAllRanks() = default;
+ void blocking_barrier();
+ };
+ }
+ }
+}
diff --git a/src/gpi_comm_lib/collectives/barrier/Operator.hpp b/src/gpi_comm_lib/collectives/barrier/Operator.hpp
new file mode 100644
index 00000000..81e863bb
--- /dev/null
+++ b/src/gpi_comm_lib/collectives/barrier/Operator.hpp
@@ -0,0 +1,20 @@
+#pragma once
+
+namespace tarantella
+{
+ namespace collectives
+ {
+ namespace Barrier
+ {
+ // \note
+ // Interface for Barrier algorithms (not thread-safe)
+ class Operator
+ {
+ public:
+ virtual ~Operator() = default;
+
+ virtual void blocking_barrier() = 0;
+ };
+ }
+ }
+}
diff --git a/src/gpi_comm_lib/collectives/lib/allreduce.h b/src/gpi_comm_lib/collectives/lib/allreduce.h
new file mode 100755
index 00000000..b9f12fe5
--- /dev/null
+++ b/src/gpi_comm_lib/collectives/lib/allreduce.h
@@ -0,0 +1,27 @@
+#pragma once
+
+namespace tarantella
+{
+ namespace collectives
+ {
+ class allreduce {
+ public:
+ enum reductionType {
+ SUM = 0,
+ AVERAGE = 1,
+ NUM_RED = 2
+ };
+ enum dataType {
+ FLOAT = 0,
+ DOUBLE = 1,
+ INT16 = 2,
+ INT32 = 3,
+ NUM_TYPE = 4
+ };
+
+ virtual int operator()() = 0;
+ virtual void signal() = 0;
+ virtual ~allreduce() {}
+ };
+ }
+}
diff --git a/src/gpi_comm_lib/collectives/lib/allreduceButterfly.cpp b/src/gpi_comm_lib/collectives/lib/allreduceButterfly.cpp
new file mode 100755
index 00000000..26fb2b58
--- /dev/null
+++ b/src/gpi_comm_lib/collectives/lib/allreduceButterfly.cpp
@@ -0,0 +1,418 @@
+#include "allreduceButterfly.h"
+#include "gpi/gaspiCheckReturn.hpp"
+#include "mailBoxGaspi.h"
+#include "gpi/Group.hpp"
+
+#include
+#include
+#include
+
+namespace tarantella
+{
+ namespace collectives
+ {
+ using tarantella::GPI::gaspiCheckReturn;
+
+ nestedRingParameter::nestedRingParameter(const rankIndexType numRanks_,
+ const rankIndexType rank_) :
+ numRanks(numRanks_),
+ rank(rank_),
+ ringSizes(getRingSizes(numRanks)),
+ strides(getStrides(ringSizes)),
+ ringIndices(getRingIndices(ringSizes, rank_)) {}
+
+ inline nestedRingParameter::ringSizesType nestedRingParameter::getRingSizes(
+ rankIndexType numRanks) {
+ ringSizesType s;
+
+ unsigned long limit = std::sqrt(numRanks) + 2;
+
+ for (unsigned long factor=2; factor < limit; factor++) {
+ while ((numRanks % factor) == 0) {
+ s.push_back(factor);
+ numRanks /= factor;
+ }
+ }
+
+ if (numRanks > 1) {
+ s.push_back(numRanks);
+ }
+
+ return s;
+ }
+
+ inline nestedRingParameter::stridesType nestedRingParameter::getStrides(
+ const ringSizesType& ringSizes) {
+ const long numLevels = ringSizes.size();
+ stridesType s(ringSizes.size());
+ unsigned long factor = 1;
+ for (long level=numLevels - 1; level >= 0; level--) {
+ s[level] = factor;
+ factor *= ringSizes[level];
+ }
+
+ return s;
+ }
+
+ inline nestedRingParameter::ringIndicesType
+ nestedRingParameter::getRingIndices(const ringSizesType& ringSizes,
+ const rankIndexType rank) {
+ ringIndicesType indices;
+
+ rankIndexType product = 1;
+ for (unsigned long i=0; i < ringSizes.size(); i++) {
+ indices.push_back((rank / product) % ringSizes[i]);
+ product *= ringSizes[i];
+ }
+
+ return indices;
+ }
+
+ nestedRingParameter::rankIndexType
+ nestedRingParameter::getNumberOfRings() const{
+ return ringSizes.size();
+ }
+
+ nestedRingParameter::rankIndexType nestedRingParameter::getRingLength(
+ const levelType level) const {
+ return ringSizes[level];
+ }
+
+ nestedRingParameter::rankIndexType nestedRingParameter::getLocalRankInRing(
+ const levelType level) const {
+ return ringIndices[level];
+ }
+
+ nestedRingParameter::rankIndexType
+ nestedRingParameter::getGlobalRankToWriteInRing(
+ const levelType level) const {
+ long numLevels = ringSizes.size();
+ rankIndexType r = 0;
+ for (long i=numLevels - 1; i > long(level); i--) {
+ r = ringIndices[i] + ringSizes[i] * r;
+ }
+ const rankIndexType next = (ringIndices[level] + 1) % ringSizes[level];
+ r = next + ringSizes[level] * r;
+ for (long i=long(level) - 1; i >= 0; i--) {
+ r = ringIndices[i] + ringSizes[i] * r;
+ }
+ return r;
+ }
+
+ nestedRingParameter::bufferIndexType nestedRingParameter::getBufferLength(
+ const levelType level) const {
+ return strides[level];
+ }
+
+ nestedRingParameter::bufferIndexType nestedRingParameter::getBufferStart(
+ const levelType level,
+ const bufferIndexType buffer) const {
+ // we assume that each global rank aggregates on each level the buffer
+ // that matches the local ring id. This buffer is
+ // I.E. getBufferStart(level, getRankInRing(level))
+ // -> getBufferStart(level, getRankInRing(level)) + getBufferLength(level)
+
+ bufferIndexType s = 0;
+ for (unsigned long i=0; i < level; i++) {
+ s += ringIndices[i] * strides[i];
+ }
+ s += buffer * strides[level];
+
+ return s;
+ }
+
+ allreduceButterfly::allreduceButterfly(
+ const long len,
+ const dataType data,
+ const reductionType reduction,
+ const segmentBuffer locationReduce_,
+ const segmentBuffer locationCommunicate_,
+ queues& queues_,
+ GPI::Group const& group_
+ )
+ : totalLength(len),
+ dataElement(data),
+ group(group_),
+ numRanks(getNumRanks()),
+ rank(getRank()),
+ locationReduce(locationReduce_),
+ locationReducePointer(getSegmentPointer(locationReduce_.segment)
+ + locationReduce_.offset),
+ locationCommunicate(locationCommunicate_),
+ topology(numRanks, getRankIndex(rank, getRanks())),
+ sender(queues_),
+ reducer(getReduce(data, reduction)),
+ status(2 * getNumberOfNotifications(numRanks) + 1){
+
+ std::vector ranks = getRanks();
+
+ setReduceScatter(ranks);
+ setAllToAll(ranks);
+ }
+
+ long allreduceButterfly::getNumRanks() const {
+ return group.get_size();
+ }
+
+ long allreduceButterfly::getRank() {
+ gaspi_rank_t rank;
+ gaspiCheckReturn(gaspi_proc_rank(&rank),
+ "gaspi_proc_rank failed with ");
+ return rank;
+ }
+
+ std::vector allreduceButterfly::getRanks() const {
+ return group.get_ranks();
+ }
+
+ unsigned long allreduceButterfly::getRankIndex(
+ gaspi_rank_t rank,
+ const std::vector& ranks) {
+ unsigned long rankIndex;
+ if (find(ranks.begin(), ranks.end(), rank) == ranks.end()) {
+ throw std::runtime_error("rank not member of group");
+ } else {
+ rankIndex = find(ranks.begin(), ranks.end(), rank)
+ - ranks.begin();
+ }
+ return rankIndex;
+ }
+
+ void allreduceButterfly::setReduceScatter(
+ const std::vector& ranks) {
+ gaspi_notification_id_t nextNotification
+ = locationCommunicate.firstNotification;
+ gaspi_offset_t nextLocalCommunicationBufferByte = 0;
+ const char* const reductionSourceBasePointer =
+ getSegmentPointer(locationCommunicate.segment)
+ + locationCommunicate.offset;
+ char* const reductionDestinationBasePointer =
+ getSegmentPointer(locationReduce.segment)
+ + locationReduce.offset;
+
+ receiver.push_back(&trigger);
+ jobs.push_back(jobType());
+
+ for (unsigned long ring=0; ring < topology.getNumberOfRings(); ring++) {
+
+ const rankIndexType ringLength = topology.getRingLength(ring);
+ const rankIndexType ringRank = topology.getLocalRankInRing(ring);
+ const bufferIndexType bufferLengthIndex = topology.getBufferLength(ring);
+ const gaspi_rank_t outgoingGlobalRank =
+ ranks[topology.getGlobalRankToWriteInRing(ring)];
+ gaspi_offset_t nextRemoteCommunicationBufferByte
+ = nextLocalCommunicationBufferByte;
+
+
+ for (unsigned long loop=0; loop < ringLength - 1; loop++) {
+ const unsigned long currentJob = receiver.size() - 1;
+ receiver.push_back(
+ new mailBoxGaspi(locationCommunicate.segment, nextNotification));
+ jobs.push_back(jobType());
+
+ const bufferIndexType sendBufferID =
+ (ringRank + ringLength - loop - 1) % ringLength;
+ const bufferIndexType sendStartIndex =
+ topology.getBufferStart(ring, sendBufferID);
+ const gaspi_offset_t sendStartByte =
+ chunkIndexToByte(sendStartIndex);
+ const long sendLengthByte =
+ chunkIndexToByte(sendStartIndex + bufferLengthIndex)
+ - sendStartByte;
+ const writer::transferParameters transfer(
+ true,
+ outgoingGlobalRank,
+ locationReduce.segment,
+ locationReduce.offset + sendStartByte,
+ locationCommunicate.segment,
+ locationCommunicate.offset + nextRemoteCommunicationBufferByte,
+ sendLengthByte,
+ nextNotification);
+ jobs[currentJob].second = transfer;
+
+ const bufferIndexType receiveBufferID =
+ (ringRank + ringLength - loop - 2) % ringLength;
+ const bufferIndexType receiveStartIndex =
+ topology.getBufferStart(ring, receiveBufferID);
+ const gaspi_offset_t receiveStartByte =
+ chunkIndexToByte(receiveStartIndex);
+ const long receiveLengthByte =
+ chunkIndexToByte(receiveStartIndex + bufferLengthIndex)
+ - receiveStartByte;
+ const reduce::task copy(
+ reductionSourceBasePointer + nextLocalCommunicationBufferByte,
+ reductionDestinationBasePointer + receiveStartByte,
+ receiveLengthByte / getDataTypeSize(dataElement));
+ jobs[currentJob + 1].first = copy;
+
+ nextNotification++;
+ nextRemoteCommunicationBufferByte += sendLengthByte;
+ nextLocalCommunicationBufferByte += receiveLengthByte;
+ }
+ }
+
+ jobs.back().first.scaling = numRanks;
+ }
+
+ inline char* allreduceButterfly::getSegmentPointer(
+ const gaspi_segment_id_t segment) {
+ gaspi_pointer_t p;
+ gaspiCheckReturn(gaspi_segment_ptr(segment, &p),
+ "failed getting segment pointer");
+ return (char*) p;
+ }
+
+ inline unsigned long allreduceButterfly::chunkIndexToByte(
+ const long chunkIndex) const {
+ return ((totalLength * chunkIndex + numRanks - 1) / numRanks)
+ * getDataTypeSize(dataElement);
+ }
+
+ void allreduceButterfly::setAllToAll(
+ const std::vector& ranks) {
+ gaspi_notification_id_t nextNotification = locationReduce.firstNotification;
+
+ for (long ring=topology.getNumberOfRings() - 1; ring >=0 ; ring--) {
+
+ const rankIndexType ringLength = topology.getRingLength(ring);
+ const rankIndexType ringRank = topology.getLocalRankInRing(ring);
+ const bufferIndexType bufferLengthIndex = topology.getBufferLength(ring);
+ const gaspi_rank_t outgoingGlobalRank =
+ ranks[topology.getGlobalRankToWriteInRing(ring)];
+
+ for (unsigned long loop=0; loop < ringLength - 1; loop++) {
+ const unsigned long currentJob = receiver.size() - 1;
+ receiver.push_back(
+ new mailBoxGaspi(locationReduce.segment, nextNotification));
+ jobs.push_back(jobType());
+
+ const bufferIndexType transferBufferID =
+ (ringRank + ringLength - loop) % ringLength;
+ const bufferIndexType transferStartIndex =
+ topology.getBufferStart(ring, transferBufferID);
+ const gaspi_offset_t transferStartByte =
+ chunkIndexToByte(transferStartIndex);
+ const long transferLengthByte =
+ chunkIndexToByte(transferStartIndex + bufferLengthIndex)
+ - transferStartByte;
+
+ const writer::transferParameters transfer(
+ true,
+ outgoingGlobalRank,
+ locationReduce.segment,
+ locationReduce.offset + transferStartByte,
+ locationReduce.segment,
+ locationReduce.offset + transferStartByte,
+ transferLengthByte,
+ nextNotification);
+ jobs[currentJob].second = transfer;
+
+ nextNotification++;
+ }
+ }
+ }
+
+ allreduceButterfly::~allreduceButterfly() {
+ delete reducer;
+ for (unsigned long i=1; i < receiver.size(); i++) {
+ delete receiver[i];
+ }
+ }
+
+ int allreduceButterfly::operator()() {
+ const unsigned long phase = status.get();
+ // could be a problem if we overtake one iteration?
+ if (!receiver[phase]->gotNotification()) {
+ return -1;
+ }
+
+ reducer->operator()(jobs[phase].first);
+ // hier schon freigeben?
+ sender(jobs[phase].second);
+
+ return (status.increment() == 0) ? 0 : -1;
+ }
+
+ void allreduceButterfly::signal() {
+ trigger.notify();
+ }
+
+ gaspi_pointer_t allreduceButterfly::getReducePointer() const {
+ return locationReducePointer;
+ }
+
+ long allreduceButterfly::getNumberOfElementsSegmentCommunicate(
+ const long len,
+ const long numRanks) {
+ return ((len + numRanks - 1) / numRanks) * (numRanks - 1);
+ }
+
+ unsigned long allreduceButterfly::getNumberOfNotifications(
+ const long numRanks) {
+ const nestedRingParameter topology(numRanks);
+
+ gaspi_notification_id_t notifications = 0;
+ for (unsigned long i=0; i < topology.getNumberOfRings(); i++) {
+ notifications += topology.getRingLength(i) - 1;
+ }
+
+ return notifications;
+ }
+
+ std::ostream& allreduceButterfly::report(std::ostream& s) const {
+ char* pr = getSegmentPointer(locationReduce.segment);
+ char* pc = getSegmentPointer(locationCommunicate.segment);
+ const unsigned long phase = status.get();
+ s << "total length: " << totalLength << std::endl
+ << "dataElement: " << dataElement << std::endl
+ << "numRanks: " << numRanks << std::endl
+ << "rank: " << rank << std::endl
+ << "topology.getNumberOfRings" << topology.getNumberOfRings() << std::endl
+ << "getNumberOfNotifications(): "
+ << getNumberOfNotifications(numRanks) << std::endl
+ << "segmentReduce: " << long(locationReduce.segment) << std::endl
+ << "offsetReduce: " << locationReduce.offset << std::endl
+ << "firstNotificationReduce: " << locationReduce.firstNotification
+ << std::endl
+ << "segmentCommunicate: " << long(locationCommunicate.segment)
+ << std::endl
+ << "offsetCommunicate: " << locationCommunicate.offset << std::endl
+ << "firstNotificationCommunicate: "
+ << locationCommunicate.firstNotification << std::endl
+ << "pointer segment reduce : "
+ << (void*)getSegmentPointer(locationReduce.segment) << std::endl
+ << "pointer segment communicate: "
+ << (void*)getSegmentPointer(locationCommunicate.segment) << std::endl
+ << "phase " << phase << std::endl;
+ for (unsigned long i=0; i < jobs.size(); i++) {
+ s << ".........................." << std::endl;
+ s << "phase " << i << std::endl;
+ if (i==0) {
+ s << "Receiver: " << "user" << std::endl;
+ } else {
+ mailBoxGaspi* m = (mailBoxGaspi*) receiver[i];
+ s << "Receiver: segment " << long(m->getSegmentID())
+ << " notification ID " << m->getMailID() << std::endl;
+ }
+
+ if (jobs[i].first.len > 0) {
+ s << "Reduce : src " << jobs[i].first.source
+ << " (" << (char*)jobs[i].first.source - pc << ")"
+ << " dst " << jobs[i].first.destination
+ << " (" << (char*)jobs[i].first.destination - pr << ")"
+ << " ele " << jobs[i].first.len
+ << " (" << jobs[i].first.len * getDataTypeSize(dataElement) << ")"
+ << std::endl;
+ } else {
+ s << "Reduce : idle" << std::endl;
+ }
+
+ s << "Send : ";
+ jobs[i].second.report(s) << std::endl;
+ }
+ s << ".........................." << std::endl;
+
+ return s;
+ }
+ }
+}
diff --git a/src/gpi_comm_lib/collectives/lib/allreduceButterfly.h b/src/gpi_comm_lib/collectives/lib/allreduceButterfly.h
new file mode 100644
index 00000000..194bac0a
--- /dev/null
+++ b/src/gpi_comm_lib/collectives/lib/allreduceButterfly.h
@@ -0,0 +1,118 @@
+#pragma once
+
+#include "allreduce.h"
+#include "counter.h"
+#include "gpi/Group.hpp"
+#include "mailBox.h"
+#include "mailBoxLocal.h"
+#include "queues.h"
+#include "reduce.h"
+#include "writer.h"
+
+#include
+
+#include
+
+namespace tarantella
+{
+ namespace collectives
+ {
+ class nestedRingParameter {
+ public:
+ typedef unsigned long rankIndexType;
+ typedef unsigned long levelType;
+ typedef unsigned long bufferIndexType;
+
+ nestedRingParameter(const rankIndexType numRanks_,
+ const rankIndexType rank_=0);
+
+ rankIndexType getNumberOfRings() const;
+ rankIndexType getRingLength(const levelType level) const;
+ rankIndexType getLocalRankInRing(const levelType level) const;
+ rankIndexType getGlobalRankToWriteInRing(const levelType level) const;
+ bufferIndexType getBufferLength(const levelType level) const;
+ bufferIndexType getBufferStart(const levelType level,
+ const bufferIndexType buffer) const;
+
+ private:
+
+ typedef std::vector ringIndicesType;
+ typedef std::vector ringSizesType;
+ typedef std::vector stridesType;
+
+ static inline ringSizesType getRingSizes(rankIndexType numRanks);
+ static inline stridesType getStrides(const ringSizesType& ringSizes);
+ static inline ringIndicesType getRingIndices(const ringSizesType& ringSizes,
+ const rankIndexType rank);
+
+ const rankIndexType numRanks;
+ const rankIndexType rank;
+ const ringSizesType ringSizes;
+ const stridesType strides;
+ const ringIndicesType ringIndices;
+ };
+
+ class allreduceButterfly : public allreduce {
+ public:
+
+ struct segmentBuffer {
+ gaspi_segment_id_t segment;
+ gaspi_offset_t offset;
+ gaspi_notification_id_t firstNotification;
+ };
+
+ allreduceButterfly(const long len,
+ const dataType data,
+ const reductionType reduction,
+ const segmentBuffer segmentReduce,
+ const segmentBuffer segmentCommunicate,
+ queues& queues_,
+ GPI::Group const& group_);
+ ~allreduceButterfly();
+ int operator()();
+ void signal();
+
+ gaspi_pointer_t getReducePointer() const;
+ static long getNumberOfElementsSegmentCommunicate(const long len,
+ const long numRanks);
+ static unsigned long getNumberOfNotifications(const long numRanks);
+ std::ostream& report(std::ostream& s) const;
+
+ private:
+
+ typedef nestedRingParameter::rankIndexType rankIndexType;
+ typedef nestedRingParameter::bufferIndexType bufferIndexType;
+ typedef std::pair jobType;
+
+ inline long getNumRanks() const;
+ static inline long getRank();
+ std::vector getRanks() const;
+ static inline rankIndexType getRankIndex(
+ gaspi_rank_t rank,
+ const std::vector& ranks);
+ void setReduceScatter(const std::vector& ranks);
+ inline static char* getSegmentPointer(const gaspi_segment_id_t segment);
+ inline unsigned long chunkIndexToByte(const long chunkIndex) const;
+ void setAllToAll(const std::vector& ranks);
+
+ const long totalLength;
+ const dataType dataElement;
+ GPI::Group const group;
+ const long numRanks;
+ const gaspi_rank_t rank;
+ const segmentBuffer locationReduce;
+ const gaspi_pointer_t locationReducePointer;
+ const segmentBuffer locationCommunicate;
+
+ const nestedRingParameter topology;
+
+ mailBoxLocal trigger;
+ std::vector receiver;
+ std::vector jobs;
+
+ writer sender;
+ reduce * reducer;
+ counter status;
+ };
+ }
+}
diff --git a/src/gpi_comm_lib/collectives/lib/allreduceButterflyDoubleBuffer.cpp b/src/gpi_comm_lib/collectives/lib/allreduceButterflyDoubleBuffer.cpp
new file mode 100755
index 00000000..484b5268
--- /dev/null
+++ b/src/gpi_comm_lib/collectives/lib/allreduceButterflyDoubleBuffer.cpp
@@ -0,0 +1,91 @@
+#include "allreduceButterflyDoubleBuffer.h"
+
+namespace tarantella
+{
+ namespace collectives
+ {
+ allreduceButterflyDoubleBuffer::allreduceButterflyDoubleBuffer(
+ const long len,
+ const dataType data,
+ const reductionType reduction,
+ const allreduceButterfly::segmentBuffer segmentReduce0,
+ const allreduceButterfly::segmentBuffer segmentReduce1,
+ const allreduceButterfly::segmentBuffer segmentCommunicate,
+ queues& queues,
+ GPI::Group const& group)
+ : state(0),
+ reduceFirst(len, data, reduction, segmentReduce0,
+ segmentCommunicate, queues, group),
+ reduceSecond(len, data, reduction, segmentReduce1,
+ segmentCommunicate, queues, group) {
+ tableReduce[0] = &reduceFirst;
+ tableReduce[1] = &reduceSecond;
+ }
+
+ int allreduceButterflyDoubleBuffer::operator()() {
+ const int result = getReduce()();
+
+ if (!result) {
+ flipReduce();
+ }
+
+ return result;
+ }
+
+ inline allreduceButterfly& allreduceButterflyDoubleBuffer::getReduce() const {
+ return tableReduce[stateToIndex(state)][0];
+ }
+
+ inline long allreduceButterflyDoubleBuffer::stateToIndex(const long state) {
+ return state & 1l;
+ }
+
+ inline void allreduceButterflyDoubleBuffer::flipReduce() {
+ __sync_fetch_and_add(&state, 1l);
+ }
+
+ void allreduceButterflyDoubleBuffer::signal() {
+ getReduce().signal();
+ }
+
+ gaspi_pointer_t allreduceButterflyDoubleBuffer::getActiveReducePointer() const {
+ return getReduce().getReducePointer();
+ }
+
+ gaspi_pointer_t allreduceButterflyDoubleBuffer::getResultsPointer() const {
+ return getOtherReduce().getReducePointer();
+ }
+
+ inline const allreduceButterfly&
+ allreduceButterflyDoubleBuffer::getOtherReduce() const {
+ return tableReduce[invertIndex(stateToIndex(state))][0];
+ }
+
+ inline long allreduceButterflyDoubleBuffer::invertIndex(const long state) {
+ return state ^ 1l;
+ }
+
+ long allreduceButterflyDoubleBuffer::getNumberOfElementsSegmentCommunicate(
+ const long len,
+ const long numRanks) {
+ return allreduceButterfly::getNumberOfElementsSegmentCommunicate(len,
+ numRanks);
+ }
+
+ unsigned long allreduceButterflyDoubleBuffer::getNumberOfNotifications(
+ const long numRanks) {
+ return allreduceButterfly::getNumberOfNotifications(numRanks);
+ }
+
+ std::ostream& allreduceButterflyDoubleBuffer::report(std::ostream& s) const {
+ s << "stateExecute: " << state << std::endl
+ << "***** reduceFirst *****" << std::endl;
+ reduceFirst.report(s);
+ s << "***** reduceSecond *****" << std::endl;
+ reduceSecond.report(s);
+
+ return s;
+ }
+ }
+}
+
\ No newline at end of file
diff --git a/src/gpi_comm_lib/collectives/lib/allreduceButterflyDoubleBuffer.h b/src/gpi_comm_lib/collectives/lib/allreduceButterflyDoubleBuffer.h
new file mode 100755
index 00000000..e4f503b7
--- /dev/null
+++ b/src/gpi_comm_lib/collectives/lib/allreduceButterflyDoubleBuffer.h
@@ -0,0 +1,52 @@
+#pragma once
+
+#include "allreduceButterfly.h"
+#include "gpi/Group.hpp"
+
+namespace tarantella
+{
+ namespace collectives
+ {
+ class allreduceButterflyDoubleBuffer : public allreduce {
+ public:
+
+ allreduceButterflyDoubleBuffer(
+ const long len,
+ const dataType data,
+ const reductionType reduction,
+ const allreduceButterfly::segmentBuffer segmentReduce0,
+ const allreduceButterfly::segmentBuffer segmentReduce1,
+ const allreduceButterfly::segmentBuffer segmentCommunicate,
+ queues& queues,
+ GPI::Group const& group);
+ int operator()();
+ void signal();
+
+ gaspi_pointer_t getActiveReducePointer() const;
+ gaspi_pointer_t getResultsPointer() const;
+ static long getNumberOfElementsSegmentCommunicate(const long len,
+ const long numRanks);
+ static unsigned long getNumberOfNotifications(const long numRanks);
+ std::ostream& report(std::ostream& s) const;
+
+ private:
+
+ inline allreduceButterfly& getReduce() const;
+ static inline long stateToIndex(const long state);
+ inline void flipReduce();
+ inline const allreduceButterfly& getOtherReduce() const;
+ static inline long invertIndex(const long state);
+
+ static const long CACHE_LINE_SIZE = 64;
+
+ char pad0[CACHE_LINE_SIZE];
+ volatile long state;
+ char pad1[CACHE_LINE_SIZE];
+
+ allreduceButterfly reduceFirst;
+ allreduceButterfly reduceSecond;
+ allreduceButterfly* tableReduce[2];
+ };
+ }
+}
+
\ No newline at end of file
diff --git a/src/gpi_comm_lib/collectives/lib/broadcast.cpp b/src/gpi_comm_lib/collectives/lib/broadcast.cpp
new file mode 100755
index 00000000..c1b814e1
--- /dev/null
+++ b/src/gpi_comm_lib/collectives/lib/broadcast.cpp
@@ -0,0 +1,197 @@
+#include "broadcast.h"
+#include "gpi/gaspiCheckReturn.hpp"
+#include "mailBoxGaspi.h"
+
+#include
+#include
+
+namespace tarantella
+{
+ namespace collectives
+ {
+ using tarantella::GPI::gaspiCheckReturn;
+
+ broadcast::broadcast(
+ const gaspi_rank_t master_,
+ const long len,
+ const gaspi_segment_id_t segment_,
+ const gaspi_offset_t offset_,
+ const gaspi_notification_id_t firstNotification_,
+ queues& queues_ )
+ : totalLength(len),
+ group(GASPI_GROUP_ALL),
+ numRanks(getNumRanks()),
+ rank(getRank()),
+ masterRank(master_),
+ segment(segment_),
+ offset(offset_),
+ firstNotification(firstNotification_),
+ sender(queues_),
+ status((rank == masterRank) ? 1 : numRanks){
+
+ std::vector ranks(numRanks);
+ gaspiCheckReturn(gaspi_group_ranks(group, &ranks[0]),
+ "gaspi_group_ranks failed with");
+ const unsigned long rankIndex = getRankIndex(rank, ranks);
+
+ if (rank == masterRank) {
+ setMaster(rankIndex, ranks);
+ } else {
+ setWorker(rankIndex, ranks);
+ }
+ }
+
+ long broadcast::getNumRanks() const {
+ gaspi_number_t size;
+ gaspiCheckReturn(gaspi_group_size(group, &size),
+ "gaspi_group_size failed with ");
+ return size;
+ }
+
+ long broadcast::getRank() {
+ gaspi_rank_t rank;
+ gaspiCheckReturn(gaspi_proc_rank(&rank),
+ "gaspi_proc_rank failed with ");
+ return rank;
+ }
+
+ long broadcast::getRankIndex(gaspi_rank_t rank,
+ const std::vector& ranks) {
+ unsigned long rankIndex;
+ if (find(ranks.begin(), ranks.end(), rank) == ranks.end()) {
+ throw std::runtime_error("rank not member of group");
+ } else {
+ rankIndex = find(ranks.begin(), ranks.end(), rank)
+ - ranks.begin();
+ }
+ return rankIndex;
+ }
+
+ void broadcast::setMaster(
+ const unsigned long rankIndex,
+ const std::vector& ranks) {
+ const gaspi_rank_t partner = ranks[getPartnerIndex(rankIndex)];
+
+ receiver.push_back(&trigger);
+
+ if (partner != rank) {
+ for (long c=0; c < numRanks; c++) {
+ writer::transferParameters job(
+ true,
+ partner,
+ segment,
+ offset + chunkIndexToByte(c),
+ segment,
+ offset + chunkIndexToByte(c),
+ chunkIndexToByte(c + 1) - chunkIndexToByte(c),
+ firstNotification + c);
+ jobs.push_back(job);
+ }
+ }
+ }
+
+ inline unsigned long broadcast::getPartnerIndex(
+ const unsigned long rankIndex) const {
+ return (rankIndex + 1) % numRanks;
+ }
+
+ void broadcast::setWorker(
+ const unsigned long rankIndex,
+ const std::vector& ranks) {
+ const gaspi_rank_t partner = ranks[getPartnerIndex(rankIndex)];
+
+ for (long c=0; c < numRanks; c++) {
+ receiver.push_back(
+ new mailBoxGaspi(segment, firstNotification + c));
+
+ if (partner == masterRank) {
+ jobs.push_back(writer::transferParameters());
+ } else {
+ writer::transferParameters transfer(
+ true,
+ partner,
+ segment,
+ offset + chunkIndexToByte(c),
+ segment,
+ offset + chunkIndexToByte(c),
+ chunkIndexToByte(c + 1) - chunkIndexToByte(c),
+ firstNotification + c);
+ jobs.push_back(transfer);
+ }
+ }
+ }
+
+ inline unsigned long broadcast::chunkIndexToByte(
+ const long chunkIndex) const {
+ return ((totalLength * chunkIndex + numRanks - 1) / numRanks);
+ }
+
+ broadcast::~broadcast() {
+ if (rank != masterRank) {
+ for (unsigned long i=0; i < receiver.size(); i++) {
+ delete receiver[i];
+ }
+ }
+ }
+
+ int broadcast::operator()() {
+ const unsigned long phase = status.get();
+ if (!receiver[phase]->gotNotification()) {
+ return -1;
+ }
+
+ if (rank == masterRank) {
+ for (unsigned long i=0; i < jobs.size(); i++) {
+ sender(jobs[i]);
+ }
+ } else {
+ sender(jobs[phase]);
+ }
+
+ return (status.increment() == 0) ? 0 : -1;
+ }
+
+ void broadcast::signal() {
+ trigger.notify();
+ }
+
+ long broadcast::getNumberOfNotifications(const long numRanks) {
+ return (numRanks > 1) ? numRanks : 0;
+ }
+
+ std::ostream& broadcast::report(std::ostream& s) const {
+ const unsigned long phase = status.get();
+ s << "total length: " << totalLength << std::endl
+ << "numRanks: " << numRanks << std::endl
+ << "rank: " << rank << std::endl
+ << "masterRank: " << masterRank << std::endl
+ << "segment: " << long(segment) << std::endl
+ << "offset: " << offset << std::endl
+ << "firstNotification: " << firstNotification << std::endl
+ << std::endl
+ << "phase " << phase << std::endl;
+ for (unsigned long i=0; i < jobs.size(); i++) {
+ s << ".........................." << std::endl;
+ s << "phase " << i << std::endl;
+ if ((i==0) && (rank == masterRank)) {
+ s << "Receiver: " << "user" << std::endl;
+ } else {
+ if (i < receiver.size()) {
+ mailBoxGaspi* m = (mailBoxGaspi*) receiver[i];
+ s << "Receiver: segment " << long(m->getSegmentID())
+ << " notification ID " << m->getMailID() << std::endl;
+ } else {
+ s << "Receiver: idle" << std::endl;
+ }
+ }
+
+ s << "Send : ";
+ jobs[i].report(s) << std::endl;
+ }
+ s << ".........................." << std::endl;
+
+ return s;
+ }
+ }
+}
+
\ No newline at end of file
diff --git a/src/gpi_comm_lib/collectives/lib/broadcast.h b/src/gpi_comm_lib/collectives/lib/broadcast.h
new file mode 100755
index 00000000..e59e0e89
--- /dev/null
+++ b/src/gpi_comm_lib/collectives/lib/broadcast.h
@@ -0,0 +1,62 @@
+#pragma once
+
+#include "writer.h"
+#include "mailBox.h"
+#include "mailBoxLocal.h"
+#include "counter.h"
+#include "queues.h"
+
+#include
+#include
+
+namespace tarantella
+{
+ namespace collectives
+ {
+ class broadcast {
+ public:
+ broadcast(const gaspi_rank_t master_,
+ const long len,
+ const gaspi_segment_id_t segment_,
+ const gaspi_offset_t offset_,
+ const gaspi_notification_id_t firstNotification_,
+ queues& queues_);
+ ~broadcast();
+ int operator()();
+ void signal();
+ static long getNumberOfNotifications(const long numRanks);
+ std::ostream& report(std::ostream& s) const;
+
+ private:
+
+ long getNumRanks() const;
+ static long getRank();
+ static long getRankIndex(gaspi_rank_t rank,
+ const std::vector& ranks);
+ void setMaster(const unsigned long rankIndex,
+ const std::vector& ranks);
+ inline unsigned long getPartnerIndex(const unsigned long rankIndex) const;
+ void setWorker(const unsigned long rankIndex,
+ const std::vector& ranks);
+ inline unsigned long chunkIndexToByte(const long chunkIndex) const;
+ inline static char* getSegmentPointer(const gaspi_segment_id_t segment);
+
+ const long totalLength;
+ const gaspi_group_t group;
+ const long numRanks;
+ const gaspi_rank_t rank;
+ const gaspi_rank_t masterRank;
+ const gaspi_segment_id_t segment;
+ const gaspi_offset_t offset;
+ const gaspi_notification_id_t firstNotification;
+
+ mailBoxLocal trigger;
+ std::vector receiver;
+ std::vector jobs;
+
+ writer sender;
+ counter status;
+ };
+ }
+}
+
\ No newline at end of file
diff --git a/src/gpi_comm_lib/collectives/lib/counter.cpp b/src/gpi_comm_lib/collectives/lib/counter.cpp
new file mode 100755
index 00000000..f36ea582
--- /dev/null
+++ b/src/gpi_comm_lib/collectives/lib/counter.cpp
@@ -0,0 +1,19 @@
+#include "counter.h"
+
+namespace tarantella
+{
+ namespace collectives
+ {
+ counter::counter(const unsigned long phasePeriod_)
+ : phasePeriod(phasePeriod_),
+ value(0) {}
+
+ unsigned long counter::increment() {
+ return (++value) % phasePeriod;
+ }
+
+ unsigned long counter::get() const {
+ return value % phasePeriod;
+ }
+ }
+}
diff --git a/src/gpi_comm_lib/collectives/lib/counter.h b/src/gpi_comm_lib/collectives/lib/counter.h
new file mode 100755
index 00000000..5a630592
--- /dev/null
+++ b/src/gpi_comm_lib/collectives/lib/counter.h
@@ -0,0 +1,21 @@
+#pragma once
+
+#include
+
+namespace tarantella
+{
+ namespace collectives
+ {
+ class counter {
+ public:
+ counter(const unsigned long phasePeriod_ = 1);
+ unsigned long increment();
+ unsigned long get() const;
+ private:
+
+ const unsigned long phasePeriod;
+ std::atomic value;
+ };
+ }
+}
+
\ No newline at end of file
diff --git a/src/gpi_comm_lib/collectives/lib/mailBox.h b/src/gpi_comm_lib/collectives/lib/mailBox.h
new file mode 100755
index 00000000..a7d9a9d2
--- /dev/null
+++ b/src/gpi_comm_lib/collectives/lib/mailBox.h
@@ -0,0 +1,14 @@
+#pragma once
+
+namespace tarantella
+{
+ namespace collectives
+ {
+ class mailBox
+ {
+ public:
+ virtual bool gotNotification() = 0;
+ virtual ~mailBox() = default;
+ };
+ }
+}
diff --git a/src/gpi_comm_lib/collectives/lib/mailBoxGaspi.cpp b/src/gpi_comm_lib/collectives/lib/mailBoxGaspi.cpp
new file mode 100755
index 00000000..3e2b3738
--- /dev/null
+++ b/src/gpi_comm_lib/collectives/lib/mailBoxGaspi.cpp
@@ -0,0 +1,47 @@
+#include "mailBoxGaspi.h"
+#include "gpi/gaspiCheckReturn.hpp"
+
+#include
+
+namespace tarantella
+{
+ namespace collectives
+ {
+ using tarantella::GPI::gaspiCheckReturn;
+
+ mailBoxGaspi::mailBoxGaspi(const gaspi_segment_id_t segmentID_,
+ const gaspi_notification_id_t mailID_)
+ : segmentID(segmentID_),
+ mailID(mailID_) {}
+
+ bool mailBoxGaspi::gotNotification() {
+ gaspi_notification_id_t event;
+ gaspi_return_t err = gaspi_notify_waitsome(segmentID,
+ mailID,
+ 1,
+ &event,
+ GASPI_TEST);
+ if (err == GASPI_TIMEOUT)
+ {
+ return false;
+ }
+ gaspiCheckReturn(err, "gaspi_notify_waitsome failed with ");
+
+ assert(mailID == event);
+ gaspi_notification_t value;
+ gaspiCheckReturn(gaspi_notify_reset(segmentID,
+ event,
+ &value),
+ "gaspi_notify_reset failed with ");
+ return value != 0;
+ }
+
+ gaspi_segment_id_t mailBoxGaspi::getSegmentID() const {
+ return segmentID;
+ }
+
+ gaspi_notification_id_t mailBoxGaspi::getMailID() const {
+ return mailID;
+ }
+ }
+}
diff --git a/src/gpi_comm_lib/collectives/lib/mailBoxGaspi.h b/src/gpi_comm_lib/collectives/lib/mailBoxGaspi.h
new file mode 100755
index 00000000..c33b6c86
--- /dev/null
+++ b/src/gpi_comm_lib/collectives/lib/mailBoxGaspi.h
@@ -0,0 +1,26 @@
+#pragma once
+
+#include "mailBox.h"
+
+#include
+
+namespace tarantella
+{
+ namespace collectives
+ {
+ class mailBoxGaspi : public mailBox
+ {
+ public:
+ mailBoxGaspi(const gaspi_segment_id_t segmentID_,
+ const gaspi_notification_id_t mailID_);
+ bool gotNotification() override;
+ gaspi_segment_id_t getSegmentID() const;
+ gaspi_notification_id_t getMailID() const;
+
+ private:
+
+ const gaspi_segment_id_t segmentID;
+ const gaspi_notification_id_t mailID;
+ };
+ }
+}
diff --git a/src/gpi_comm_lib/collectives/lib/mailBoxLocal.cpp b/src/gpi_comm_lib/collectives/lib/mailBoxLocal.cpp
new file mode 100755
index 00000000..352f7545
--- /dev/null
+++ b/src/gpi_comm_lib/collectives/lib/mailBoxLocal.cpp
@@ -0,0 +1,20 @@
+#include "mailBoxLocal.h"
+
+namespace tarantella
+{
+ namespace collectives
+ {
+ mailBoxLocal::mailBoxLocal()
+ : status(0),
+ target(0) {}
+
+ bool mailBoxLocal::gotNotification() {
+ unsigned long statusOld = status;
+ return (statusOld < target) && status.compare_exchange_strong(statusOld, statusOld + 1);
+ }
+
+ void mailBoxLocal::notify() {
+ ++target;
+ }
+ }
+}
diff --git a/src/gpi_comm_lib/collectives/lib/mailBoxLocal.h b/src/gpi_comm_lib/collectives/lib/mailBoxLocal.h
new file mode 100755
index 00000000..0d9b34fa
--- /dev/null
+++ b/src/gpi_comm_lib/collectives/lib/mailBoxLocal.h
@@ -0,0 +1,22 @@
+#pragma once
+
+#include "mailBox.h"
+#include
+
+namespace tarantella
+{
+ namespace collectives
+ {
+ class mailBoxLocal : public mailBox
+ {
+ public:
+ mailBoxLocal();
+ bool gotNotification() override;
+ void notify();
+
+ private:
+ std::atomic status;
+ std::atomic target;
+ };
+ }
+}
diff --git a/src/gpi_comm_lib/collectives/lib/queues.cpp b/src/gpi_comm_lib/collectives/lib/queues.cpp
new file mode 100755
index 00000000..d8feb3cb
--- /dev/null
+++ b/src/gpi_comm_lib/collectives/lib/queues.cpp
@@ -0,0 +1,59 @@
+#include "queues.h"
+#include "gpi/gaspiCheckReturn.hpp"
+
+namespace tarantella
+{
+ namespace collectives
+ {
+ using tarantella::GPI::gaspiCheckReturn;
+
+ queues::queues(const long num,
+ const gaspi_queue_id_t first)
+ : numQueues(num)
+ , state(0) {
+ for (long i=first; i < first + num; i++) {
+ queueStock.push_back(i);
+ }
+ }
+
+ queues::queues(const std::vector& queues_)
+ : numQueues(queues_.size()),
+ state(0),
+ queueStock(queues_) {
+ }
+
+ gaspi_queue_id_t queues::get() const {
+ return stateToQueue(state);
+ }
+
+ inline gaspi_queue_id_t queues::stateToQueue(const long state_) const {
+ return queueStock[state_ % numQueues];
+ }
+
+ gaspi_queue_id_t queues::swap(gaspi_queue_id_t badQueue) {
+ const long stateLocal = state;
+ const gaspi_queue_id_t queueLocal = stateToQueue(stateLocal);
+
+ if (queueLocal != badQueue) {
+ return queueLocal;
+ } else {
+ const long stateLocalNew = stateLocal + 1;
+ const gaspi_queue_id_t queueLocalNew = stateToQueue(stateLocalNew);
+
+ clearQueue(queueLocalNew);
+
+ const long stateBeforeSwap =
+ __sync_val_compare_and_swap(&state, stateLocal, stateLocalNew);
+
+ return (stateBeforeSwap == stateLocal)
+ ? queueLocalNew
+ : stateToQueue(stateBeforeSwap);
+ };
+ }
+
+ inline void queues::clearQueue(const gaspi_queue_id_t queue) {
+ gaspiCheckReturn(gaspi_wait(queue, GASPI_BLOCK),
+ "Failed to clear queue with ");
+ }
+ }
+}
diff --git a/src/gpi_comm_lib/collectives/lib/queues.h b/src/gpi_comm_lib/collectives/lib/queues.h
new file mode 100755
index 00000000..0679a7d0
--- /dev/null
+++ b/src/gpi_comm_lib/collectives/lib/queues.h
@@ -0,0 +1,33 @@
+#pragma once
+
+#include
+#include
+
+namespace tarantella
+{
+ namespace collectives
+ {
+ class queues {
+ public:
+ queues(const long num = 2,
+ const gaspi_queue_id_t first = 0);
+ queues(const std::vector& queues_);
+
+ gaspi_queue_id_t get() const;
+ gaspi_queue_id_t swap(gaspi_queue_id_t badQueue);
+
+ private:
+ inline gaspi_queue_id_t stateToQueue(const long) const;
+ inline void clearQueue(const gaspi_queue_id_t queue);
+
+ static const long CACHE_LINE_SIZE = 64;
+ const long numQueues;
+
+ char pad0 [CACHE_LINE_SIZE];
+ volatile long state;
+ char pad1 [CACHE_LINE_SIZE];
+
+ std::vector queueStock;
+ };
+ }
+}
diff --git a/src/gpi_comm_lib/collectives/lib/reduce.cpp b/src/gpi_comm_lib/collectives/lib/reduce.cpp
new file mode 100755
index 00000000..6e1ba9c8
--- /dev/null
+++ b/src/gpi_comm_lib/collectives/lib/reduce.cpp
@@ -0,0 +1,188 @@
+#include "reduce.h"
+
+#include
+#include
+
+namespace tarantella
+{
+ namespace collectives
+ {
+ namespace
+ {
+ template
+ inline void add(const reduce::task& t) {
+ const T* const a = (const T*) t.source;
+ T* const b = (T*) t.destination;
+ const long n = t.len;
+
+ for (long i=0; i < n; i++) {
+ b[i] += a[i];
+ }
+ }
+
+ template
+ inline void average(const reduce::task& t) {
+ if (t.scaling > 1) {
+ const T* const a = (const T*) t.source;
+ T* const b = (T*) t.destination;
+ const long n = t.len;
+ const T s = t.scaling;
+
+ for (long i=0; i < n; i++) {
+ b[i] = (b[i] + a[i]) / s;
+ }
+ } else {
+ add(t);
+ }
+ }
+
+ template
+ inline void averageopt(const reduce::task& t) {
+ if (t.scaling > 1) {
+ const T* const a = (const T*) t.source;
+ T* const b = (T*) t.destination;
+ const long n = t.len;
+ const T s = T(1) / T(t.scaling);
+
+ for (long i=0; i < n; i++) {
+ b[i] = (b[i] + a[i]) * s;
+ }
+ } else {
+ add(t);
+ }
+ }
+
+ class reduce_float_sum : public reduce {
+ public:
+ void operator()(const task& t) const {
+ add(t);
+ }
+ };
+
+ class reduce_float_average : public reduce {
+ public:
+ void operator()(const task& t) const {
+ averageopt(t);
+ }
+ };
+
+ class reduce_double_sum : public reduce {
+ public:
+ void operator()(const task& t) const {
+ add(t);
+ }
+ };
+
+ class reduce_double_average : public reduce {
+ public:
+ void operator()(const task& t) const {
+ averageopt(t);
+ }
+ };
+
+ class reduce_int16_sum : public reduce {
+ public:
+ void operator()(const task& t) const {
+ add(t);
+ }
+ };
+
+ class reduce_int16_average : public reduce {
+ public:
+ void operator()(const task& t) const {
+ average(t);
+ }
+ };
+
+ class reduce_int32_sum : public reduce {
+ public:
+ void operator()(const task& t) const {
+ add(t);
+ }
+ };
+
+ class reduce_int32_average : public reduce {
+ public:
+ void operator()(const task& t) const {
+ average(t);
+ }
+ };
+ }
+
+ reduce * getReduce(const allreduce::dataType data,
+ const allreduce::reductionType reduction) {
+ reduce* p = NULL;
+
+ switch (data) {
+ case allreduce::FLOAT:
+ switch (reduction) {
+ case allreduce::SUM:
+ p = new reduce_float_sum();
+ break;
+ case allreduce::AVERAGE:
+ p = new reduce_float_average();
+ break;
+ default:
+ break;
+ }
+ break;
+ case allreduce::DOUBLE:
+ switch (reduction) {
+ case allreduce::SUM:
+ p = new reduce_double_sum;
+ break;
+ case allreduce::AVERAGE:
+ p = new reduce_double_average;
+ break;
+ default:
+ break;
+ }
+ break;
+ case allreduce::INT16:
+ switch (reduction) {
+ case allreduce::SUM:
+ p = new reduce_int16_sum;
+ break;
+ case allreduce::AVERAGE:
+ p = new reduce_int16_average;
+ break;
+ default:
+ break;
+ }
+ break;
+ case allreduce::INT32:
+ switch (reduction) {
+ case allreduce::SUM:
+ p = new reduce_int32_sum;
+ break;
+ case allreduce::AVERAGE:
+ p = new reduce_int32_average;
+ break;
+ default:
+ break;
+ }
+ break;
+ default:
+ break;
+ };
+
+ if (p == NULL) {
+ throw std::runtime_error(
+ "Unsupported combination of data type and reduction type");
+ }
+
+ return p;
+ }
+
+ size_t getDataTypeSize(const allreduce::dataType d) {
+ const size_t sizes[allreduce::NUM_TYPE] = {
+ sizeof(float),
+ sizeof(double),
+ sizeof(int16_t),
+ sizeof(int32_t)
+ };
+
+ return sizes[d];
+ }
+ }
+}
diff --git a/src/gpi_comm_lib/collectives/lib/reduce.h b/src/gpi_comm_lib/collectives/lib/reduce.h
new file mode 100755
index 00000000..e0b7ea91
--- /dev/null
+++ b/src/gpi_comm_lib/collectives/lib/reduce.h
@@ -0,0 +1,34 @@
+#pragma once
+
+#include "allreduce.h"
+
+#include
+
+namespace tarantella
+{
+ namespace collectives
+ {
+ class reduce {
+ public:
+ struct task {
+ const void* source;
+ void* destination;
+ long len;
+ unsigned long scaling;
+ task(const void* s = NULL,
+ void* d = NULL,
+ long n = 0,
+ unsigned long sc = 0)
+ : source(s), destination(d), len(n), scaling(sc) {}
+ };
+
+ virtual void operator()(const task& t) const = 0;
+ virtual ~reduce() {}
+ };
+
+ reduce * getReduce(const allreduce::dataType data,
+ const allreduce::reductionType reduction);
+
+ size_t getDataTypeSize(const allreduce::dataType d);
+ }
+}
diff --git a/src/gpi_comm_lib/collectives/lib/writer.cpp b/src/gpi_comm_lib/collectives/lib/writer.cpp
new file mode 100755
index 00000000..ed5ad550
--- /dev/null
+++ b/src/gpi_comm_lib/collectives/lib/writer.cpp
@@ -0,0 +1,80 @@
+#include "writer.h"
+#include "gpi/gaspiCheckReturn.hpp"
+
+#include
+
+namespace tarantella
+{
+ namespace collectives
+ {
+ const gaspi_size_t writer::MESSAGE_LENGTH_LIMIT = 0x40000000;
+
+ using tarantella::GPI::gaspiCheckReturn;
+
+ writer::transferParameters::transferParameters(
+ bool a,
+ gaspi_rank_t r,
+ gaspi_segment_id_t sl,
+ gaspi_offset_t ol,
+ gaspi_segment_id_t sr,
+ gaspi_offset_t orm,
+ gaspi_size_t sz,
+ gaspi_notification_id_t id)
+ : active(a),
+ rank(r),
+ segmentLocal(sl),
+ offsetLocal(ol),
+ segmentRemote(sr),
+ offsetRemote(orm),
+ size(sz),
+ notificationID(id)
+ {}
+
+ std::ostream& writer::transferParameters::report(std::ostream& s) const {
+ if (active) {
+ s << "rank " << rank
+ << " | sl " << long(segmentLocal)
+ << " ol " << offsetLocal
+ << " | sr " << long(segmentRemote)
+ << " or " << offsetRemote
+ << " ID " << notificationID
+ << " | sz " << size;
+ } else {
+ s << "idle";
+ }
+ return s;
+ }
+
+ writer::writer(queues& queues_)
+ : queueSource(queues_) {}
+
+ void writer::operator()(const transferParameters& p) {
+ if (!p.active) return;
+ //thread save? watch queue management!
+
+ if (p.size > MESSAGE_LENGTH_LIMIT) {
+ throw std::runtime_error("writer: message is too long");
+ }
+
+ gaspi_return_t err;
+ gaspi_queue_id_t queueLocal = queueSource.get();
+ while ((err = gaspi_write_notify(p.segmentLocal,
+ p.offsetLocal,
+ p.rank,
+ p.segmentRemote,
+ p.offsetRemote,
+ p.size,
+ p.notificationID,
+ 1,
+ queueLocal,
+ GASPI_BLOCK))
+ != GASPI_SUCCESS) {
+ if (err == GASPI_QUEUE_FULL) {
+ queueLocal = queueSource.swap(queueLocal);
+ } else {
+ gaspiCheckReturn(err, "gaspi_write_notify failed with ");
+ }
+ }
+ }
+ }
+}
diff --git a/src/gpi_comm_lib/collectives/lib/writer.h b/src/gpi_comm_lib/collectives/lib/writer.h
new file mode 100755
index 00000000..db580448
--- /dev/null
+++ b/src/gpi_comm_lib/collectives/lib/writer.h
@@ -0,0 +1,46 @@
+#pragma once
+
+#include "queues.h"
+
+#include
+#include
+
+namespace tarantella
+{
+ namespace collectives
+ {
+ class writer {
+ public:
+ struct transferParameters {
+ bool active;
+ gaspi_rank_t rank;
+ gaspi_segment_id_t segmentLocal;
+ gaspi_offset_t offsetLocal;
+ gaspi_segment_id_t segmentRemote;
+ gaspi_offset_t offsetRemote;
+ gaspi_size_t size;
+ gaspi_notification_id_t notificationID;
+ transferParameters(
+ bool a = false,
+ gaspi_rank_t r = 0,
+ gaspi_segment_id_t sl = 0,
+ gaspi_offset_t ol = 0,
+ gaspi_segment_id_t sr = 0,
+ gaspi_offset_t orm = 0,
+ gaspi_size_t sz = 0,
+ gaspi_notification_id_t id = 0);
+ std::ostream& report(std::ostream& s) const;
+ };
+
+ writer(queues& queues_);
+ void operator()(const transferParameters& p);
+
+ private:
+
+ static const gaspi_size_t MESSAGE_LENGTH_LIMIT;
+
+ queues& queueSource;
+ };
+
+ }
+}
\ No newline at end of file
diff --git a/src/gpi_comm_lib/distribution/GroupBuilder.hpp b/src/gpi_comm_lib/distribution/GroupBuilder.hpp
new file mode 100644
index 00000000..a3b78b8f
--- /dev/null
+++ b/src/gpi_comm_lib/distribution/GroupBuilder.hpp
@@ -0,0 +1,34 @@
+#pragma once
+
+#include "gpi/Context.hpp"
+#include "gpi/ResourceManager.hpp"
+
+#include
+
+namespace tarantella
+{
+ namespace distribution
+ {
+ class DataParallelGroupBuilder
+ {
+ public:
+ DataParallelGroupBuilder(GPI::Context& context)
+ : context(context)
+ { }
+
+ GPI::Group const get_group()
+ {
+ auto& resource_manager = context.get_resource_manager();
+ auto const num_ranks = context.get_comm_size();
+
+ std::vector all_ranks(num_ranks);
+ std::iota(all_ranks.begin(), all_ranks.end(), static_cast