Skip to content

Commit

Permalink
[TensorExpr] Remove dtype_ and add buf_ fields to `CodeGen::Buffe…
Browse files Browse the repository at this point in the history
…rArg`.

`BufferArg` is used to describe parameters passed to the codegen: it
indicates whether the parameter is a var or a buf and holds a pointer to
the corresponding var/buf. Both var and buf contain dtype, and thus
duplicating it in BufferArg is unnecessary - we can always get it from
the var/buf. Hence we're removing dtype_ field from BufferArg in this
PR. We're also adding a `buf_` field here: this is done so that
BufferArg truly has all the info about the parameter.

ghstack-source-id: e9e9ff211b1708ef2ef1dfe1c191964ba0e9c18d
Pull Request resolved: pytorch#57382
  • Loading branch information
Mikhail Zolotukhin committed Apr 30, 2021
1 parent 124bd45 commit cd46e63
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 21 deletions.
29 changes: 15 additions & 14 deletions torch/csrc/jit/tensorexpr/codegen.h
Expand Up @@ -88,30 +88,31 @@ class TORCH_API CodeGen {

class CodeGen::BufferArg {
public:
BufferArg(const Placeholder& buffer)
: var_(buffer.data()->base_handle()), dtype_(buffer.dtype()) {}
BufferArg(Tensor* tensor)
: var_(tensor->buf()->base_handle()), dtype_(tensor->buf()->dtype()) {}
BufferArg(const VarHandle& var)
: var_(var.node()), dtype_(var.dtype()), isVar_(true) {}
BufferArg(const BufHandle& buf)
: var_(buf.node()->base_handle()), dtype_(buf.node()->dtype()) {}
BufferArg(const Placeholder& buffer) : buf_(buffer.data()) {}
BufferArg(Tensor* tensor) : buf_(tensor->buf()) {}
BufferArg(const VarHandle& var) : var_(var.node()), isVar_(true) {}
BufferArg(const BufHandle& buf) : buf_(buf.node()) {}

const Var* var() const {
return var_;
return isVar_ ? var_ : buf_->base_handle();
}
Dtype dtype() const {
return dtype_;

const Buf* buf() const {
return buf_;
}

bool isVar() const {
return isVar_;
}

Dtype dtype() const {
return isVar_ ? var_->dtype() : buf_->dtype();
}

private:
const Var* var_;
Dtype dtype_;
bool isVar_{false};
const Var* var_ = nullptr;
const Buf* buf_ = nullptr;
bool isVar_ = false;
};

class CodeGen::CallArg {
Expand Down
14 changes: 7 additions & 7 deletions torch/csrc/jit/tensorexpr/eval.cpp
Expand Up @@ -989,16 +989,16 @@ void SimpleIREvaluator::call(const std::vector<CallArg>& args) {
USE_TRIGGER(simple_ir_eval_executed);
}

void SimpleIREvaluator::bindArg(const BufferArg& buf, const CallArg& data) {
if (!buf.isVar()) {
impl_->bindBuf(buf.var(), data.data());
void SimpleIREvaluator::bindArg(const BufferArg& bufArg, const CallArg& data) {
if (!bufArg.isVar()) {
impl_->bindBuf(bufArg.var(), data.data());
return;
}

switch (buf.dtype().scalar_type()) {
#define TYPE_CASE(Type, Name) \
case ScalarType::Name: \
impl_->bindVar(buf.var(), data.Name##Data()); \
switch (bufArg.dtype().scalar_type()) {
#define TYPE_CASE(Type, Name) \
case ScalarType::Name: \
impl_->bindVar(bufArg.var(), data.Name##Data()); \
break;
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE);
#undef TYPE_CASE
Expand Down

0 comments on commit cd46e63

Please sign in to comment.