Skip to content

Commit

Permalink
min/max support for SymInt/Floats, finish as_strided/scatter/squeeze(…
Browse files Browse the repository at this point in the history
…) backward symint support (pytorch#86643)

Pull Request resolved: pytorch#86643
Approved by: https://github.com/anjali411
  • Loading branch information
albanD authored and pytorchmergebot committed Oct 11, 2022
1 parent 6923dc3 commit 86f914e
Show file tree
Hide file tree
Showing 12 changed files with 92 additions and 32 deletions.
7 changes: 3 additions & 4 deletions aten/src/ATen/native/transformers/cuda/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

#include <c10/cuda/CUDAMathCompat.h>

#include <ATen/native/NonSymbolicBC.h>
#include <ATen/native/nested/NestedTensorUtils.h>
#include <ATen/native/nested/NestedTensorTransformerFunctions.h>

Expand Down Expand Up @@ -368,8 +367,8 @@ __global__ void transform_bias_rescale_qkv_add_padding_kernel(
}

Tensor collapse_dims_1_and_2(const Tensor& sizes) {
auto sizes_dim1 = at::native::narrow(sizes, 1, 0, 1);
auto sizes_dim2 = at::native::narrow(sizes, 1, 1, 1);
auto sizes_dim1 = at::native::narrow_symint(sizes, 1, 0, 1);
auto sizes_dim2 = at::native::narrow_symint(sizes, 1, 1, 1);

return (sizes_dim1 * sizes_dim2).contiguous();
}
Expand Down Expand Up @@ -451,7 +450,7 @@ __host__ std::tuple<Tensor, Tensor, Tensor> transform_bias_rescale_qkv_cuda(
auto sizes = collapse_dims_1_and_2(nt_qkv->get_nested_size_tensor());
auto offsets =
NestedTensor_batch_offsets_from_size_tensor(sizes, sizes.numel());
at::native::narrow(offsets, 0, sizes.numel() + 1, sizes.numel())
at::native::narrow_symint(offsets, 0, sizes.numel() + 1, sizes.numel())
.copy_(sizes.reshape({-1}));
auto metadata = offsets.to(at::Device(kCUDA), at::kInt, true, true);
const auto offsets_ptr = metadata.data_ptr<int>();
Expand Down
15 changes: 15 additions & 0 deletions c10/core/SymInt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,21 @@ bool SymInt::operator>=(SymInt sci) const {
return res[0]->ge(res[1])->bool_();
}

SymInt SymInt::min(SymInt sci) const {
if (!is_symbolic() && !sci.is_symbolic()) {
return std::min(data_, sci.data_);
}
auto res = normalize_symints(*this, sci);
return SymInt::toSymInt(res[0]->min(res[1]));
}
SymInt SymInt::max(SymInt sci) const {
if (!is_symbolic() && !sci.is_symbolic()) {
return std::max(data_, sci.data_);
}
auto res = normalize_symints(*this, sci);
return SymInt::toSymInt(res[0]->max(res[1]));
}

void SymInt::operator*=(SymInt sci) {
*this = *this * sci;
}
Expand Down
3 changes: 3 additions & 0 deletions c10/core/SymInt.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ class C10_API SymInt {
void operator*=(SymInt sci);
void operator+=(SymInt sci);

SymInt min(SymInt sci) const;
SymInt max(SymInt sci) const;

SymInt operator*(int64_t sci) const;
bool operator<(int64_t sci) const;
bool operator==(int64_t sci) const;
Expand Down
6 changes: 6 additions & 0 deletions c10/core/SymIntNodeImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ class C10_API SymIntNodeImpl : public c10::intrusive_ptr_target {
virtual SymIntNode ceil() {
TORCH_CHECK(false, "NYI");
};
virtual SymIntNode min(const SymIntNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymIntNode max(const SymIntNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymIntNode clone() {
TORCH_CHECK(false, "NYI");
};
Expand Down
2 changes: 0 additions & 2 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,7 +1062,6 @@ def assert_compiler(gm: torch.fx.GraphModule, _):
xfail('nn.functional.interpolate', 'trilinear'), # Cannot call sizes() on tensor with symbolic sizes/st...
xfail('nn.functional.kl_div', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.l1_loss', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.linear', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.local_response_norm', ''), # aten.fill.Scalar - couldn't find symbolic meta functio...
xfail('nn.functional.max_pool1d', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.max_pool2d', ''), # aten.max_pool2d_with_indices_backward.default - couldn't find s...
Expand Down Expand Up @@ -1137,7 +1136,6 @@ def assert_compiler(gm: torch.fx.GraphModule, _):
xfail('special.polygamma', 'special_polygamma_n_0'), # aten.polygamma.default - couldn't find symbolic ...
xfail('special.xlog1py', ''), # aten.special_xlog1py.default - couldn't find symbolic meta function/deco...
xfail('split', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('squeeze', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('std', ''), # Cannot call numel() on tensor with symbolic sizes/strides
xfail('std_mean', ''), # Cannot call numel() on tensor with symbolic sizes/strides
xfail('stft', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
Expand Down
1 change: 0 additions & 1 deletion test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,7 +1056,6 @@ def f(a, b, c, d, e):
xfail('argmin', ''), # aten.argmin.default - couldn't find symbolic meta function/decomposition
xfail('argsort', ''), # aten.sort.default - couldn't find symbolic meta function/decomposition
xfail('argwhere', ''), # aten.nonzero.default - couldn't find symbolic meta function/decomposition
xfail('as_strided_scatter', ''), # aten.as_strided_scatter.default - couldn't find symbolic meta function/decomposition
xfail('baddbmm', ''), # aten.baddbmm.default - couldn't find symbolic meta function/decomposition
xfail('bernoulli', ''), # aten.bernoulli.default - couldn't find symbolic meta function/decomposition
xfail('bucketize', ''), # aten.bucketize.Tensor - couldn't find symbolic meta function/decomposition
Expand Down
8 changes: 4 additions & 4 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1493,19 +1493,19 @@
result: auto_element_wise

- name: squeeze(Tensor(a) self) -> Tensor(a)
self: unsqueeze_to(grad, self.sizes())
self: unsqueeze_to(grad, self.sym_sizes())
result: auto_linear

- name: squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)
self: unsqueeze_to(grad, dim, self.sizes())
self: unsqueeze_to(grad, dim, self.sym_sizes())
result: auto_linear

- name: squeeze_(Tensor(a!) self) -> Tensor(a!)
self: unsqueeze_to(grad, self.sizes())
self: unsqueeze_to(grad, self.sym_sizes())
result: auto_linear

- name: squeeze_.dim(Tensor(a!) self, int dim) -> Tensor(a!)
self: unsqueeze_to(grad, dim, self.sizes())
self: unsqueeze_to(grad, dim, self.sym_sizes())
result: auto_linear

- name: std.correction(Tensor self, int[1]? dim, *, int? correction, bool keepdim=False) -> Tensor
Expand Down
1 change: 1 addition & 0 deletions torch/_subclasses/fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,7 @@ def wrap(e, device=None):
def functions_with_cpp_meta_impl_that_support_symint(self):
return [
aten.empty_strided.default,
aten.as_strided_scatter.default,
aten.as_strided.default,
aten.zeros.default,
aten.detach.default,
Expand Down
35 changes: 22 additions & 13 deletions torch/csrc/autograd/FunctionsManual.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -848,23 +848,26 @@ Tensor unbind_backward(const variable_list& grads, int64_t dim) {
return at::stack(grads_tensors, dim);
}

Tensor unsqueeze_to(const Tensor& self, IntArrayRef sizes) {
Tensor unsqueeze_to(const Tensor& self, c10::SymIntArrayRef sym_sizes) {
auto result = self;

int64_t nDims = sizes.size();
int64_t nDims = sym_sizes.size();
for (const auto dim : c10::irange(nDims)) {
if (sizes[dim] == 1) {
if (sym_sizes[dim] == 1) {
result = result.unsqueeze(dim);
}
}
return result;
}

Tensor unsqueeze_to(const Tensor& self, int64_t dim, IntArrayRef sizes) {
dim = at::maybe_wrap_dim(dim, sizes.size());
Tensor unsqueeze_to(
const Tensor& self,
int64_t dim,
c10::SymIntArrayRef sym_sizes) {
dim = at::maybe_wrap_dim(dim, sym_sizes.size());
// in NumPy it's not an error to unsqueeze a scalar, but we still need to
// avoided unsqueezing in the backward.
if (sizes.size() > 0 && sizes[dim] == 1) {
if (sym_sizes.size() > 0 && sym_sizes[dim] == 1) {
return self.unsqueeze(dim);
}
return self;
Expand Down Expand Up @@ -2836,21 +2839,27 @@ Tensor as_strided_backward(

// Step (1): create underlying tensor as "storage"
auto shared_offset =
std::min(input_geometry.sym_storage_offset(), sym_storage_offset);
// TODO: symint-ify. Do we need a min() and max() for SymInts?
input_geometry.sym_storage_offset().min(sym_storage_offset);
auto inp_effective_offset =
input_geometry.sym_storage_offset() - shared_offset;
auto out_effective_offset = sym_storage_offset - shared_offset;
auto base_size = std::max(
_min_storage_size(inp_sizes_, inp_strides_, inp_effective_offset),
_min_storage_size(out_sizes_, out_strides_, out_effective_offset));
auto storage = grad.new_empty_symint(c10::SymIntArrayRef(base_size));
storage.zero_();
auto base_size1 =
_min_storage_size(inp_sizes_, inp_strides_, inp_effective_offset);
auto base_size2 =
_min_storage_size(out_sizes_, out_strides_, out_effective_offset);
auto base_size = base_size1.max(base_size2);
auto storage = grad.new_zeros_symint(c10::SymIntArrayRef(base_size));

// prepare indices tensor if we will do index_add_ later
c10::optional<at::Tensor> flatten_full_indices;
if (inp_maybe_overlap || out_maybe_overlap) {
flatten_full_indices =
at::arange(0, base_size, grad.options().dtype(at::kLong));
// TODO: should we symint-ify arange? Need SymScalar.
at::arange(
0,
base_size.guard_int(__FILE__, __LINE__),
grad.options().dtype(at::kLong));
}

// Step (2): use output geometry to scatter gradients into storage
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/autograd/FunctionsManual.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,11 @@ at::Tensor logcumsumexp_backward(
at::Tensor result,
int64_t dim);
at::Tensor unbind_backward(const variable_list& grads, int64_t dim);
at::Tensor unsqueeze_to(const at::Tensor& self, at::IntArrayRef sizes);
at::Tensor unsqueeze_to(const at::Tensor& self, c10::SymIntArrayRef sym_sizes);
at::Tensor unsqueeze_to(
const at::Tensor& self,
int64_t dim,
at::IntArrayRef sizes);
c10::SymIntArrayRef sym_sizes);
std::vector<at::Tensor> cat_tensors_backward(
const at::Tensor& grad,
const std::vector<std::vector<c10::SymInt>>& sizes,
Expand Down
19 changes: 19 additions & 0 deletions torch/csrc/jit/python/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,13 @@ class PythonSymIntNodeImpl : public c10::SymIntNodeImpl {
return dispatch_common_(__FUNCTION__, other);
}

virtual SymIntNode min(const SymIntNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
virtual SymIntNode max(const SymIntNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}

virtual SymIntNode ceil() override {
return dispatch_common_(__FUNCTION__);
}
Expand Down Expand Up @@ -1481,6 +1488,18 @@ void initJITBindings(PyObject* module) {
.def(
"__ceil__",
[](c10::SymIntNode a) -> c10::SymIntNode { return a->ceil(); })
.def(
"__min__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->min(snb);
})
.def(
"__max__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->max(snb);
})
.def("__bool__", [](c10::SymIntNode a) { return a->bool_(); })
.def("__int__", [](c10::SymIntNode a) { return a->int_(); })
// Intentionally don't set file line, as the Python backtrace matters
Expand Down
23 changes: 17 additions & 6 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def eval(cls, a):
'mul': lambda a, b: a * b,
'mod': lambda a, b: a % b,
'truediv': lambda a, b: a / b,
'floordiv': lambda a, b: FloorDiv(a, b)
'floordiv': lambda a, b: FloorDiv(a, b),
}

magic_methods = {
Expand All @@ -215,7 +215,9 @@ def eval(cls, a):
'lt': lambda a, b: sympy.Lt(a, b),
'le': lambda a, b: sympy.Le(a, b),
'ge': lambda a, b: sympy.Ge(a, b),
'ceil': lambda a: Ceil(a)
'ceil': lambda a: Ceil(a),
'min': lambda a, b: sympy.Min(a, b),
'max': lambda a, b: sympy.Max(a, b),
}

unary_magic_methods = {
Expand All @@ -228,14 +230,22 @@ def _make_magic(method, func, py_type):
func = lru_cache(256)(func)

def magic_impl(self, other):
if method in ["min", "max"]:
# op = getattr(builtins, method)
return self
else:
op = getattr(operator, method)
if SYM_FUNCTION_MODE:
return _handle_sym_dispatch(getattr(operator, method), (self, other), {})
return _handle_sym_dispatch(op, (self, other), {})
if isinstance(other, py_type):
other = other.expr
other_expr = other.expr
else:
assert isinstance(other, sympy.Expr)
other_expr = other
# TODO: consider constant prop here
expr = self.shape_env.replace(self.expr)
other = self.shape_env.replace(other)
out = func(expr, other)
other_expr = self.shape_env.replace(other_expr)
out = func(expr, other_expr)
out = sympy.expand(out)
if method in ["truediv"]:
return PySymFloat(out, self.shape_env)
Expand All @@ -246,6 +256,7 @@ def magic_impl(self, other):

def unary_magic_impl(self):
if SYM_FUNCTION_MODE:
# TODO: Should this if/else be moved outside of SYM_FUNCTION_MODE ?
if method in ["ceil", "floor"]:
op = getattr(math, method)
else:
Expand Down

0 comments on commit 86f914e

Please sign in to comment.