Skip to content

Commit

Permalink
Draft version: Make AliasAnalysisKind optional in Op Registration API (
Browse files Browse the repository at this point in the history
…pytorch#30187)

Summary:
Don't look into deep into the diff's implementation. The reason to send out this diff is to help sync on the design first. Once we agree on the design, I will update the implementation accordingly.

**Here is the basic design for achieving this functionality:**

**Q1: Do we need to tell apart case between the following:**
case 1:  registry 1: PURE -> registry 2: CONSERVATIVE
case 2:  registry 1: PURE -> registry 2: <not set>

A: should be yes though, right now both cases have same value(due to defaulting to CONSERVATIVE) in operators_ and operatorLookupTable_.
case 1 should be denied while case 2 should be legal case where registry 1 will be PURE at the end.

**How to tell apart both cases:**

Right now, AliasAnalysisKind::CONSERVATIVE is by default (code pointer: https://our.intern.facebook.com/intern/diffusion/FBS/browse/master/fbcode/caffe2/aten/src/ATen/core/dispatch/OperatorOptions.h?lines=22%2C52)

Current approach: Introducing a boolean flag in OperatorOptions called isDefault, defaulting to value true. When manually call setAliasAnalysis(AliasAnalysisKind), it will be set too false.
And then when findSchema() in Dispatcher.cpp,  we will check response's option's isDefault value.
If isDefault = true, then with some sanity check and if all checks passed, we can update the option info in both operators_ and operatorLookupTable_

Other approaches:
1. Introducing a new AliasAnalaysisKind maybe called NOT_SPECIFIED.  (I am not doing it this way since then I need to update other callosities related to AliasAnalaysisKind::CONSERVATIVE) Also, we will need to have additional logics to align between NOT_SPECIFIED and CONSERVATIVE

**What data to be updated:**
corresponding entry in std::list<OperatorDef> operators_ and LeftRight<ska::flat_hash_map<OperatorName, OperatorHandle>> operatorLookupTable_

(More things to be discussed here.)

**Do we need to trigger listeners if an entry get updated:**
I think no.
callOnOperatorRegistered(op) seems only to be using OperatorHandle.schema now from the only callsite from register_c10_ops.cpp
(code pointers: https://our.intern.facebook.com/intern/diffusion/FBS/browse/master/fbcode/caffe2/aten/src/ATen/core/dispatch/Dispatcher.cpp?commit=b4cefeaa98dca5b1ec5f7a0bca6028e368960244&lines=87-90
and https://our.intern.facebook.com/intern/diffusion/FBS/browse/master/fbcode/caffe2/torch/csrc/jit/register_c10_ops.cpp?lines=178&link_ref=biggrep)

However, things can be much more complicated if future extensions may use options when some listeners want to use options value to register operators.

**Future reading list + remaining questions:**
1. How options get consumed on the other side.
2. Usages for fields in OperatorEntry besides schema/options/kernals
Pull Request resolved: pytorch#30187

Test Plan:
[xintchen@devvm6308.prn2 ~/fbsource/fbcode] buck test mode/dev //caffe2:ATen-core-test

All tests passed

Differential Revision: D18530964

Pulled By: charliechen0401

fbshipit-source-id: 60c0560a63a36e54f09f397667bb7122b61d6a8e
  • Loading branch information
Xintao Chen authored and facebook-github-bot committed Nov 22, 2019
1 parent c478a92 commit 5d7b208
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 5 deletions.
10 changes: 8 additions & 2 deletions aten/src/ATen/core/dispatch/Dispatcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,14 @@ OperatorHandle Dispatcher::findOrRegisterSchema_(FunctionSchema&& schema, Operat
str << schema << " vs " << found->schema();
TORCH_CHECK(false, "Tried to register multiple operators with the same name and the same overload name but different schemas: ", str.str());
}
if (found->options() != options) {
TORCH_CHECK(false, "Tried to register multiple operators with the same schema but different options: ", toString(schema));
if (options.isDefaultAliasAnalysisKind()) {
// just do nothing and let it pass.
} else if (found->options().isDefaultAliasAnalysisKind()) {
found->operatorIterator_->op.updateOptionsAliasAnalysis(options.aliasAnalysis());
} else {
TORCH_CHECK(
found->options() == options,
"Tried to register multiple operators with the same schema but different options: ", toString(schema));
}
return *found;
}
Expand Down
1 change: 0 additions & 1 deletion aten/src/ATen/core/dispatch/Dispatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,4 +258,3 @@ inline const KernelFunction& Dispatcher::dispatch_(const DispatchTable& dispatch
}

} // namespace c10

4 changes: 4 additions & 0 deletions aten/src/ATen/core/dispatch/OperatorEntry.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ class OperatorEntry final {
return options_;
}

void updateOptionsAliasAnalysis(AliasAnalysisKind a) {
options_.setAliasAnalysis(a);
}

private:
void deregisterKernel_(TensorTypeId dispatch_key, std::list<KernelFunction>::iterator kernel);
void deregisterCatchallKernel_(std::list<KernelFunction>::iterator kernel);
Expand Down
10 changes: 8 additions & 2 deletions aten/src/ATen/core/dispatch/OperatorOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,14 @@ inline const char* toString(AliasAnalysisKind aliasAnalysisKind) {

struct OperatorOptions final {
public:
bool isDefaultAliasAnalysisKind() const {
return aliasAnalysisKind_ == c10::nullopt;
}

AliasAnalysisKind aliasAnalysis() const {
return aliasAnalysisKind_;
return !isDefaultAliasAnalysisKind()
? *aliasAnalysisKind_
: AliasAnalysisKind::CONSERVATIVE;
}

void setAliasAnalysis(AliasAnalysisKind v) {
Expand All @@ -49,7 +55,7 @@ struct OperatorOptions final {
}

private:
AliasAnalysisKind aliasAnalysisKind_ = AliasAnalysisKind::CONSERVATIVE;
c10::optional<AliasAnalysisKind> aliasAnalysisKind_;
};

} // namespace c10
45 changes: 45 additions & 0 deletions aten/src/ATen/core/op_registration/op_registration_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,51 @@ struct MockKernel final : OperatorKernel {
bool* called_;
};

TEST(OperatorRegistrationTest, whenRegisteringSameSchemaWithAliasAnalysisAfterRegisteringWithoutAliasAnalysis_thenCanBeCalled) {
{
auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<DummyKernel>(c10::TensorTypeId::CPUTensorId));
auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<DummyKernel>(c10::TensorTypeId::XLATensorId).aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION));

auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
ASSERT_TRUE(op.has_value());
EXPECT_EQ(op->options().aliasAnalysis(), at::AliasAnalysisKind::PURE_FUNCTION);
}
{
auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<DummyKernel>(c10::TensorTypeId::XLATensorId).aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION));
auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<DummyKernel>(c10::TensorTypeId::CPUTensorId));

auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
ASSERT_TRUE(op.has_value());
EXPECT_EQ(op->options().aliasAnalysis(), at::AliasAnalysisKind::PURE_FUNCTION);
}
}

TEST(OperatorRegistrationTest, whenRegisteringSameSchemaWithSameAliasAnalysis_thenCanBeCalled) {
auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<DummyKernel>(c10::TensorTypeId::CPUTensorId).aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION));
auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<DummyKernel>(c10::TensorTypeId::XLATensorId).aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION));

auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
ASSERT_TRUE(op.has_value());
EXPECT_EQ(op->options().aliasAnalysis(), at::AliasAnalysisKind::PURE_FUNCTION);
}

TEST(OperatorRegistrationTest, whenRegisteringSameSchemaWithNoAliasAnalysis_thenCanBeCalled) {
auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<DummyKernel>(c10::TensorTypeId::CPUTensorId));
auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<DummyKernel>(c10::TensorTypeId::XLATensorId));

auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
ASSERT_TRUE(op.has_value());
EXPECT_TRUE(op->options().isDefaultAliasAnalysisKind());
EXPECT_EQ(op->options().aliasAnalysis(), at::AliasAnalysisKind::CONSERVATIVE);
}

TEST(OperatorRegistrationTest, whenRegisteringSameSchemaWithDifferentAliasAnalysis_thenShouldThrow) {
expectThrows<c10::Error>([] {
auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<DummyKernel>(c10::TensorTypeId::CPUTensorId).aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION));
auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<DummyKernel>(c10::TensorTypeId::XLATensorId).aliasAnalysis(at::AliasAnalysisKind::CONSERVATIVE));
}, "Tried to register multiple operators with the same schema but different options:");
}

TEST(OperatorRegistrationTest, whenRegisteringWithSchemaBeforeKernelInOptionsObject_thenCanBeCalled) {
bool called = false;
auto registrar = c10::RegisterOperators().op(c10::RegisterOperators::options().schema("_test::dummy(Tensor dummy) -> ()").catchAllKernel<MockKernel>(&called));
Expand Down

0 comments on commit 5d7b208

Please sign in to comment.