A domain specific language to express machine learning workloads.
skimo-openhub Merge pull request #601 from facebookresearch/pr/pre-template
ScheduleTreeMapping::ScheduleTreeMapping: drop redundant initialization
Latest commit 220b590 Aug 31, 2018
Permalink
Failed to load latest commit information.
.circleci Move to trunk LLVM Jul 11, 2018
.github Move to trunk LLVM Jul 11, 2018
.jenkins Move to trunk LLVM Jul 11, 2018
cmake Initial OSS release of Tensor Comprehensions. Feb 14, 2018
conda_recipes Move to trunk LLVM Jul 11, 2018
docs Update installation.rst Aug 9, 2018
external/isl bump isl for schedule_nonneg_var_coefficient option Jul 18, 2018
isl_interface bump isl for merge of C++ bindings Aug 9, 2018
python Generate PTX with NVCC Jul 24, 2018
tc ScheduleTreeMapping::ScheduleTreeMapping: drop redundant initialization Aug 21, 2018
tensor_comprehensions Generate PTX with NVCC Jul 24, 2018
test test_cuda_mapper.cc: drop removal of user pointers from isl_id objects Aug 8, 2018
third-party bump isl for merge of C++ bindings Aug 9, 2018
.clang-format Initial OSS release of Tensor Comprehensions. Feb 14, 2018
.gitignore Stop putting generated protos in source tree Jul 24, 2018
.gitmodules isl submodule: switch back to ntv_dev branch Jun 22, 2018
BUILD.md Update installation docs Jun 8, 2018
CMakeLists.txt Generate PTX with LLVM trunk Jul 24, 2018
CODE_OF_CONDUCT.md Add CODE_OF_CONDUCT.md Feb 14, 2018
CONTRIBUTING.md locally generate and store isl C++ interface Jun 19, 2018
CodeOwners.md Redo the build system Jun 1, 2018
CodingConventions.md add more guidelines for the one-line summary of commit messages May 25, 2018
LICENSE Initial OSS release of Tensor Comprehensions. Feb 14, 2018
README.md Fix a broken link in README Aug 21, 2018
build.sh locally generate and store isl C++ interface Jun 19, 2018
check_and_fix_format.sh Initial OSS release of Tensor Comprehensions. Feb 14, 2018
check_format.sh Initial OSS release of Tensor Comprehensions. Feb 14, 2018
setup.py User-facing Python API Jun 29, 2018
test.sh Add Kronecker benchmark Jun 3, 2018
test_cpu.sh test_cpu.sh: run test_cuda_mapper_memory_promotion Jun 6, 2018

README.md

Tensor Comprehensions

Tensor Comprehensions (TC) is a fully-functional C++ library to automatically synthesize high-performance machine learning kernels using Halide, ISL and NVRTC or LLVM. TC additionally provides basic integration with Caffe2 and PyTorch. We provide more details in our paper on arXiv.

This library is designed to be highly portable, machine-learning-framework agnostic and only requires a simple tensor library with memory allocation, offloading and synchronization capabilities.

For now, we have integrated TC with Caffe2 and PyTorch.

A simple example

The following illustrates a short but powerful feature of the library: the capacity to JIT-compile high-performance machine learning kernels on demand, for specific sizes.

import tensor_comprehensions as tc
import torch
lang = """
def tensordot(float(N, C1, C2, H, W) I0, float(N, C2, C3, H, W) I1) -> (O) {
    O(n, c1, c3, h, w) +=! I0(n, c1, c2, h, w) * I1(n, c2, c3, h, w)
}
"""
N, C1, C2, C3, H, W = 32, 512, 8, 2, 28, 28
tensordot = tc.define(lang, name="tensordot")
I0, I1 = torch.randn(N, C1, C2, H, W).cuda(), torch.randn(N, C2, C3, H, W).cuda()
best_options = tensordot.autotune(I0, I1, cache=True)
out = tensordot(I0, I1, options=best_options)

After a few generations of autotuning on a 2-GPU P100 system, we see results resembling:

Autotuning Sample

In C++ a minimal autotuning example resembles the following:

TEST(TensorDot, SimpleAutotune) {
  // 1. Define and setup the TC compilation unit with CUDA memory
  // management backed by ATen tensors.
  std::string tc = R"TC(
def tensordot(float(N, C1, C2, H, W) I0,
              float(N, C2, C3, H, W) I1)  -> (O)
{
    O(n, c1, c3, h, w) +=! I0(n, c1, r_c2, h, w) * I1(n, r_c2, c3, h, w)
}
  )TC";

  // 2. Allocate tensors with random data.
  at::Tensor I0 = at::CUDA(at::kFloat).rand({32,  8, 16, 17, 25});
  at::Tensor I1 = at::CUDA(at::kFloat).rand({32, 16, 2, 17, 25});

  // 3. Run autotuning with evolutionary search starting from a naive option.
  auto naiveOptions = Backend::MappingOptionsType::makeNaiveMappingOptions();
  tc::aten::ATenAutotuner<tc::CudaBackend, tc::autotune::GeneticSearch>
      geneticAutotuneATen(tc);
  auto bestOption =
      geneticAutotuneATen.tune("tensordot", {I0, I1}, {naiveOptions});

  // 4. Compile and run the TC with the best option after allocating output
  //    tensors.
  auto pExecutor =
      tc::aten::compile<Backend>(tc, "tensordot", {I0, I1}, bestOption[0]);
  auto outputs = tc::aten::prepareOutputs(tc, "tensordot", {I0, I1});
  auto timings = tc::aten::profile(*pExecutor, {I0, I1}, outputs);
  std::cout << "tensordot size I0: " << I0.sizes() << ", "
            << "size I1: " << I1.sizes()
            << " ran in: " << timings.kernelRuntime.toMicroSeconds() << "us\n";
}

Note that we only need to autotune a TC once to obtain reasonable mapping options that can translate to other problem sizes for a given TC as the following snippet illustrates:

// 5. Reuse bestOptions from autotuning on another kernel
for (auto sizes : std::vector<std::pair<at::IntList, at::IntList>>{
         {{4, 9, 7, 16, 14}, {4, 7, 3, 16, 14}},
         {{8, 5, 11, 10, 10}, {8, 11, 16, 10, 10}},
     }) {
  at::Tensor I0 = makeATenTensor<Backend>(sizes.first);
  at::Tensor I1 = makeATenTensor<Backend>(sizes.second);
  auto pExecutor =
      tc::aten::compile<Backend>(tc, "tensordot", {I0, I1}, bestOption[0]);
  auto outputs = tc::aten::prepareOutputs(tc, "tensordot", {I0, I1});
  auto timings = tc::aten::profile(*pExecutor, {I0, I1}, outputs);
  std::cout << "tensordot size I0: " << I0.sizes() << ", "
            << "size I1: " << I1.sizes()
            << " ran in: " << timings.kernelRuntime.toMicroSeconds()
            << "us\n";
}

Putting it all together, one may see:

> build$ ./examples/example_simple
[==========] Running 1 test from 1 test case.
[----------] Global test environment set-up.
[----------] 1 test from TensorDot
[ RUN      ] TensorDot.SimpleAutotune
Generation 0    Jobs(Compiled, GPU)/total  (10, 10)/10   (best/median/worst)us: 226/4238/7345
Generation 1    Jobs(Compiled, GPU)/total  (10, 10)/10   (best/median/worst)us: 220/221/233
Generation 2    Jobs(Compiled, GPU)/total  (10, 10)/10   (best/median/worst)us: 220/221/234
tensordot size I0: [16, 8, 16, 17, 25], size I1: [16, 16, 2, 17, 25] ran in: 239us
tensordot size I0: [4, 9, 7, 16, 14], size I1: [4, 7, 3, 16, 14] ran in: 56us
tensordot size I0: [8, 5, 11, 10, 10], size I1: [8, 11, 16, 10, 10] ran in: 210us
[       OK ] TensorDot.SimpleAutotune (27812 ms)
[----------] 1 test from TensorDot (27812 ms total)

[----------] Global test environment tear-down
[==========] 1 test from 1 test case ran. (27812 ms total)
[  PASSED  ] 1 test.

We have not yet characterized the precise fraction of peak performance we obtain but it is not uncommon to obtain 80%+ of peak shared memory bandwidth after autotuning. Solid register-level optimizations are still in the work but TC in its current form already addresses the productivity gap between the needs of research and the needs of production. Which is why we are excited to share it with the entire community and bring this collaborative effort in the open.

Documentation

General: You can find detailed information about Tensor Comprehensions here.

C++ API: We also provide documentation for our C++ API which can can be found here

Installation

Binaries

We provide conda package for making it easy to install and use TC binary. Please refer to our documentation here for instructions.

From Source

You can find documentation here which contains instructions for building TC via docker, conda packages or in non-conda environment.

Communication

Code of Conduct

See the CODE_OF_CONDUCT.md file for more details.

License

Tensor Comprehensions is distributed under a permissive Apache v2.0 license, see the LICENSE file for more details.

Contributing

See the CONTRIBUTING.md file for more details.