Skip to content
This repository has been archived by the owner on Apr 28, 2023. It is now read-only.

Commit

Permalink
PyTorch Integration with Tensor Comprehensions
Browse files Browse the repository at this point in the history
  • Loading branch information
prigoyal committed Feb 27, 2018
1 parent c9b372f commit a1def65
Show file tree
Hide file tree
Showing 93 changed files with 5,419 additions and 189 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Expand Up @@ -13,5 +13,6 @@ conda
*/.nfs*
tensor_comprehensions.egg-info/
tensor_comprehensions/version.py
tensor_comprehensions/*.proto
slurm-*
examples/results*
examples/results*
9 changes: 4 additions & 5 deletions CMakeLists.txt
Expand Up @@ -174,8 +174,6 @@ endif()

################################################################################
# ATen

# ATen - if someone ships libATen.so, we try to use that if available
# first find python path
execute_process(COMMAND which python OUTPUT_VARIABLE __python OUTPUT_STRIP_TRAILING_WHITESPACE)
message(STATUS "PYTHON output: \n${__python}")
Expand All @@ -184,14 +182,15 @@ message(STATUS "PYTHON output: \n${__python}")
execute_process(COMMAND "${__python}" "-c" "import torch;" RESULT_VARIABLE __torch_install OUTPUT_STRIP_TRAILING_WHITESPACE)
message(STATUS "IMPORTING TORCH: \n${__torch_install}")

# also get the site-packages path where conda installs things
# also get the site-packages path where conda installs pytorch
execute_process(COMMAND "python" "-c"
"from distutils.sysconfig import get_python_lib; print(get_python_lib())"
OUTPUT_VARIABLE PYTHON_SITE_PACKAGES OUTPUT_STRIP_TRAILING_WHITESPACE)
message(STATUS "PYTHON site packages: \n${PYTHON_SITE_PACKAGES}")

# if PyTorch is installed, we get libATen.so.1 from there, otherwise build it
if (__torch_install EQUAL 0)
message(STATUS "TORCH INSTALLED, linking to ATen")
message(STATUS "TORCH INSTALLED, linking to ATen from PyTorch")
set(ATEN_INCLUDE_DIR "${PYTHON_SITE_PACKAGES}/torch/lib/include")
include_directories(${ATEN_INCLUDE_DIR})
find_library(ATEN_LIBRARIES NAMES libATen.so.1 PATHS ${PYTHON_SITE_PACKAGES}/torch/lib)
Expand Down Expand Up @@ -238,7 +237,7 @@ message(STATUS "Found glog: ${GLOG_LIBRARIES}")

add_subdirectory(src)
enable_testing()
add_subdirectory(pybinds)
add_subdirectory(tensor_comprehensions/pybinds)
add_subdirectory(test)

if (WITH_CAFFE2)
Expand Down
67 changes: 41 additions & 26 deletions CodeOwners.md
@@ -1,16 +1,52 @@
This file lists the "owners" for each part of Tensor Comprehensions.
These people are major contributors of a specific part, they ensure issues in that part get addressed and pull requests get reviewed.
If you want to contribute to Tensor Comprehensions, make sure to include these people in your review requests.
These people are major contributors of a specific part, they ensure issues in
that part get addressed and pull requests get reviewed. If you want to contribute
to Tensor Comprehensions, make sure to include these people in your review requests.

### Build system and Docker
### Documentation
* **Priya Goyal** [@prigoyal](https://github.com/prigoyal)
* **Albert Cohen** [@albertcohen](https://github.com/albertcohen)
```
docs/*
*.md
```

### Build system
* **Priya Goyal** [@prigoyal](https://github.com/prigoyal)
* **Nicolas Vasilache** [@nicolasvasilache](https://github.com/nicolasvasilache)
```
docker/*
.circleci/*
*/CMakeLists.txt
```

### Tensor Comprehensions language and frontend
### Docker
* **Priya Goyal** [@prigoyal](https://github.com/prigoyal)
```
docker/*
```

### Conda Packaging
* **Priya Goyal** [@prigoyal](https://github.com/prigoyal)
```
conda_recipes/*
```

### Python bindings
* **Priya Goyal** [@prigoyal](https://github.com/prigoyal)
```
tensor_comprehensions/pybinds/*
test_python/*
```

### Framework integration (PyTorch, Caffe2, etc.)
* **Priya Goyal** [@prigoyal](https://github.com/prigoyal)
```
tensor_comprehensions/*
src/c2/*
src/aten/*
```

### Tensor Comprehensions language and frontend
* **Zachary DeVito** [@zdevito](https://github.com/zdevito)
```
src/lang/*
Expand All @@ -37,27 +73,6 @@ src/core/polyhedral/*
src/autotuner/*
```

### Framework integration (ATen, Caffe2, etc.)
* **Priya Goyal** [@prigoyal](https://github.com/prigoyal)
```
src/c2/*
src/aten/*
```

### Documentation
* **Priya Goyal** [@prigoyal](https://github.com/prigoyal)
* **Albert Cohen** [@albertcohen](https://github.com/albertcohen)
```
docs/*
*.md
```

### Python bindings
* **Priya Goyal** [@prigoyal](https://github.com/prigoyal)
```
pybind/*
```

### Protocol buffers
* **Oleksandr Zinenko** [@ftynse](https://github.com/ftynse)
* **Theodoros Theodoridis** [@ttheodor](https://github.com/ttheodor)
Expand Down
60 changes: 30 additions & 30 deletions README.md
@@ -1,56 +1,56 @@
# ![Tensor Comprehensions](docs/source/_static/img/tc-logo-full-color-with-text-2.png)

Tensor Comprehensions (TC) is a fully-functional C++ library to *automatically* synthesize high-performance machine learning kernels using [Halide](https://github.com/halide/Halide), [ISL](http://isl.gforge.inria.fr/) and NVRTC or LLVM. TC additionally provides basic integration with Caffe2 and pybind11 bindings for use with python. We provide more details in our paper on [arXiv](https://arxiv.org/abs/1802.04730).
Tensor Comprehensions (TC) is a fully-functional C++ library to *automatically* synthesize high-performance machine learning kernels using [Halide](https://github.com/halide/Halide), [ISL](http://isl.gforge.inria.fr/) and NVRTC or LLVM. TC additionally provides basic integration with Caffe2 and PyTorch. We provide more details in our paper on [arXiv](https://arxiv.org/abs/1802.04730).

This library is designed to be highly portable, machine-learning-framework agnostic and only requires a simple tensor library with memory allocation, offloading and synchronization capabilities.

For now, we have integrated TC with the [Caffe2](https://github.com/caffe2/caffe2) and [ATen](https://github.com/pytorch/pytorch/tree/master/aten/src/ATen) tensor libraries.
For now, we have integrated TC with the [Caffe2](https://github.com/caffe2/caffe2) and [PyTorch](https://github.com/pytorch/pytorch/).

# A simple example

The following illustrates a short but powerful feature of the library: the capacity to JIT-compile high-performance machine learning kernels on demand, for specific sizes.

```cpp
#include <ATen/ATen.h>
#include "tc/aten/aten_compiler.h"
#include "tc/core/mapping_options.h"

// 1. Define and setup the TC compilation unit with CUDA memory management backed by ATen.
std::string tc = R"TC(
def TensorDot(float(N, C1, C2, H, W) I0, float(N, C2, C3, H, W) I1) -> (O) {
```python
import tensor_comprehensions as tc
import torch
lang = """
def tensordot(float(N, C1, C2, H, W) I0, float(N, C2, C3, H, W) I1) -> (O) {
O(n, c1, c3, h, w) +=! I0(n, c1, c2, h, w) * I1(n, c2, c3, h, w)
})TC";
}
"""
N, C1, C2, C3, H, W = 32, 512, 8, 2, 28, 28
tensordot = tc.define(lang, name="tensordot")
I0, I1 = torch.randn(N, C1, C2, H, W).cuda(), torch.randn(N, C2, C3, H, W).cuda()
best_options = tensordot.autotune(I0, I1, cache=True)
out = tensordot(I0, I1, options=best_options)
```

// 2. Allocate tensors with random data
at::Tensor I0 = at::CUDA(at::kFloat).rand({32, 512, 8, 28, 28});
at::Tensor I1 = at::CUDA(at::kFloat).rand({32, 8, 2, 28, 28});
std::vector<at::Tensor> outputs;
After a few generations of `autotuning` on a 2-GPU P100 system, we see results resembling:

// 3. Run autotuning with evolutionary search starting from a naive option
auto options = tc::MappingOptions::makeNaiveMappingOptions();
auto bestOption = autotune(cacheFilename, tc, "TensorDot", {I0, I1}, options, {options});
![Autotuning Sample](docs/source/_static/img/autotuning.png)

// 4. Compile and run the TC with the best option.
tc::ATenCompilationUnit atCompl;
atCompl.define(tc);
auto handle = atCompl.compile("TensorDot", {I0, I1}, bestOption);
atCompl.run("TensorDot", {I0, I1}, outputs, handle);
We have not yet characterized the precise fraction of peak performance we obtain but it is not uncommon to obtain 80%+ of peak shared memory bandwidth after autotuning. Solid register-level optimizations are still in the work but TC in its current form already addresses the productivity gap between the needs of research and the needs of production. Which is why we are excited to share it with the entire community and bring this collaborative effort in the open.

// 5. Perform precision checks against an ATen reference implementation
check({I0, I1}, outputs, [&I0, &I1](){ return ...; });
```
# Documentation

After a few generations of autotuning on a 2-GPU P100 system, we see results resembling:
**General**: You can find detailed information about Tensor Comprehensions [here](https://facebookresearch.github.io/TensorComprehensions/).

![Autotuning Sample](docs/source/_static/img/autotuning.png)
**C++ API**: We also provide documentation for our C++ API which can can be found [here](https://facebookresearch.github.io/TensorComprehensions/api/)

We have not yet characterized the precise fraction of peak performance we obtain but it is not uncommon to obtain 80%+ of peak shared memory bandwidth after autotuning. Solid register-level optimizations are still in the work but TC in its current form already addresses the productivity gap between the needs of research and the needs of production. Which is why we are excited to share it with the entire community and bring this collaborative effort in the open.
# Installation

## Binaries

We provide conda package for making it easy to install and use TC binary. Please refer to our documentation
[here](https://facebookresearch.github.io/TensorComprehensions/framework/pytorch_integration/getting_started.html) for instructions.

## From Source

# Installation / Documentation
You can find documentation [here](https://facebookresearch.github.io/TensorComprehensions/) which contains instructions for building TC via docker, conda packages or in non-conda environment.

# Communication

* **Email**: tensorcomp@fb.com
* **GitHub issues**: bug reports, feature requests, install issues, RFCs, thoughts, etc.
* **Slack**: For discussion around framework integration, build support, collaboration, etc. join our slack channel https://tensorcomprehensions.herokuapp.com/.

Expand Down
33 changes: 29 additions & 4 deletions build.sh
Expand Up @@ -245,9 +245,9 @@ function install_caffe2() {
rm -rf * || exit 1

if ! test ${USE_CONTBUILD_CACHE}; then
${CMAKE_VERSION} -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DPYTHON_EXECUTABLE=${PYTHON} -DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX} -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DBUILD_PYTHON=${WITH_PYTHON_C2} -DUSE_GLOG=OFF -DUSE_GFLAGS=OFF -DUSE_NNPACK=${WITH_NNPACK} -DUSE_GLOO=OFF -DUSE_NCCL=OFF -DUSE_LMDB=OFF -DUSE_LEVELDB=OFF -DBUILD_TEST=OFF -DUSE_OPENCV=OFF -DUSE_OPENMP=OFF -DCMAKE_INSTALL_MESSAGE=NEVER -DCMAKE_CXX_FLAGS="-fno-var-tracking-assignments" -DPROTOBUF_PROTOC_EXECUTABLE=${PROTOC} -DCUDNN_ROOT_DIR=${CUDNN_ROOT_DIR} -DCUB_INCLUDE_DIR=${CUB_INCLUDE_DIR} -DCMAKE_C_COMPILER=${CC} -DCMAKE_CXX_COMPILER=${CXX} .. || exit
${CMAKE_VERSION} -DBUILD_BINARY=OFF -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DCMAKE_PREFIX_PATH=${INSTALL_PREFIX} -DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX} -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DBUILD_PYTHON=${WITH_PYTHON_C2} -DUSE_GLOG=OFF -DUSE_GFLAGS=OFF -DUSE_NNPACK=${WITH_NNPACK} -DGLOG_ROOT_DIR=${INSTALL_PREFIX} -DUSE_GLOO=OFF -DUSE_NCCL=OFF -DUSE_LMDB=OFF -DUSE_LEVELDB=OFF -DBUILD_TEST=OFF -DUSE_OPENCV=OFF -DUSE_OPENMP=OFF -DCMAKE_INSTALL_MESSAGE=NEVER -DCMAKE_CXX_FLAGS="-fno-var-tracking-assignments" -DPROTOBUF_PROTOC_EXECUTABLE=${PROTOC} -DCUDNN_ROOT_DIR=${CUDNN_ROOT_DIR} -DCUB_INCLUDE_DIR=${CUB_INCLUDE_DIR} -DCMAKE_C_COMPILER=${CC} -DCMAKE_CXX_COMPILER=${CXX} .. || exit
else
${CMAKE_VERSION} -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DPYTHON_EXECUTABLE=${PYTHON} -DCUDA_ARCH_NAME="Maxwell" -DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX} -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DBUILD_PYTHON=${WITH_PYTHON_C2} -DUSE_GLOG=OFF -DUSE_GLOO=OFF -DUSE_NNPACK=${WITH_NNPACK} -DUSE_NCCL=OFF -DUSE_GFLAGS=OFF -DUSE_LMDB=OFF -DUSE_LEVELDB=OFF -DBUILD_TEST=OFF -DUSE_OPENCV=OFF -DUSE_OPENMP=OFF -DCMAKE_INSTALL_MESSAGE=NEVER -DCMAKE_CXX_FLAGS="-fno-var-tracking-assignments" -DPROTOBUF_PROTOC_EXECUTABLE=${PROTOC} -DCMAKE_C_COMPILER=${CC} -DCMAKE_CXX_COMPILER=${CXX} .. || exit
${CMAKE_VERSION} -DBUILD_BINARY=OFF -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DCUDA_ARCH_NAME="Maxwell" -DCMAKE_PREFIX_PATH=${INSTALL_PREFIX} -DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX} -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DBUILD_PYTHON=${WITH_PYTHON_C2} -DUSE_GLOG=OFF -DUSE_GFLAGS=OFF -DGLOG_ROOT_DIR=${INSTALL_PREFIX} -DUSE_GLOO=OFF -DUSE_NNPACK=${WITH_NNPACK} -DUSE_NCCL=OFF -DUSE_LMDB=OFF -DUSE_LEVELDB=OFF -DBUILD_TEST=OFF -DUSE_OPENCV=OFF -DUSE_OPENMP=OFF -DCMAKE_INSTALL_MESSAGE=NEVER -DCMAKE_CXX_FLAGS="-fno-var-tracking-assignments" -DPROTOBUF_PROTOC_EXECUTABLE=${PROTOC} -DCUB_INCLUDE_DIR=${CUB_INCLUDE_DIR} -DCMAKE_C_COMPILER=${CC} -DCMAKE_CXX_COMPILER=${CXX} .. || exit
fi
fi
VERBOSE=${VERBOSE} make -j $CORES install -s || exit 1
Expand Down Expand Up @@ -320,8 +320,33 @@ function install_cub() {

function install_tc_python() {
echo "Setting up python now"
export PYTHONPATH=${TC_DIR}/build/pybinds:${PYTHONPATH}
echo "PYTHONPATH: ${PYTHONPATH}"
echo "USE_CONTBUILD_CACHE: ${USE_CONTBUILD_CACHE}"

if [ "$USE_CONTBUILD_CACHE" == "1" ]; then
echo "Running on CI, setting PYTHONPATH only"
export PYTHONPATH=${TC_DIR}/build/tensor_comprehensions/pybinds:${PYTHONPATH}
echo "PYTHONPATH: ${PYTHONPATH}"
else
if which conda &> /dev/null; then
echo "Found conda, going to install Python packages"
cd ${TC_DIR}
export CONDA_PYTHON=$(which python3)
echo "CONDA_PYTHON: ${CONDA_PYTHON}"
if [ "$BUILD_TYPE" == "Release" ]; then
echo "Install mode setup for python"
${CONDA_PYTHON} setup.py install
else
echo "Develop mode setup for python"
${CONDA_PYTHON} setup.py develop
fi
else
echo "Conda not found, setting PYTHONPATH instead"
echo "Setting PYTHONPATH now"
export PYTHONPATH=${TC_DIR}/tensor_comprehensions:$PYTHONPATH
echo "PYTHONPATH: ${PYTHONPATH}"
fi
fi
echo "python all set now"
}

function install_tc() {
Expand Down
68 changes: 68 additions & 0 deletions conda_recipes/Dockerfile
@@ -0,0 +1,68 @@
FROM nvidia/cuda:8.0-cudnn6-devel-ubuntu14.04

ENV DEBIAN_FRONTEND noninteractive

RUN apt-get update

RUN apt-get install -y --no-install-recommends make git ssh realpath wget unzip cmake3 vim
RUN apt-get install -y --no-install-recommends libgoogle-glog-dev libyaml-dev
RUN apt-get install -y --no-install-recommends libgtest-dev libz-dev libgmp3-dev
RUN apt-get install -y --no-install-recommends automake libtool valgrind subversion
RUN apt-get install -y --no-install-recommends ca-certificates software-properties-common

RUN cmake --version

# GCC 4.8.*
RUN add-apt-repository ppa:ubuntu-toolchain-r/test
RUN apt-get update
RUN apt-get install -y --no-install-recommends libcilkrts5 gcc-4.8 g++-4.8
RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-4.8 50
RUN update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-4.8 50

ENV CC /usr/bin/gcc
ENV CXX /usr/bin/g++

# LLVM+Clang
ENV LLVM_SOURCES /tmp/llvm_sources-tapir5.0
WORKDIR $LLVM_SOURCES

ENV CLANG_PREFIX /usr/local/clang+llvm-tapir5.0
ENV CMAKE_VERSION cmake

RUN git clone --recursive https://github.com/wsmoses/Tapir-LLVM llvm && \
mkdir -p ${LLVM_SOURCES}/llvm_build && cd ${LLVM_SOURCES}/llvm_build && \
${CMAKE_VERSION} -DLLVM_ENABLE_EH=ON -DLLVM_ENABLE_OCAMLDOC=OFF -DLLVM_INSTALL_OCAMLDOC_HTML_DIR=/tmp -DLLVM_OCAML_INSTALL_PATH=/tmp -DCMAKE_INSTALL_PREFIX=${CLANG_PREFIX} -DLLVM_TARGETS_TO_BUILD=X86 -DCOMPILER_RT_BUILD_CILKTOOLS=OFF -DLLVM_ENABLE_CXX1Y=ON -DLLVM_ENABLE_TERMINFO=OFF -DLLVM_BUILD_TESTS=OFF -DLLVM_ENABLE_ASSERTIONS=ON -DCMAKE_BUILD_TYPE=Release -DLLVM_BUILD_LLVM_DYLIB=ON -DLLVM_ENABLE_RTTI=ON ../llvm/ && \
make -j"$(nproc)" -s && make install -j"$(nproc)" -s

RUN rm -Rf ${LLVM_SOURCES}

# Anaconda3
ENV LD_LIBRARY_PATH /usr/local/cuda/lib64:/usr/local/cuda/targets/x86_64-linux/lib/stubs/:$LD_LIBRARY_PATH
ENV PATH /usr/local/bin:/usr/local/cuda/bin:$PATH

WORKDIR /conda-install
RUN echo 'export PATH=/opt/conda/bin:$PATH' > /etc/profile.d/conda.sh &&\
wget --quiet https://repo.continuum.io/archive/Anaconda3-5.0.1-Linux-x86_64.sh -O anaconda.sh && \
chmod +x anaconda.sh && \
./anaconda.sh -b -p /opt/conda && \
rm anaconda.sh

ENV PATH /opt/conda/bin:$PATH

RUN conda install numpy decorator six future cmake pyyaml

# Protobuf 3.4*
WORKDIR /proto-install
RUN wget --quiet https://github.com/google/protobuf/archive/v3.4.0.zip -O proto.zip && unzip -qq proto.zip -d /

RUN cd /protobuf-3.4.0 && ./autogen.sh && ./configure && make -j 8
RUN cd /protobuf-3.4.0 && make install && ldconfig

RUN which conda
RUN conda --version
RUN which protoc
RUN protoc --version
RUN which python
RUN python --version

CMD ["bash"]

0 comments on commit a1def65

Please sign in to comment.