Skip to content

Commit

Permalink
Fix some missing TensorDomain::getContiguousContiguity calls (pytorch…
Browse files Browse the repository at this point in the history
…#2083)

Co-authored-by: Xiang Gao <qasdfgtyuiop@gmail.com>
  • Loading branch information
csarofeen and zasdfgbnm committed Oct 16, 2022
1 parent ccafae8 commit eedc290
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 4 deletions.
3 changes: 2 additions & 1 deletion torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,8 @@ TensorView* castIntermediateValueInCompleteFusion(
// Create the actual domain and tv.
return IrBuilder::create<TensorView>(
IrBuilder::create<TensorDomain>(
new_root_domain, std::vector<bool>(new_root_domain.size(), true)),
new_root_domain,
TensorDomain::getContiguousContiguity(new_root_domain)),
data_type);
};

Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/ops/alias.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ TensorView* permute(TensorView* x, const std::vector<int64_t>& new2old) {

TensorView* out_tensor = IrBuilder::create<TensorView>(
IrBuilder::create<TensorDomain>(
out_domain, std::vector<bool>(out_domain.size(), true)),
out_domain, TensorDomain::getContiguousContiguity(out_domain)),
x->getDataType().value());
IrBuilder::create<TransposeOp>(out_tensor, x, normalized_new2old);
return out_tensor;
Expand Down
55 changes: 55 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6659,6 +6659,61 @@ TEST_F(NVFuserTest, FusionIssue2074_CUDA) {
ASSERT_TRUE(at::allclose(cg_outputs[1], t4));
}

TEST_F(NVFuserTest, FusionIssue2075_CUDA) {
auto fusion_ptr = std::make_unique<Fusion>();
Fusion& fusion = *fusion_ptr.get();
FusionGuard fg(&fusion);

int x = 2, y = 128, z = 128;

auto tv0 = makeContigConcreteTensor({1, -1, 1});
fusion.addInput(tv0);
auto tv1 = makeContigConcreteTensor({1, 1, -1});
fusion.addInput(tv1);

auto tv2 = set(tv0);
auto tv3 = expand(
tv2,
{IrBuilder::create<Int>(x),
tv2->axis(1)->extent(),
IrBuilder::create<Int>(z)});

// [1, 1, 128] -> [1, 1, 1, 1, 1, 128]
auto tv4 = broadcast(tv1, {{false, false, true, true, true, false}});
// [1, 1, 1, 1, 1, 128] -> [2, 128, 1, 1, 1, 128]
auto tv5 = expand(
tv4,
{IrBuilder::create<Int>(x),
IrBuilder::create<Int>(y),
tv4->axis(2)->extent(),
tv4->axis(3)->extent(),
tv4->axis(4)->extent(),
tv4->axis(5)->extent()});
auto tv6 = set(tv5);
// [2, 128, 1, 1, 1, 128] -> [2, 1, 128, 1, 1, 128]
auto tv7 = permute(tv6, {0, 3, 1, 2, 4, 5});
auto tv8 = sum(tv7, {1, 3, 4});
auto tv9 = le(tv8, tv3);
auto tv10 = castOp(DataType::Float, tv9);
fusion.addOutput(tv10);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);

at::Tensor t0 = at::randn({1, y, 1}, options);
at::Tensor t1 = at::randn({1, 1, z}, options);
auto t3 = t0.expand({x, y, z});
auto t4 = t1.unsqueeze(-2).unsqueeze(-2).unsqueeze(-2);
auto t5 = t4.expand({x, y, 1, 1, 1, z});
auto t7 = t5.permute({0, 3, 1, 2, 4, 5});
auto t8 = t7.squeeze(-2).squeeze(-2).squeeze(-3);
auto t9 = t8 < t3;
auto t10 = t9.to(at::kFloat);

FusionExecutorCache executor_cache(std::move(fusion_ptr));
auto cg_outputs = executor_cache.runFusionWithInputs({t0, t1});
testValidate(&fusion, cg_outputs, {t0, t1}, {t10}, __LINE__, __FILE__);
}

// Test file size should be up to 10K LoC. Create a new file for more tests.

} // namespace jit
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/codegen/cuda/transform_rfactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ std::pair<TensorDomain*, TensorDomain*> TransformRFactor::runReplay(
new_producer_root,
new_producer_rfactor_domain,
new_producer_domain,
std::vector<bool>(new_producer_rfactor_domain.size(), true));
TensorDomain::getContiguousContiguity(new_producer_rfactor_domain));

// Producer has been finished, now work on consumer.

Expand Down Expand Up @@ -467,7 +467,7 @@ std::pair<TensorDomain*, TensorDomain*> TransformRFactor::runReplay(
original_td->container(),
new_consumer_root_domain,
new_consumer_domain,
std::vector<bool>(new_consumer_root_domain.size(), true));
TensorDomain::getContiguousContiguity(new_consumer_root_domain));

return std::make_pair(producer_domain, consumer_domain);
}
Expand Down

0 comments on commit eedc290

Please sign in to comment.