Skip to content

Commit

Permalink
[Prim][PIR] Support composite rules of Llama ops (PaddlePaddle#58018)
Browse files Browse the repository at this point in the history
* pir prim support decomposite rule of ops

* PIR Prim support ops

* change unnecessary error raise

* fix code

* fix code

* support prim op register

* support max vjp

* fix code

* fixx test case
  • Loading branch information
cyber-pioneer committed Oct 13, 2023
1 parent 5fcf600 commit e70e8a5
Show file tree
Hide file tree
Showing 22 changed files with 870 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
'maximum',
'argsort',
'min',
'max',
'batch_norm',
'max_pool2d_with_index',
'pool2d',
Expand Down Expand Up @@ -183,6 +184,7 @@
'maximum',
'argsort',
'min',
'max',
'batch_norm',
'max_pool2d_with_index',
'pool2d',
Expand Down
12 changes: 12 additions & 0 deletions paddle/fluid/primitive/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,23 @@
'tanh_grad',
'transpose_grad',
'concat_grad',
'erf_grad',
'exp_grad',
'expand_grad',
'log_grad',
'gather_nd_grad',
'pad_grad',
'max_grad',
'slice_grad',
'tile_grad',
] # vjp list of primitive op
CUSTOM_VJP = [
'gelu_grad',
'layer_norm_grad',
'dropout_grad',
'silu_grad',
'softmax_grad',
'sqrt_grad',
] # custom vjp list of composite op
VJP_COMPS = PRIM_VJP + CUSTOM_VJP

Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/primitive/primitive.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,4 @@
- full
- cast
- sign
- slice
300 changes: 300 additions & 0 deletions paddle/fluid/primitive/rule/vjp/details.h
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,306 @@ void dropout_grad(const Tensor& mask,
}
}

template <typename T>
void erf_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
auto m_2_sqrt_pi = full<T>(phi::vectorize(x.dims()), M_2_SQRTPI, x.dtype());
auto neg_one = full<T>(phi::vectorize(x.dims()), -1.0, x.dtype());
auto neg_tmp = neg_one * x * x;
auto mul_tmp = m_2_sqrt_pi * exp<T>(neg_tmp);
set_output<T>(out_grad * mul_tmp, x_grad);
}
}

template <typename T>
void expand_grad(const Tensor& x,
const Tensor& out_grad,
const IntArray& shape,
Tensor* x_grad) {
if (x_grad) {
auto out_dims = phi::make_ddim(shape.GetData());
if (out_dims != x.dims()) {
auto axes = get_reduce_dims(x.dims(), out_dims);
if (!axes.size()) {
by_pass<T>(out_grad, x_grad);
} else {
auto reduced = out_grad.sum(phi::vectorize(axes), x.dtype(), false);
if (reduced.dims().size() != x.dims().size()) {
reduced = reshape<T>(reduced, x.shape());
}
set_output<T>(reduced, x_grad);
}
} else {
by_pass<T>(out_grad, x_grad);
}
}
}

template <typename T>
void log_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
// dx = dout / x
set_output<T>(out_grad / x, x_grad);
}
}

template <typename T>
void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
if (out.dtype() == phi::DataType::FLOAT16 ||
out.dtype() == phi::DataType::BFLOAT16) {
Tensor out_promote = cast<T>(out, phi::DataType::FLOAT32);
Tensor out_grad_promote = cast<T>(out_grad, phi::DataType::FLOAT32);
set_output<T>(cast<T>(out_promote * out_grad_promote, out.dtype()),
x_grad);
} else {
set_output<T>(out_grad * out, x_grad);
}
}
}

template <typename T>
void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
// This calculation is important for resnet.
auto x_grad_tmp = (0.5 / out) * out_grad;
set_output<T>(x_grad_tmp, x_grad);
}
}

template <typename T>
void silu_grad(const Tensor& x,
const Tensor& out,
const Tensor& out_grad,
Tensor* x_grad) {
if (x_grad) {
auto org_dtype = x.dtype();
bool need_cast = org_dtype == phi::DataType::FLOAT16 ||
org_dtype == phi::DataType::BFLOAT16;
if (need_cast) {
auto x_cast = cast<T>(x, phi::DataType::FLOAT32);
auto out_cast = cast<T>(out, phi::DataType::FLOAT32);
auto out_grad_cast = cast<T>(out_grad, phi::DataType::FLOAT32);
auto sigmoid = 1.0 / (1.0 + exp<T>(-x_cast));
auto res = out_grad_cast * sigmoid * (1.0 + x_cast - out_cast);
set_output<T>(cast<T>(res, org_dtype), x_grad);
} else {
auto sigmoid = 1.0 / (1.0 + exp<T>(-x));
auto res = out_grad * sigmoid * (1.0 + x - out);
set_output<T>(res, x_grad);
}
}
}

template <typename T>
void softmax_grad(const Tensor& out,
const Tensor& out_grad,
int axis,
Tensor* x_grad) {
if (x_grad) {
if (out_grad.dims().size() > 0) {
if (axis >= 0) {
auto new_out_grad = out_grad * out;
auto tmp_x_grad = new_out_grad -
out * sum<T>(new_out_grad, {axis}, out.dtype(), true);
set_output<T>(tmp_x_grad, x_grad);
} else {
auto new_out_grad = out_grad * out;
auto tmp_x_grad =
new_out_grad - out * sum<T>(new_out_grad,
{out.dims().size() + axis},
out.dtype(),
true);
set_output<T>(tmp_x_grad, x_grad);
}
} else {
set_output<T>(
full<T>(phi::vectorize(out_grad.dims()), 0.0, out_grad.dtype()),
x_grad);
}
}
}

template <typename T>
void gather_nd_grad(const Tensor& x,
const Tensor& index,
const Tensor& out_grad,
Tensor* x_grad) {
if (x_grad) {
auto zero_tensor = full<T>(phi::vectorize(x.dims()), 0.0, x.dtype());
auto x_grad_tmp = scatter_nd_add<T>(zero_tensor, index, out_grad);
set_output<T>(x_grad_tmp, x_grad);
}
}

template <typename T>
void pad_grad(const Tensor& input,
const Tensor& out_grad,
const std::vector<int>& paddings,
const Scalar& pad_value,
Tensor* input_grad) {
if (input_grad) {
size_t rank = input.dims().size();
auto out_dims = out_grad.dims();

std::vector<int64_t> starts(rank, 0);
std::vector<int64_t> ends(rank, 0);
std::vector<int64_t> axes(rank, 0);
std::vector<int64_t> infer_flags(rank, 1);
std::vector<int64_t> decrease_axis({});
for (size_t i = 0; i < rank; ++i) {
starts[i] = static_cast<int64_t>(paddings[2 * i]);
ends[i] = static_cast<int64_t>(out_dims[i] - paddings[2 * i + 1]);
axes[i] = i;
}
auto out_tmp =
slice<T>(out_grad, axes, starts, ends, infer_flags, decrease_axis);
set_output<T>(out_tmp, input_grad);
}
}

template <typename T>
void max_grad(const Tensor& x,
const Tensor& out,
const Tensor& out_grad,
const IntArray& axis,
bool keepdim,
bool reduce_all,
Tensor* x_grad) {
if (!x_grad) {
return;
}
auto zero_tensor = full<T>(phi::vectorize(x.dims()), 0.0, x.dtype());
std::vector<int64_t> x_dim = phi::vectorize<int64_t>(x.dims());
int64_t axis_size = axis.size();
int64_t x_dim_size = x_dim.size();
reduce_all = false;
if (reduce_all || axis_size == 0 || axis_size == x_dim_size) {
reduce_all = true;
} else {
reduce_all = false;
}
auto x_grad_tmp = Tensor();
if (x_dim_size == 0 || x_dim_size == 1 || keepdim) {
auto out_grad_tmp = out_grad.expand(IntArray(x_dim));
auto out_tmp = out.expand(IntArray(x_dim));
auto mask = equal<T>(x, out_tmp);
x_grad_tmp = where<T>(mask, out_grad_tmp, zero_tensor);
} else {
auto axis_ = std::vector<int64_t>();
if (reduce_all) {
for (int64_t i = 0; i < x_dim_size; i++) {
axis_.push_back(i);
}
} else {
axis_ = axis.GetData();
for (int64_t i = 0; i < axis_size; i++) {
if (axis[i] < 0) {
axis_[i] = axis[i] + x_dim_size;
}
}
}
auto out_grad_shape = get_unsqueeze_dims(out_grad, axis_);
auto out_grad_ = reshape<T>(out_grad, out_grad_shape);
auto out_ = reshape<T>(out, out_grad_shape);
auto out_grad_tmp = out_grad_.expand(IntArray(x_dim));
auto out_tmp = out_.expand(IntArray(x_dim));
auto mask = equal<T>(x, out_tmp);
x_grad_tmp = where<T>(mask, out_grad_tmp, zero_tensor);
}
set_output<T>(x_grad_tmp, x_grad);
}

template <typename T>
void slice_grad(const Tensor& input,
const Tensor& out_grad,
const std::vector<int64_t>& axes,
const IntArray& starts,
const IntArray& ends,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis,
Tensor* input_grad) {
if (input_grad) {
size_t rank = input.dims().size();
auto out_dims = out_grad.dims();
std::vector<int64_t> origin_out_shape;
auto in_dims = input.dims();

auto decrease_size = decrease_axis.size();
if (decrease_size > 0) {
if (decrease_size == static_cast<size_t>(in_dims.size())) {
// all dims decrease
out_dims = phi::make_ddim(std::vector<int>(decrease_size, 1));
} else {
origin_out_shape.resize(out_dims.size() + decrease_size, -1);
for (size_t i = 0; i < decrease_size; ++i) {
origin_out_shape[decrease_axis[i]] = 1;
}

int index = 0;
for (size_t i = 0; i < origin_out_shape.size(); ++i) {
if (origin_out_shape[i] == -1) {
origin_out_shape[i] = out_dims[index];
++index;
}
}
out_dims = phi::make_ddim(origin_out_shape);
}
}

std::vector<int> offsets(rank, 0);
std::vector<int> extents(rank, 0);
for (size_t i = 0; i < rank; ++i) {
offsets[i] = 0;
extents[i] = out_dims[i];
}
for (size_t i = 0; i < axes.size(); ++i) {
int axis = axes[i];
int64_t start = starts[i] < 0 ? (starts[i] + in_dims[axis]) : starts[i];
start = std::max(start, static_cast<int64_t>(0));
offsets[axis] = start;
}

std::vector<int> paddings;
for (size_t i = 0; i < rank; ++i) {
paddings.push_back(offsets[i]);
paddings.push_back((in_dims[i] - out_dims[i]) - offsets[i]);
}
if (decrease_size > 0 &&
(decrease_size != static_cast<size_t>(in_dims.size()))) {
auto out_tmp =
pad<T>(reshape<T>(out_grad, origin_out_shape), paddings, 0.0);
set_output<T>(out_tmp, input_grad);
} else {
auto out_tmp = pad<T>(out_grad, paddings, 0.0);
set_output<T>(out_tmp, input_grad);
}
}
}

template <typename T>
void tile_grad(const Tensor& x,
const Tensor& out_grad,
const IntArray& repeat_times,
Tensor* x_grad) {
if (x_grad) {
auto repeat_times_data = repeat_times.GetData();
auto out_grad_shape = phi::vectorize<int>(out_grad.dims());
auto result = out_grad;
for (int i = 0; i < static_cast<int>(repeat_times_data.size()); i++) {
int size = out_grad_shape[i] / repeat_times_data[i];
std::vector<int> sections(repeat_times_data[i], size);
auto split_arr = split<T>(result, IntArray(sections), i);
result = full<T>(phi::vectorize(split_arr[0].dims()), 0.0, x.dtype());
for (int j = 0; j < static_cast<int>(split_arr.size()); j++) {
result = split_arr[j] + result;
}
}
result = reshape<T>(result, x.shape());
set_output<T>(result, x_grad);
}
}

} // namespace details
} // namespace primitive
} // namespace paddle
8 changes: 8 additions & 0 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,10 @@ void BindValue(py::module *m) {
.def("first_use", &Value::first_use, return_value_policy::reference)
.def("has_one_use", &Value::HasOneUse)
.def("use_empty", &Value::use_empty)
.def("replace_all_uses_with",
[](Value &self, Value &op_value) {
self.ReplaceAllUsesWith(op_value);
})
.def("__eq__", &Value::operator==)
.def("__eq__",
[](Value &self, OpResult &other) {
Expand Down Expand Up @@ -610,6 +614,10 @@ void BindOpResult(py::module *m) {
return false;
}
})
.def("replace_all_uses_with",
[](OpResult &self, OpResult &op_result) {
self.ReplaceAllUsesWith(op_result);
})
.def_property(
"stop_gradient",
[](OpResult &self) {
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@
kernel :
func : tile_grad
no_need_buffer : x
composite : tile_grad(x, outgrad, repeat_times, x_grad)
composite : tile_grad(x, out_grad, repeat_times, x_grad)
backward : tile_double_grad

- backward_op : trans_layout_grad
Expand Down
3 changes: 2 additions & 1 deletion python/paddle/autograd/ir_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import collections
import logging
from collections.abc import Sequence

import paddle.pir
Expand Down Expand Up @@ -556,7 +557,7 @@ def create_backward_prune_set(inputs, outputs, no_grad_set, state):
if state.value_to_valuegrad[item] != []:
outputs_set.add(state.value_to_valuegrad[item][0][0])
else:
raise ValueError("input privided by inputs has no use")
logging.warning("input privided by inputs has no use")

inputs_set = set()
for output in outputs:
Expand Down
Loading

0 comments on commit e70e8a5

Please sign in to comment.