Skip to content

Commit

Permalink
[Phi] Add unbind yaml and final state api (PaddlePaddle#41277)
Browse files Browse the repository at this point in the history
* add unbind yaml

* fix unittest
  • Loading branch information
chenwhql authored and wu.zeng committed Apr 10, 2022
1 parent 9c6625d commit 6708774
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 8 deletions.
48 changes: 48 additions & 0 deletions paddle/phi/api/lib/api_custom_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,54 @@ std::tuple<Tensor, Tensor, Tensor> momentum_impl(
return api_output;
}

std::vector<Tensor> unbind_impl(const Tensor& input, int axis) {
auto kernel_key_set = ParseKernelKeyByInputArgs(input);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();

Backend kernel_backend = kernel_key.backend();
DataLayout kernel_layout = kernel_key.layout();
DataType kernel_data_type = kernel_key.dtype();

auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"unbind", {kernel_backend, kernel_layout, kernel_data_type});
VLOG(6) << "unbind API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << "unbind API kernel: " << kernel;

auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);

auto dense_input = PrepareData(input, kernel.InputAt(0), {});

// Calculate the number of out tensors
auto input_shape = input.dims();
if (axis < 0) {
axis = input_shape.size() + axis;
}
auto out_num = input_shape[axis];

std::vector<Tensor> out;
auto dense_outs = SetKernelOutput(out_num, kernel_backend, &out);
std::vector<phi::MetaTensor> meta_outs;
meta_outs.reserve(out_num);
std::vector<phi::MetaTensor*> meta_out_ptrs;
meta_out_ptrs.reserve(out_num);
for (int64_t i = 0; i < out_num; ++i) {
meta_outs.push_back(dense_outs[i]);
meta_out_ptrs.push_back(&meta_outs.back());
}

phi::UnbindInferMeta(MakeMetaTensor(*dense_input), axis, meta_out_ptrs);

using kernel_signature = void (*)(const phi::DeviceContext&,
const phi::DenseTensor&,
int,
std::vector<phi::DenseTensor*>&);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx, *dense_input, axis, dense_outs);

return out;
}

////////////////// Backward(grad) api impls //////////////////////

// TODO(chenweihang): the original sum grad op can support higher-level
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/api/lib/api_custom_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ limitations under the License. */

#pragma once

#include <vector>

#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/place.h"
Expand Down Expand Up @@ -73,6 +75,8 @@ std::tuple<Tensor, Tensor, Tensor> momentum_impl(
bool multi_precision,
float rescale_grad);

std::vector<Tensor> unbind_impl(const Tensor& input, int axis);

////////////////// Backward(grad) api impls //////////////////////

std::vector<Tensor> add_n_grad_impl(const std::vector<Tensor>& x,
Expand Down
12 changes: 6 additions & 6 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2429,7 +2429,7 @@ void TransposeGradInferMeta(const MetaTensor& x,

void UnbindInferMeta(const MetaTensor& x,
int axis,
std::vector<MetaTensor>* outs) {
std::vector<MetaTensor*> outs) {
auto in_dims = x.dims();
std::vector<int> out_dim;
axis = axis < 0 ? in_dims.size() + axis : axis;
Expand All @@ -2438,11 +2438,11 @@ void UnbindInferMeta(const MetaTensor& x,
}
auto out_dims = phi::make_ddim(out_dim);

for (size_t i = 0; i < outs->size(); ++i) {
(*outs)[i].set_dtype(x.dtype());
(*outs)[i].set_dims(out_dims);
(*outs)[i].set_layout(x.layout());
(*outs)[i].share_lod(x);
for (size_t i = 0; i < outs.size(); ++i) {
outs[i]->set_dtype(x.dtype());
outs[i]->set_dims(out_dims);
outs[i]->set_layout(x.layout());
outs[i]->share_lod(x);
}
}

Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ void TrilTriuInferMeta(const MetaTensor& x,

void UnbindInferMeta(const MetaTensor& x,
int axis,
std::vector<MetaTensor>* outs);
std::vector<MetaTensor*> outs);

void UnchangedInferMeta(const MetaTensor& x, MetaTensor* out);

Expand Down
22 changes: 22 additions & 0 deletions python/paddle/fluid/tests/unittests/test_unbind_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
import unittest
import numpy as np
from op_test import OpTest, convert_float_to_uint16
import paddle
import paddle.fluid as fluid
import paddle.tensor as tensor
from paddle.fluid import compiler, Program, program_guard, core
from paddle.fluid.framework import _test_eager_guard


class TestUnbind(unittest.TestCase):
Expand All @@ -39,6 +41,25 @@ def test_unbind(self):
assert np.array_equal(res_1, input_1[0, 0:100])
assert np.array_equal(res_2, input_1[1, 0:100])

def test_unbind_dygraph(self):
with fluid.dygraph.guard():
np_x = np.random.random([2, 3]).astype("float32")
x = paddle.to_tensor(np_x)
x.stop_gradient = False
[res_1, res_2] = paddle.unbind(x, 0)
self.assertTrue(np.array_equal(res_1, np_x[0, 0:100]))
self.assertTrue(np.array_equal(res_2, np_x[1, 0:100]))

out = paddle.add_n([res_1, res_2])

np_grad = np.ones(x.shape, np.float32)
out.backward()
self.assertTrue(np.array_equal(x.grad.numpy(), np_grad))

def test_unbind_dygraph_final_state(self):
with _test_eager_guard():
self.test_unbind_dygraph()


class TestLayersUnbind(unittest.TestCase):
def test_layers_unbind(self):
Expand Down Expand Up @@ -157,6 +178,7 @@ def outReshape(self):
class TestUnbindBF16Op(OpTest):
def setUp(self):
self._set_op_type()
self.python_api = paddle.unbind
self.dtype = self.get_dtype()
self.axis = 0
self.num = 3
Expand Down
5 changes: 4 additions & 1 deletion python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1469,6 +1469,9 @@ def unbind(input, axis=0):
# x3.shape [3, 5]
"""
if in_dygraph_mode():
return _C_ops.final_state_unbind(input, axis)

if not isinstance(axis, (int)):
raise TypeError("The type of 'axis' must be int, but received %s." %
(type(axis)))
Expand All @@ -1477,7 +1480,7 @@ def unbind(input, axis=0):
input_shape = input.shape
axis_ = axis if axis >= 0 else len(input_shape) + axis
num = input_shape[axis_]
if paddle.in_dynamic_mode():
if _in_legacy_dygraph():
return _C_ops.unbind(input, num, 'axis', axis)

helper = LayerHelper("unbind", **locals())
Expand Down
6 changes: 6 additions & 0 deletions python/paddle/utils/code_gen/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1939,6 +1939,12 @@
backend : place
data_type : dtype

- api : unbind
args : (Tensor input, int axis)
output : Tensor[]
invoke : unbind_impl(input, axis)
backward : unbind_grad

# unfold
- api : unfold
args : (Tensor x, int[] kernel_sizes, int[] strides, int[] paddings, int[] dilations)
Expand Down
6 changes: 6 additions & 0 deletions python/paddle/utils/code_gen/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1480,6 +1480,12 @@
kernel :
func : trunc_grad

- backward_api : unbind_grad
forward : unbind (Tensor input, int axis) -> Tensor[](out)
args : (Tensor[] out_grad, int axis)
output : Tensor(input_grad)
invoke : stack(out_grad, axis)

- backward_api : unfold_grad
forward : unfold (Tensor x, int[] kernel_sizes, int[] strides, int[] paddings, int[] dilations) -> Tensor(out)
args : (Tensor x, Tensor out_grad, int[] kernel_sizes, int[] strides, int[] paddings, int[] dilations)
Expand Down

0 comments on commit 6708774

Please sign in to comment.