Skip to content
This repository has been archived by the owner on Apr 28, 2023. It is now read-only.

PyTorch Integration with Tensor Comprehensions #72

Merged
merged 1 commit into from
Feb 27, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@ conda
*/.nfs*
tensor_comprehensions.egg-info/
tensor_comprehensions/version.py
tensor_comprehensions/*.proto
slurm-*
examples/results*
examples/results*
9 changes: 4 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,6 @@ endif()

################################################################################
# ATen

# ATen - if someone ships libATen.so, we try to use that if available
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't we anymore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do but this has been more properly highlighted in next steps when we are searching for libaten in pytorch install

# first find python path
execute_process(COMMAND which python OUTPUT_VARIABLE __python OUTPUT_STRIP_TRAILING_WHITESPACE)
message(STATUS "PYTHON output: \n${__python}")
Expand All @@ -184,14 +182,15 @@ message(STATUS "PYTHON output: \n${__python}")
execute_process(COMMAND "${__python}" "-c" "import torch;" RESULT_VARIABLE __torch_install OUTPUT_STRIP_TRAILING_WHITESPACE)
message(STATUS "IMPORTING TORCH: \n${__torch_install}")

# also get the site-packages path where conda installs things
# also get the site-packages path where conda installs pytorch
execute_process(COMMAND "python" "-c"
"from distutils.sysconfig import get_python_lib; print(get_python_lib())"
OUTPUT_VARIABLE PYTHON_SITE_PACKAGES OUTPUT_STRIP_TRAILING_WHITESPACE)
message(STATUS "PYTHON site packages: \n${PYTHON_SITE_PACKAGES}")

# if PyTorch is installed, we get libATen.so.1 from there, otherwise build it
if (__torch_install EQUAL 0)
message(STATUS "TORCH INSTALLED, linking to ATen")
message(STATUS "TORCH INSTALLED, linking to ATen from PyTorch")
set(ATEN_INCLUDE_DIR "${PYTHON_SITE_PACKAGES}/torch/lib/include")
include_directories(${ATEN_INCLUDE_DIR})
find_library(ATEN_LIBRARIES NAMES libATen.so.1 PATHS ${PYTHON_SITE_PACKAGES}/torch/lib)
Expand Down Expand Up @@ -238,7 +237,7 @@ message(STATUS "Found glog: ${GLOG_LIBRARIES}")

add_subdirectory(src)
enable_testing()
add_subdirectory(pybinds)
add_subdirectory(tensor_comprehensions/pybinds)
add_subdirectory(test)

if (WITH_CAFFE2)
Expand Down
67 changes: 41 additions & 26 deletions CodeOwners.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,52 @@
This file lists the "owners" for each part of Tensor Comprehensions.
These people are major contributors of a specific part, they ensure issues in that part get addressed and pull requests get reviewed.
If you want to contribute to Tensor Comprehensions, make sure to include these people in your review requests.
These people are major contributors of a specific part, they ensure issues in
that part get addressed and pull requests get reviewed. If you want to contribute
to Tensor Comprehensions, make sure to include these people in your review requests.

### Build system and Docker
### Documentation
* **Priya Goyal** [@prigoyal](https://github.com/prigoyal)
* **Albert Cohen** [@albertcohen](https://github.com/albertcohen)
```
docs/*
*.md
```

### Build system
* **Priya Goyal** [@prigoyal](https://github.com/prigoyal)
* **Nicolas Vasilache** [@nicolasvasilache](https://github.com/nicolasvasilache)
```
docker/*
.circleci/*
*/CMakeLists.txt
```

### Tensor Comprehensions language and frontend
### Docker
* **Priya Goyal** [@prigoyal](https://github.com/prigoyal)
```
docker/*
```

### Conda Packaging
* **Priya Goyal** [@prigoyal](https://github.com/prigoyal)
```
conda_recipes/*
```

### Python bindings
* **Priya Goyal** [@prigoyal](https://github.com/prigoyal)
```
tensor_comprehensions/pybinds/*
test_python/*
```

### Framework integration (PyTorch, Caffe2, etc.)
* **Priya Goyal** [@prigoyal](https://github.com/prigoyal)
```
tensor_comprehensions/*
src/c2/*
src/aten/*
```

### Tensor Comprehensions language and frontend
* **Zachary DeVito** [@zdevito](https://github.com/zdevito)
```
src/lang/*
Expand All @@ -37,27 +73,6 @@ src/core/polyhedral/*
src/autotuner/*
```

### Framework integration (ATen, Caffe2, etc.)
* **Priya Goyal** [@prigoyal](https://github.com/prigoyal)
```
src/c2/*
src/aten/*
```

### Documentation
* **Priya Goyal** [@prigoyal](https://github.com/prigoyal)
* **Albert Cohen** [@albertcohen](https://github.com/albertcohen)
```
docs/*
*.md
```

### Python bindings
* **Priya Goyal** [@prigoyal](https://github.com/prigoyal)
```
pybind/*
```

### Protocol buffers
* **Oleksandr Zinenko** [@ftynse](https://github.com/ftynse)
* **Theodoros Theodoridis** [@ttheodor](https://github.com/ttheodor)
Expand Down
60 changes: 30 additions & 30 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,56 +1,56 @@
# ![Tensor Comprehensions](docs/source/_static/img/tc-logo-full-color-with-text-2.png)

Tensor Comprehensions (TC) is a fully-functional C++ library to *automatically* synthesize high-performance machine learning kernels using [Halide](https://github.com/halide/Halide), [ISL](http://isl.gforge.inria.fr/) and NVRTC or LLVM. TC additionally provides basic integration with Caffe2 and pybind11 bindings for use with python. We provide more details in our paper on [arXiv](https://arxiv.org/abs/1802.04730).
Tensor Comprehensions (TC) is a fully-functional C++ library to *automatically* synthesize high-performance machine learning kernels using [Halide](https://github.com/halide/Halide), [ISL](http://isl.gforge.inria.fr/) and NVRTC or LLVM. TC additionally provides basic integration with Caffe2 and PyTorch. We provide more details in our paper on [arXiv](https://arxiv.org/abs/1802.04730).

This library is designed to be highly portable, machine-learning-framework agnostic and only requires a simple tensor library with memory allocation, offloading and synchronization capabilities.

For now, we have integrated TC with the [Caffe2](https://github.com/caffe2/caffe2) and [ATen](https://github.com/pytorch/pytorch/tree/master/aten/src/ATen) tensor libraries.
For now, we have integrated TC with the [Caffe2](https://github.com/caffe2/caffe2) and [PyTorch](https://github.com/pytorch/pytorch/).

# A simple example

The following illustrates a short but powerful feature of the library: the capacity to JIT-compile high-performance machine learning kernels on demand, for specific sizes.

```cpp
#include <ATen/ATen.h>
#include "tc/aten/aten_compiler.h"
#include "tc/core/mapping_options.h"

// 1. Define and setup the TC compilation unit with CUDA memory management backed by ATen.
std::string tc = R"TC(
def TensorDot(float(N, C1, C2, H, W) I0, float(N, C2, C3, H, W) I1) -> (O) {
```python
import tensor_comprehensions as tc
import torch
lang = """
def tensordot(float(N, C1, C2, H, W) I0, float(N, C2, C3, H, W) I1) -> (O) {
O(n, c1, c3, h, w) +=! I0(n, c1, c2, h, w) * I1(n, c2, c3, h, w)
})TC";
}
"""
N, C1, C2, C3, H, W = 32, 512, 8, 2, 28, 28
tensordot = tc.define(lang, name="tensordot")
I0, I1 = torch.randn(N, C1, C2, H, W).cuda(), torch.randn(N, C2, C3, H, W).cuda()
best_options = tensordot.autotune(I0, I1, cache=True)
out = tensordot(I0, I1, options=best_options)
```

// 2. Allocate tensors with random data
at::Tensor I0 = at::CUDA(at::kFloat).rand({32, 512, 8, 28, 28});
at::Tensor I1 = at::CUDA(at::kFloat).rand({32, 8, 2, 28, 28});
std::vector<at::Tensor> outputs;
After a few generations of `autotuning` on a 2-GPU P100 system, we see results resembling:

// 3. Run autotuning with evolutionary search starting from a naive option
auto options = tc::MappingOptions::makeNaiveMappingOptions();
auto bestOption = autotune(cacheFilename, tc, "TensorDot", {I0, I1}, options, {options});
![Autotuning Sample](docs/source/_static/img/autotuning.png)

// 4. Compile and run the TC with the best option.
tc::ATenCompilationUnit atCompl;
atCompl.define(tc);
auto handle = atCompl.compile("TensorDot", {I0, I1}, bestOption);
atCompl.run("TensorDot", {I0, I1}, outputs, handle);
We have not yet characterized the precise fraction of peak performance we obtain but it is not uncommon to obtain 80%+ of peak shared memory bandwidth after autotuning. Solid register-level optimizations are still in the work but TC in its current form already addresses the productivity gap between the needs of research and the needs of production. Which is why we are excited to share it with the entire community and bring this collaborative effort in the open.

// 5. Perform precision checks against an ATen reference implementation
check({I0, I1}, outputs, [&I0, &I1](){ return ...; });
```
# Documentation

After a few generations of autotuning on a 2-GPU P100 system, we see results resembling:
**General**: You can find detailed information about Tensor Comprehensions [here](https://facebookresearch.github.io/TensorComprehensions/).

![Autotuning Sample](docs/source/_static/img/autotuning.png)
**C++ API**: We also provide documentation for our C++ API which can can be found [here](https://facebookresearch.github.io/TensorComprehensions/api/)

We have not yet characterized the precise fraction of peak performance we obtain but it is not uncommon to obtain 80%+ of peak shared memory bandwidth after autotuning. Solid register-level optimizations are still in the work but TC in its current form already addresses the productivity gap between the needs of research and the needs of production. Which is why we are excited to share it with the entire community and bring this collaborative effort in the open.
# Installation

## Binaries

We provide conda package for making it easy to install and use TC binary. Please refer to our documentation
[here](https://facebookresearch.github.io/TensorComprehensions/framework/pytorch_integration/getting_started.html) for instructions.

## From Source

# Installation / Documentation
You can find documentation [here](https://facebookresearch.github.io/TensorComprehensions/) which contains instructions for building TC via docker, conda packages or in non-conda environment.

# Communication

* **Email**: tensorcomp@fb.com
* **GitHub issues**: bug reports, feature requests, install issues, RFCs, thoughts, etc.
* **Slack**: For discussion around framework integration, build support, collaboration, etc. join our slack channel https://tensorcomprehensions.herokuapp.com/.

Expand Down
33 changes: 29 additions & 4 deletions build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,9 @@ function install_caffe2() {
rm -rf * || exit 1

if ! test ${USE_CONTBUILD_CACHE}; then
${CMAKE_VERSION} -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DPYTHON_EXECUTABLE=${PYTHON} -DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX} -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DBUILD_PYTHON=${WITH_PYTHON_C2} -DUSE_GLOG=OFF -DUSE_GFLAGS=OFF -DUSE_NNPACK=${WITH_NNPACK} -DUSE_GLOO=OFF -DUSE_NCCL=OFF -DUSE_LMDB=OFF -DUSE_LEVELDB=OFF -DBUILD_TEST=OFF -DUSE_OPENCV=OFF -DUSE_OPENMP=OFF -DCMAKE_INSTALL_MESSAGE=NEVER -DCMAKE_CXX_FLAGS="-fno-var-tracking-assignments" -DPROTOBUF_PROTOC_EXECUTABLE=${PROTOC} -DCUDNN_ROOT_DIR=${CUDNN_ROOT_DIR} -DCUB_INCLUDE_DIR=${CUB_INCLUDE_DIR} -DCMAKE_C_COMPILER=${CC} -DCMAKE_CXX_COMPILER=${CXX} .. || exit
${CMAKE_VERSION} -DBUILD_BINARY=OFF -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DCMAKE_PREFIX_PATH=${INSTALL_PREFIX} -DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX} -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DBUILD_PYTHON=${WITH_PYTHON_C2} -DUSE_GLOG=OFF -DUSE_GFLAGS=OFF -DUSE_NNPACK=${WITH_NNPACK} -DGLOG_ROOT_DIR=${INSTALL_PREFIX} -DUSE_GLOO=OFF -DUSE_NCCL=OFF -DUSE_LMDB=OFF -DUSE_LEVELDB=OFF -DBUILD_TEST=OFF -DUSE_OPENCV=OFF -DUSE_OPENMP=OFF -DCMAKE_INSTALL_MESSAGE=NEVER -DCMAKE_CXX_FLAGS="-fno-var-tracking-assignments" -DPROTOBUF_PROTOC_EXECUTABLE=${PROTOC} -DCUDNN_ROOT_DIR=${CUDNN_ROOT_DIR} -DCUB_INCLUDE_DIR=${CUB_INCLUDE_DIR} -DCMAKE_C_COMPILER=${CC} -DCMAKE_CXX_COMPILER=${CXX} .. || exit
else
${CMAKE_VERSION} -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DPYTHON_EXECUTABLE=${PYTHON} -DCUDA_ARCH_NAME="Maxwell" -DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX} -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DBUILD_PYTHON=${WITH_PYTHON_C2} -DUSE_GLOG=OFF -DUSE_GLOO=OFF -DUSE_NNPACK=${WITH_NNPACK} -DUSE_NCCL=OFF -DUSE_GFLAGS=OFF -DUSE_LMDB=OFF -DUSE_LEVELDB=OFF -DBUILD_TEST=OFF -DUSE_OPENCV=OFF -DUSE_OPENMP=OFF -DCMAKE_INSTALL_MESSAGE=NEVER -DCMAKE_CXX_FLAGS="-fno-var-tracking-assignments" -DPROTOBUF_PROTOC_EXECUTABLE=${PROTOC} -DCMAKE_C_COMPILER=${CC} -DCMAKE_CXX_COMPILER=${CXX} .. || exit
${CMAKE_VERSION} -DBUILD_BINARY=OFF -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DCUDA_ARCH_NAME="Maxwell" -DCMAKE_PREFIX_PATH=${INSTALL_PREFIX} -DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX} -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DBUILD_PYTHON=${WITH_PYTHON_C2} -DUSE_GLOG=OFF -DUSE_GFLAGS=OFF -DGLOG_ROOT_DIR=${INSTALL_PREFIX} -DUSE_GLOO=OFF -DUSE_NNPACK=${WITH_NNPACK} -DUSE_NCCL=OFF -DUSE_LMDB=OFF -DUSE_LEVELDB=OFF -DBUILD_TEST=OFF -DUSE_OPENCV=OFF -DUSE_OPENMP=OFF -DCMAKE_INSTALL_MESSAGE=NEVER -DCMAKE_CXX_FLAGS="-fno-var-tracking-assignments" -DPROTOBUF_PROTOC_EXECUTABLE=${PROTOC} -DCUB_INCLUDE_DIR=${CUB_INCLUDE_DIR} -DCMAKE_C_COMPILER=${CC} -DCMAKE_CXX_COMPILER=${CXX} .. || exit
fi
fi
VERBOSE=${VERBOSE} make -j $CORES install -s || exit 1
Expand Down Expand Up @@ -320,8 +320,33 @@ function install_cub() {

function install_tc_python() {
echo "Setting up python now"
export PYTHONPATH=${TC_DIR}/build/pybinds:${PYTHONPATH}
echo "PYTHONPATH: ${PYTHONPATH}"
echo "USE_CONTBUILD_CACHE: ${USE_CONTBUILD_CACHE}"

if [ "$USE_CONTBUILD_CACHE" == "1" ]; then
echo "Running on CI, setting PYTHONPATH only"
export PYTHONPATH=${TC_DIR}/build/tensor_comprehensions/pybinds:${PYTHONPATH}
echo "PYTHONPATH: ${PYTHONPATH}"
else
if which conda &> /dev/null; then
echo "Found conda, going to install Python packages"
cd ${TC_DIR}
export CONDA_PYTHON=$(which python3)
echo "CONDA_PYTHON: ${CONDA_PYTHON}"
if [ "$BUILD_TYPE" == "Release" ]; then
echo "Install mode setup for python"
${CONDA_PYTHON} setup.py install
else
echo "Develop mode setup for python"
${CONDA_PYTHON} setup.py develop
fi
else
echo "Conda not found, setting PYTHONPATH instead"
echo "Setting PYTHONPATH now"
export PYTHONPATH=${TC_DIR}/tensor_comprehensions:$PYTHONPATH
echo "PYTHONPATH: ${PYTHONPATH}"
fi
fi
echo "python all set now"
}

function install_tc() {
Expand Down
68 changes: 68 additions & 0 deletions conda_recipes/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
FROM nvidia/cuda:8.0-cudnn6-devel-ubuntu14.04
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FMI, why do we rewrite a full new Dockerfile rather than build on top of what we have been using in CI?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the one in CI doesn't have GPU but for building conda packages, I need to run python test which basically need GPU in all cases. Hence building the docker image which has GPUs. You can read how nvidia-docker works. I provided the readme instructions here for using and building this image.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#resolved

Copy link
Contributor

@nicolasvasilache nicolasvasilache Feb 26, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docker image can't have GPUs, you can't download them :)

All our trusty docker images used to inherit from this FROM nvidia/cuda:8.0-cudnn6-devel-ubuntu14.04 and all our xenial docker images used to inherit from this FROM nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04 but I now realize all that has changed drastically in the last stretch before the release and that we now use this file.

Still, when providing the Docker for conda_recipes you are using a very similar Docker base than what I wrote a few months back which got me confused.

Does this mean we should uniformize what is in here and revert back to the old docker files like you did in this PR?

I'll open a task specifically for this.

Copy link
Contributor Author

@prigoyal prigoyal Feb 26, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docker files will be moved from here to a separate repo https://github.com/prigoyal/tensorcomp-dockerfiles that I have setup for Jenkins setup. The question about whether the images should be tiered or not is something that I am discussing with @pietern and @ezyang already. Please open the issue in the above repo to centralize discussions properly.

Still, when providing the Docker for conda_recipes you are using a very similar Docker base than what I wrote a few months back which got me confused.

Not quite sure, the docker images follow quite different structure now and provide more safeguarding.

FROM nvidia/cuda:8.0-cudnn6-devel-ubuntu14.04

why did we inherit from those earlier? did you use nvidia-smi or did you ever run gpu test in that docker image? :)

The docker image can't have GPUs, you can't download them :)

yeah, that's why we use nvidia-docker https://github.com/NVIDIA/nvidia-docker , looking at the diagram here will make things clear. Docker is not a VM, its an image (note, the terminology has been incorrectly used in past)

revert back to the old docker files like you did in this PR

nope, setting up CircleCI to properly use the gpu images we build is cumbersome, so we are going to use AWS jenkins setup soon.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#resolved

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did we inherit from those earlier? did you use nvidia-smi?

No since there is no cuda driver on the CPU-only CI machines.
However we have always had cross-compilation tests on CPU-only CI that were testing simple parts of the CUDA toolkit.

In particular you don't need a physical GPU to generate PTX and see whether compilation passes (e.g. NVRTC + CUB)...

Copy link
Contributor Author

@prigoyal prigoyal Feb 26, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can continue doing that, no need to start from NVIDIA images. the toolkit is installed in the images as you also pointed out. Nvidia-docker images are useless if you don't run CI gpu tests with it.


ENV DEBIAN_FRONTEND noninteractive

RUN apt-get update

RUN apt-get install -y --no-install-recommends make git ssh realpath wget unzip cmake3 vim
RUN apt-get install -y --no-install-recommends libgoogle-glog-dev libyaml-dev
RUN apt-get install -y --no-install-recommends libgtest-dev libz-dev libgmp3-dev
RUN apt-get install -y --no-install-recommends automake libtool valgrind subversion
RUN apt-get install -y --no-install-recommends ca-certificates software-properties-common

RUN cmake --version

# GCC 4.8.*
RUN add-apt-repository ppa:ubuntu-toolchain-r/test
RUN apt-get update
RUN apt-get install -y --no-install-recommends libcilkrts5 gcc-4.8 g++-4.8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you explain why it is required to install gcc-4.8 manually?
IIRC ubuntu 14.04 (trusty) should ship with gcc-4.8.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should but this is done as a safety measure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#resolved

RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-4.8 50
RUN update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-4.8 50

ENV CC /usr/bin/gcc
ENV CXX /usr/bin/g++

# LLVM+Clang
ENV LLVM_SOURCES /tmp/llvm_sources-tapir5.0
WORKDIR $LLVM_SOURCES

ENV CLANG_PREFIX /usr/local/clang+llvm-tapir5.0
ENV CMAKE_VERSION cmake

RUN git clone --recursive https://github.com/wsmoses/Tapir-LLVM llvm && \
mkdir -p ${LLVM_SOURCES}/llvm_build && cd ${LLVM_SOURCES}/llvm_build && \
${CMAKE_VERSION} -DLLVM_ENABLE_EH=ON -DLLVM_ENABLE_OCAMLDOC=OFF -DLLVM_INSTALL_OCAMLDOC_HTML_DIR=/tmp -DLLVM_OCAML_INSTALL_PATH=/tmp -DCMAKE_INSTALL_PREFIX=${CLANG_PREFIX} -DLLVM_TARGETS_TO_BUILD=X86 -DCOMPILER_RT_BUILD_CILKTOOLS=OFF -DLLVM_ENABLE_CXX1Y=ON -DLLVM_ENABLE_TERMINFO=OFF -DLLVM_BUILD_TESTS=OFF -DLLVM_ENABLE_ASSERTIONS=ON -DCMAKE_BUILD_TYPE=Release -DLLVM_BUILD_LLVM_DYLIB=ON -DLLVM_ENABLE_RTTI=ON ../llvm/ && \
make -j"$(nproc)" -s && make install -j"$(nproc)" -s

RUN rm -Rf ${LLVM_SOURCES}

# Anaconda3
ENV LD_LIBRARY_PATH /usr/local/cuda/lib64:/usr/local/cuda/targets/x86_64-linux/lib/stubs/:$LD_LIBRARY_PATH
ENV PATH /usr/local/bin:/usr/local/cuda/bin:$PATH

WORKDIR /conda-install
RUN echo 'export PATH=/opt/conda/bin:$PATH' > /etc/profile.d/conda.sh &&\
wget --quiet https://repo.continuum.io/archive/Anaconda3-5.0.1-Linux-x86_64.sh -O anaconda.sh && \
chmod +x anaconda.sh && \
./anaconda.sh -b -p /opt/conda && \
rm anaconda.sh

ENV PATH /opt/conda/bin:$PATH

RUN conda install numpy decorator six future cmake pyyaml

# Protobuf 3.4*
WORKDIR /proto-install
RUN wget --quiet https://github.com/google/protobuf/archive/v3.4.0.zip -O proto.zip && unzip -qq proto.zip -d /

RUN cd /protobuf-3.4.0 && ./autogen.sh && ./configure && make -j 8
RUN cd /protobuf-3.4.0 && make install && ldconfig

RUN which conda
RUN conda --version
RUN which protoc
RUN protoc --version
RUN which python
RUN python --version

CMD ["bash"]
Loading