Skip to content

Commit

Permalink
Bind/run interface to CodeGen (pytorch#60)
Browse files Browse the repository at this point in the history
* Bind/run interface to CodeGen

* Make LLVMCodeGen implement CodeGen interface

* Allow bind/run to be unimplemented for the moment (CUDA)

* Cache compilation result

* Two nasty bugs: forgot virtual dtor, forgot to clear bindings after run()
  • Loading branch information
bertmaher authored and Mikhail Zolotukhin committed Feb 7, 2020
1 parent 3cfe72d commit c8e8fad
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 31 deletions.
13 changes: 13 additions & 0 deletions test/test_tensorexpr.py
Expand Up @@ -260,3 +260,16 @@ def easy(x, y):
b = torch.zeros(1024, dtype=torch.int32)
x= traced(a, b)
np.testing.assert_allclose(np.zeros(1024), x.numpy())

def test_reps():
def easy(x, y):
c = torch.add(x, y)
return c

traced = torch.jit.trace(easy, (torch.rand(1024), torch.rand(1024)))

for _ in range(32):
a = torch.ones(1024)
b = torch.zeros(1024)
x = traced(a, b)
np.testing.assert_allclose(np.ones(1024), x.numpy())
34 changes: 10 additions & 24 deletions torch/csrc/jit/passes/tensorexpr_fuser.cpp
Expand Up @@ -272,6 +272,7 @@ struct TensorExprKernel {
std::unordered_map<int64_t, Tensor> tensors;
std::unordered_map<int64_t, Expr> constants;
Stmt stmt;
std::unique_ptr<CodeGen> codegen;

Expr constant(torch::jit::Value* v) {
if (v->node()->kind() == prim::Constant) {
Expand Down Expand Up @@ -532,9 +533,7 @@ struct TensorExprKernel {
}
}
stmt = sch.Lower();
}

void run(Stack& stack) {
#ifdef ENABLE_LLVM
// Set up formal params (inputs, then outputs) for kernel.
std::vector<Buffer*> params;
Expand All @@ -548,41 +547,28 @@ struct TensorExprKernel {
params.push_back(&outbuf);

// Generate code.
LLVMCodeGen codegen(stmt, params);
codegen = std::make_unique<LLVMCodeGen>(stmt, params);
#else
codegen = std::make_unique<SimpleIREvaluator>(stmt);
#endif
}

void run(Stack& stack) {
// Set up arguments (inputs, then outputs) for kernel call.
auto inputs = last(stack, buffer_args.size());
std::vector<void*> args;
for (int i = 0; i < buffer_args.size(); i++) {
args.push_back(inputs[i].toTensor().data_ptr());
codegen->bind(buffer_args[i], inputs[i].toTensor().data_ptr());
}
at::Tensor output =
at::empty(bufferSizes(*tensor_output), at::ScalarType::Float);
args.push_back(output.data_ptr());
codegen->bind(*tensor_output, output.data_ptr());

// Call the kernel.
codegen.value<int32_t>(args);
codegen->run();

// Update the stack.
drop(stack, buffer_args.size());
stack.insert(stack.end(), std::move(output));
#else
SimpleIREvaluator eval(stmt);
std::vector<std::vector<float>> backing;

auto inputs = last(stack, buffer_args.size());
for (size_t i = 0; i < buffer_args.size(); i++) {
eval.bindBuffer(buffer_args[i], inputs[i].toTensor().data_ptr());
}

at::Tensor output =
at::empty(bufferSizes(*tensor_output), at::ScalarType::Float);
eval.bindBuffer(*tensor_output, output.data_ptr());

eval.eval();
drop(stack, buffer_args.size());
stack.insert(stack.end(), std::move(output));
#endif
}
};

Expand Down
17 changes: 16 additions & 1 deletion torch/csrc/jit/tensorexpr/codegen.h
Expand Up @@ -24,6 +24,11 @@ class CodeGen {
CodeGen(const Expr& expr, Ts... ts)
: ir_node_(expr.node()), buffer_args_({BufferArg(ts)...}) {}

CodeGen(const IRNode* node)
: ir_node_(node) {}

virtual ~CodeGen() {}

RefHandle<IRNode>& ir_node() {
return ir_node_;
}
Expand All @@ -40,6 +45,14 @@ class CodeGen {
return buffer_args_;
}

virtual void bind(const BufferArg& buf, const CallArg& data) {
LOG(FATAL) << "Unimplemented interface";
}

virtual void run() {
LOG(FATAL) << "Unimplemented interface";
}

private:
RefHandle<IRNode> ir_node_;
std::vector<BufferArg> buffer_args_;
Expand Down Expand Up @@ -77,7 +90,9 @@ class CodeGen::CallArg {
template <typename T>
CallArg(const std::vector<T>& buffer) : ptr_(const_cast<T*>(buffer.data())) {}

void* data() {
CallArg(void* ptr) : ptr_(ptr) {}

void* data() const {
return ptr_;
}

Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/tensorexpr/cuda_codegen.h
Expand Up @@ -217,6 +217,8 @@ class CudaCodeGen : public CodeGen {
#endif
}

~CudaCodeGen() override {}

template <typename... Ts>
void operator()(const Ts&... ts) {
std::vector<CallArg> args({CallArg(ts)...});
Expand Down
10 changes: 6 additions & 4 deletions torch/csrc/jit/tensorexpr/eval.h
Expand Up @@ -81,13 +81,15 @@ class SimpleIREvaluator : public CodeGen, public IRVisitor {
public:
using CodeGen::CodeGen;

template <typename Buf>
void bindBuffer(Buf b, void* d) {
buffer_mapping_[BufferArg(b).var().node()] = d;
~SimpleIREvaluator() override {}

void bind(const BufferArg& buf, const CallArg& data) override {
buffer_mapping_[buf.var().node()] = data.data();
}

void eval() {
void run() override {
ir_node().node()->accept(this);
buffer_mapping_.clear();
}

template <typename... Ts>
Expand Down
12 changes: 11 additions & 1 deletion torch/csrc/jit/tensorexpr/llvm_codegen.cpp
Expand Up @@ -35,7 +35,8 @@ LLVMCodeGen::LLVMCodeGen(const Expr& expr)
{}

LLVMCodeGen::LLVMCodeGen(const IRNode* node, const std::vector<Buffer*>& args, Dtype dtype)
: context_(std::make_unique<llvm::LLVMContext>()),
: CodeGen(node),
context_(std::make_unique<llvm::LLVMContext>()),
irb_(*context_.getContext()) {
llvm::InitializeAllTargets();
llvm::InitializeAllTargetMCs();
Expand Down Expand Up @@ -154,6 +155,15 @@ LLVMCodeGen::LLVMCodeGen(const IRNode* node, const std::vector<Buffer*>& args, D
kernelAddress_ = cantFail(sym.getAddress());
}

void LLVMCodeGen::bind(const BufferArg& buf, const CallArg& data) {
args_.push_back(data.data());
}

void LLVMCodeGen::run() {
value<float>(args_);
args_.clear();
}

// TODO: The binary ops are copypasta.

void LLVMCodeGen::visit(const Add* v) {
Expand Down
11 changes: 10 additions & 1 deletion torch/csrc/jit/tensorexpr/llvm_codegen.h
Expand Up @@ -4,6 +4,7 @@
#include <torch/csrc/WindowsTorchApiMacro.h>

#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h"
#include "torch/csrc/jit/tensorexpr/codegen.h"
#include "torch/csrc/jit/tensorexpr/ir.h"
#include "torch/csrc/jit/tensorexpr/ir_visitor.h"
#include "torch/csrc/jit/tensorexpr/llvm_jit.h"
Expand All @@ -24,7 +25,7 @@ namespace torch {
namespace jit {
namespace compiler {

class TORCH_API LLVMCodeGen : public IRVisitor {
class TORCH_API LLVMCodeGen : public CodeGen, public IRVisitor {
private:
llvm::orc::ThreadSafeContext context_;
llvm::IRBuilder<> irb_;
Expand All @@ -42,6 +43,8 @@ class TORCH_API LLVMCodeGen : public IRVisitor {
std::unordered_map<const BaseExprNode*, int> varToArg_;
std::unordered_map<const Variable*, llvm::Value*> varToVal_;

std::vector<void*> args_;

private:
explicit LLVMCodeGen(
const IRNode* node,
Expand All @@ -60,6 +63,12 @@ class TORCH_API LLVMCodeGen : public IRVisitor {
Dtype dtype = kInt32);
explicit LLVMCodeGen(const Expr& expr);

~LLVMCodeGen() override {}

void bind(const BufferArg& buf, const CallArg& data) override;

void run() override;

void visit(const Add* v) override;
void visit(const Sub* v) override;
void visit(const Mul* v) override;
Expand Down

0 comments on commit c8e8fad

Please sign in to comment.