Skip to content

Commit

Permalink
[functorch] batch rule : few decomposition ops (#96744)
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitij12345 authored and cyyever committed Mar 27, 2023
1 parent 39950d9 commit f3dff0b
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 12 deletions.
7 changes: 7 additions & 0 deletions aten/src/ATen/functorch/BatchRulesDecompositions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) {
OP_DECOMPOSE(flipud);
OP_DECOMPOSE2(float_power, Tensor_Tensor);
OP_DECOMPOSE2(float_power, Tensor_Scalar);
OP_DECOMPOSE2(float_power, Scalar);
OP_DECOMPOSE2(floor_divide, Scalar);
OP_DECOMPOSE(gather_backward);
OP_DECOMPOSE(ger);
Expand Down Expand Up @@ -279,6 +280,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) {
OP_DECOMPOSE(vstack);
OP_DECOMPOSE2(where, ScalarOther);
OP_DECOMPOSE2(where, ScalarSelf);
OP_DECOMPOSE2(where, Scalar);
OP_DECOMPOSE(orgqr);
m.impl("unflatten.int", native::unflatten_symint);
m.impl("_convolution_double_backward", native::_convolution_double_backward);
Expand Down Expand Up @@ -325,6 +327,11 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) {
OP_DECOMPOSE2(linalg_matrix_rank, atol_rtol_tensor);
OP_DECOMPOSE2(linalg_matrix_rank, atol_rtol_float);

// comparison ops
OP_DECOMPOSE2(greater, Scalar);
OP_DECOMPOSE2(less_equal, Scalar);
OP_DECOMPOSE2(less, Scalar);
OP_DECOMPOSE2(not_equal, Scalar);
}

}}
6 changes: 0 additions & 6 deletions test/functorch/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3743,13 +3743,7 @@ def test_vmap_exhaustive(self, device, dtype, op):
decorate('bitwise_left_shift', decorator=unittest.skipIf(TEST_WITH_UBSAN, "Fails with above error")),
decorate('bitwise_right_shift', decorator=unittest.skipIf(TEST_WITH_UBSAN, "Fails with above error")),
# One or more of the overload doesn't have a Batch rule.
xfail('where'),
xfail('bincount'),
xfail('float_power'),
xfail('gt'),
xfail('le'),
xfail('lt'),
xfail('ne'),
# UBSAN: runtime error: 1.27043e+262 is outside the range of representable values of type 'float'
decorate('special.zeta', decorator=unittest.skipIf(TEST_WITH_UBSAN, "Fails with above error")),
# RuntimeError: Expected all tensors to be on the same device,
Expand Down
6 changes: 0 additions & 6 deletions test/functorch/test_vmap_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,12 @@
"aten::flatten.using_ints",
"aten::flatten.using_names",
"aten::flatten_dense_tensors",
"aten::float_power.Scalar",
"aten::float_power_.Scalar",
"aten::float_power_.Tensor",
"aten::floor_divide_.Scalar",
"aten::frobenius_norm",
"aten::fused_moving_avg_obs_fake_quant",
"aten::get_gradients",
"aten::greater.Scalar",
"aten::greater_.Scalar",
"aten::greater_.Tensor",
"aten::greater_equal_.Scalar",
Expand All @@ -140,10 +138,8 @@
"aten::item",
"aten::kl_div",
"aten::ldexp_",
"aten::less.Scalar",
"aten::less_.Scalar",
"aten::less_.Tensor",
"aten::less_equal.Scalar",
"aten::less_equal_.Scalar",
"aten::less_equal_.Tensor",
"aten::linalg_cond.p_str",
Expand Down Expand Up @@ -199,7 +195,6 @@
"aten::norm.names_ScalarOpt_dim",
"aten::norm.names_ScalarOpt_dim_dtype",
"aten::norm_except_dim",
"aten::not_equal.Scalar",
"aten::not_equal_.Scalar",
"aten::not_equal_.Tensor",
"aten::one_hot",
Expand Down Expand Up @@ -297,7 +292,6 @@
"aten::var_mean.correction_names",
"aten::var_mean.names_dim",
"aten::where",
"aten::where.Scalar",

}

Expand Down

0 comments on commit f3dff0b

Please sign in to comment.