Skip to content

Commit

Permalink
[CINN] fix bug: remove redandunt conversion in select (PaddlePaddle#6…
Browse files Browse the repository at this point in the history
  • Loading branch information
hxzd5568 authored and co63oc committed May 13, 2024
1 parent acceeed commit 19ed190
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 28 deletions.
10 changes: 0 additions & 10 deletions paddle/cinn/hlir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,6 @@
func : scale
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : select
args : (Tensor condition, Tensor true_value, Tensor false_value )
output : Tensor(out)
infer_meta :
func : WhereInferMeta
spmd_rule: WhereInferSpmd
kernel :
func : where
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : slice
args : (Tensor x, int64_t[] axes, int64_t[] starts, int64_t[] ends, int64_t[] infer_flags, int64_t[] decrease_axis)
output : Tensor
Expand Down
18 changes: 0 additions & 18 deletions paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1073,23 +1073,6 @@ class SigmoidOpPattern
}
};

class WhereOpPattern : public pir::OpRewritePattern<paddle::dialect::WhereOp> {
public:
using pir::OpRewritePattern<paddle::dialect::WhereOp>::OpRewritePattern;

bool MatchAndRewrite(paddle::dialect::WhereOp op,
pir::PatternRewriter &rewriter) const override {
auto select_op = rewriter.Build<cinn::dialect::SelectOp>(
op->operand_source(0), op->operand_source(1), op->operand_source(2));

rewriter.ReplaceAllUsesWith(op.result(0), select_op.result(0));

rewriter.EraseOp(op);

return true;
}
};

class GatherOpPattern
: public pir::OpRewritePattern<paddle::dialect::GatherOp> {
public:
Expand Down Expand Up @@ -1154,7 +1137,6 @@ pir::RewritePatternSet PdOpToCinnOpPass::InitializePatterns(
ps.Add<UnsqueezeOpPattern>(context);
ps.Add<SigmoidOpPattern>(context);
ps.Add<GatherOpPattern>(context);
ps.Add<WhereOpPattern>(context);
ps.Add<FlattenOpPattern>(context);

return ps;
Expand Down
1 change: 1 addition & 0 deletions paddle/cinn/hlir/framework/pir/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ const std::unordered_map<std::string, std::string> CompatibleInfo::OP_NAMES = {
{"pd_op.unsqueeze", "reshape"},
{"pd_op.split_with_num", "split"},
{"pd_op.expand", "broadcast_to"},
{"pd_op.where", "select"},
{"cinn_op.generate_shape", "generate_shape"},
{"cinn_op.broadcast", "broadcast_to"}};

Expand Down

0 comments on commit 19ed190

Please sign in to comment.