Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DIP] Fix DIP codegen for supporting multi-channel processing #288

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions frontend/Interfaces/lib/DIP.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ func.func @corr_2d_constant_padding(%inputImage : memref<?x?xf32>, %kernel : mem
return
}

func.func @corr_2d_nchw_fchw_constant_padding(%inputImage : memref<?x?x?x?xf32>, %kernel : memref<?x?x?x?xf32>, %outputImage : memref<?x?x?x?xf32>, %centerX : index, %centerY : index, %constantValue : f32) attributes{llvm.emit_c_interface}
{
dip.corr_2d_nchw_fchw <CONSTANT_PADDING> %inputImage, %kernel, %outputImage, %centerX, %centerY, %constantValue : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, index, index, f32
return
}

func.func @corr_2d_replicate_padding(%inputImage : memref<?x?xf32>, %kernel : memref<?x?xf32>, %outputImage : memref<?x?xf32>, %centerX : index, %centerY : index, %constantValue : f32) attributes{llvm.emit_c_interface}
{
dip.corr_2d <REPLICATE_PADDING> %inputImage, %kernel, %outputImage, %centerX, %centerY , %constantValue : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>, index, index, f32
Expand Down
32 changes: 32 additions & 0 deletions midend/include/Dialect/DIP/DIPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,38 @@ def DIP_Corr2DOp : DIP_Op<"corr_2d"> {
}];
}

def DIP_Corr2DOpNCHWFCHW : DIP_Op<"corr_2d_nchw_fchw"> {
let summary = [{Same as corr_2d but with support for 4 dimensional input, filter and output.
Input format is expected to be NCHW while the filter is expected to be in FCHW format where,
1. N is batch size.
2. C is no. of input channels.
3. H is height of the respective container.
4. W is width of the respective container.
5. F is no. of output channels.
For example:

```mlir
dip.corr_2d_nchw_fchw CONSTANT_PADDING %inputImage, %kernel, %output, %centerX, %centerY, %constantValue
: memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, index, index, index
```
}];

let arguments = (ins Arg<AnyRankedOrUnrankedMemRef, "inputMemref",
[MemRead]>:$memrefI,
Arg<AnyRankedOrUnrankedMemRef, "kernelMemref",
[MemRead]>:$memrefK,
Arg<AnyRankedOrUnrankedMemRef, "outputMemref",
[MemRead]>:$memrefCO,
Index : $centerX,
Index : $centerY,
AnyTypeOf<[AnyI8, AnyI32, AnyI64, AnyFloat]> : $constantValue,
DIP_BoundaryOptionAttr:$boundary_option);

let assemblyFormat = [{
$boundary_option $memrefI `,` $memrefK `,` $memrefCO `,` $centerX `,` $centerY `,` $constantValue attr-dict `:` type($memrefI) `,` type($memrefK) `,` type($memrefCO) `,` type($centerX) `,` type($centerY) `,` type($constantValue)
}];
}

def DIP_CorrFFT2DOp : DIP_Op<"corrfft_2d">
{
let summary = [{
Expand Down
12 changes: 12 additions & 0 deletions midend/include/Utils/DIPUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,18 @@ void traverseImagewBoundaryExtrapolation(
Value constantValue, Value strideVal, Type elemTy,
buddy::dip::BoundaryOption boundaryOptionAttr, int64_t stride, DIP_OP op);

// Intended for the same purpose as above function on 4 dimensional memref
// inputs. These memrefs must be in NCHW_FCHW format were,
// 1. N is the batch size.
// 2. C is the no. of channels in input.
// 3. H is the height of the respective container.
// 4. W is the width of the respective container.
void traverseImagewBoundaryExtrapolation4DMemRefsNCHWFCHW(
OpBuilder &rewriter, Location loc, MLIRContext *ctx, Value input,
Value kernel, Value output, Value centerX, Value centerY,
Value constantValue, Value strideVal, Type elemTy,
buddy::dip::BoundaryOption boundaryOptionAttr, int64_t stride, DIP_OP op);

// Function for applying type check mechanisms for all DIP dialect operations.
template <typename DIPOP>
DIP_ERROR checkDIPCommonTypes(DIPOP op, const std::vector<Value> &args);
Expand Down
52 changes: 52 additions & 0 deletions midend/lib/Conversion/LowerDIP/LowerDIPPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,57 @@ class DIPCorr2DOpLowering : public OpRewritePattern<dip::Corr2DOp> {
int64_t stride;
};

class DIPCorr2DOpNCHWFCHWLowering
: public OpRewritePattern<dip::Corr2DOpNCHWFCHW> {
public:
using OpRewritePattern<dip::Corr2DOpNCHWFCHW>::OpRewritePattern;

explicit DIPCorr2DOpNCHWFCHWLowering(MLIRContext *context,
int64_t strideParam)
: OpRewritePattern(context) {
stride = strideParam;
}

LogicalResult matchAndRewrite(dip::Corr2DOpNCHWFCHW op,
PatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto *ctx = op->getContext();

// Register operand values.
Value input = op->getOperand(0);
Value kernel = op->getOperand(1);
Value output = op->getOperand(2);
Value centerX = op->getOperand(3);
Value centerY = op->getOperand(4);
Value constantValue = op->getOperand(5);
dip::BoundaryOption boundaryOptionAttr = op.getBoundaryOption();
Value strideVal = rewriter.create<arith::ConstantIndexOp>(loc, stride);

auto inElemTy = input.getType().cast<MemRefType>().getElementType();
dip::DIP_ERROR error = dip::checkDIPCommonTypes<dip::Corr2DOpNCHWFCHW>(
op, {input, kernel, output, constantValue});

if (error == dip::DIP_ERROR::INCONSISTENT_TYPES) {
return op->emitOpError() << "input, kernel, output and constant must "
"have the same element type";
} else if (error == dip::DIP_ERROR::UNSUPPORTED_TYPE) {
return op->emitOpError() << "supports only f32, f64 and integer types. "
<< inElemTy << "is passed";
}

traverseImagewBoundaryExtrapolation4DMemRefsNCHWFCHW(
rewriter, loc, ctx, input, kernel, output, centerX, centerY,
constantValue, strideVal, inElemTy, boundaryOptionAttr, stride,
dip::DIP_OP::CORRELATION_2D);
// Remove the origin convolution operation.
rewriter.eraseOp(op);
return success();
}

private:
int64_t stride;
};

class DIPCorrFFT2DOpLowering : public OpRewritePattern<dip::CorrFFT2DOp> {
public:
using OpRewritePattern<dip::CorrFFT2DOp>::OpRewritePattern;
Expand Down Expand Up @@ -1305,6 +1356,7 @@ class DIPMorphGrad2DOpLowering : public OpRewritePattern<dip::MorphGrad2DOp> {
void populateLowerDIPConversionPatterns(RewritePatternSet &patterns,
int64_t stride) {
patterns.add<DIPCorr2DOpLowering>(patterns.getContext(), stride);
patterns.add<DIPCorr2DOpNCHWFCHWLowering>(patterns.getContext(), stride);
patterns.add<DIPCorrFFT2DOpLowering>(patterns.getContext(), stride);
patterns.add<DIPRotate2DOpLowering>(patterns.getContext(), stride);
patterns.add<DIPResize2DOpLowering>(patterns.getContext(), stride);
Expand Down
Loading
Loading