Skip to content

Commit

Permalink
Issue 1770 (#1774)
Browse files Browse the repository at this point in the history
* Fix issue #1770

* Adds assertions to avoid segfault
  • Loading branch information
naoyam committed Jun 24, 2022
1 parent 0773c33 commit c8b4f42
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 16 deletions.
41 changes: 25 additions & 16 deletions arith.cpp
Expand Up @@ -561,8 +561,9 @@ namespace {
// Helper function to reduce repetitive code
template <typename T1, typename T2>
TensorView* arithOpOverloads(Val* (*func)(Val*, Val*), T1* v1, T2* v2) {
return func(v1->template as<Val>(), v2->template as<Val>())
->template as<TensorView>();
Val* out = func(v1->template as<Val>(), v2->template as<Val>());
TORCH_INTERNAL_ASSERT(out->isA<TensorView>());
return out->as<TensorView>();
}

template <typename T1, typename T2>
Expand All @@ -571,9 +572,10 @@ TensorView* arithOpOverloads(
T1* v1,
T2* v2,
DataType common_dtype) {
return binaryOp(
type, v1->template as<Val>(), v2->template as<Val>(), common_dtype)
->template as<TensorView>();
Val* out = binaryOp(
type, v1->template as<Val>(), v2->template as<Val>(), common_dtype);
TORCH_INTERNAL_ASSERT(out->isA<TensorView>());
return out->as<TensorView>();
}

template <typename T1, typename T2, typename T3>
Expand All @@ -583,11 +585,12 @@ TensorView* arithOpOverloads(
T2* v2,
T3* v3) {
auto vals = maybeBroadcast({v1, v2, v3});
return func(
vals[0]->template as<Val>(),
vals[1]->template as<Val>(),
vals[2]->template as<Val>())
->template as<TensorView>();
Val* out = func(
vals[0]->template as<Val>(),
vals[1]->template as<Val>(),
vals[2]->template as<Val>());
TORCH_INTERNAL_ASSERT(out->isA<TensorView>());
return out->as<TensorView>();
}

template <typename T1, typename T2, typename T3, typename T4>
Expand All @@ -598,12 +601,13 @@ TensorView* arithOpOverloads(
T3* v3,
T4* v4) {
auto vals = maybeBroadcast({v1, v2, v3, v4});
return func(
vals[0]->template as<Val>(),
vals[1]->template as<Val>(),
vals[2]->template as<Val>(),
vals[3]->template as<Val>())
->template as<TensorView>();
Val* out = func(
vals[0]->template as<Val>(),
vals[1]->template as<Val>(),
vals[2]->template as<Val>(),
vals[3]->template as<Val>());
TORCH_INTERNAL_ASSERT(out->isA<TensorView>());
return out->as<TensorView>();
}

// Output type promotion logic for binary operators
Expand Down Expand Up @@ -1509,6 +1513,11 @@ Val* where(Val* c, Val* v1, Val* v2) {
promote_type(v1->getDataType().value(), v2->getDataType().value());
auto out_vtype =
promote_type(v1->getValType().value(), v2->getValType().value());
// Even when v1 and v2 are scalar, the output is a tensor if the
// conditional input is a tensor.
if (c->getValType() == ValType::TensorView) {
out_vtype = ValType::TensorView;
}
auto vals = maybeBroadcast({c, v1, v2});
Val* out = nullptr;
if (out_vtype == ValType::TensorView) {
Expand Down
35 changes: 35 additions & 0 deletions test/test_gpu.cpp
Expand Up @@ -23706,6 +23706,41 @@ TEST_F(NVFuserTest, FusionIgnoreZeroDimReduction_CUDA) {
__FILE__);
}

// Repro of issue #1770
TEST_F(NVFuserTest, FusionIssue1770Repro_CUDA) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

auto tv0 = makeSymbolicTensor(1);
fusion->addInput(tv0);
auto tv1 = makeSymbolicTensor(1);
fusion->addInput(tv1);

auto tv2 = ge(tv0, tv1);
auto tv3 =
where(tv2, IrBuilder::create<Double>(1), IrBuilder::create<Double>(2));
fusion->addOutput(tv3);

std::vector<int64_t> shape({999});
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor t0 = at::randn(shape, options);
at::Tensor t1 = at::randn(shape, options);
std::vector<IValue> aten_inputs({t0, t1});

FusionExecutorCache executor_cache(std::move(fusion));
auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs);

auto ref = where(t0 >= t1, 1.0, 2.0);

testValidate(
executor_cache.fusion(),
cg_outputs,
aten_inputs,
{ref},
__LINE__,
__FILE__);
}

} // namespace jit
} // namespace torch
#endif // #if defined(USE_CUDA)

0 comments on commit c8b4f42

Please sign in to comment.