Skip to content

Commit

Permalink
[MLIR][ONNX] Add support for onnx.ScatterND
Browse files Browse the repository at this point in the history
This commit adds support for onnx.ScatterND op in the onnx pipeline.

Signed-Off-by: Gaurav Shukla <gaurav.shukla@amd.com>
  • Loading branch information
Shukla-Gaurav committed Jun 21, 2024
1 parent acd57a3 commit cff22ab
Show file tree
Hide file tree
Showing 3 changed files with 512 additions and 12 deletions.
287 changes: 287 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3258,4 +3258,291 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
rewriter.replaceOp(binder.op, inputSequence);
return success();
});
patterns.onOp(
"ScatterND", 16,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value data, indices, updates;
std::string reduction;
if (binder.tensorOperandAtIndex(data, 0) ||
binder.tensorOperandAtIndex(indices, 1) ||
binder.tensorOperandAtIndex(updates, 2) ||
binder.tensorResultType(resultType) ||
binder.customOpNameStringAttr(reduction, "reduction", "none"))
return failure();

// Map onnx reduction type to torch reduction type.
if (reduction == "add") {
reduction = "sum";
} else if (reduction == "mul") {
reduction = "prod";
} else if (reduction == "max") {
reduction = "amax";
} else if (reduction == "min") {
reduction = "amin";
} else if (reduction != "none") {
return rewriter.notifyMatchFailure(
binder.op, "expects reduction to be one of add, mul, max, min, "
"none(default)");
}

Location loc = binder.getLoc();
auto dataTy = cast<Torch::ValueTensorType>(data.getType());
auto indicesTy = cast<Torch::ValueTensorType>(indices.getType());
auto updatesTy = cast<Torch::ValueTensorType>(updates.getType());
if (!dataTy || !dataTy.hasSizes())
return failure();
if (!indicesTy || !indicesTy.hasSizes())
return failure();
if (!updatesTy || !updatesTy.hasSizes())
return failure();

// step 1. Get shapes and ranks of data, indices and updates.
// The last dimension of indices is expected to be static.
ArrayRef<int64_t> dataShape = dataTy.getSizes();
int64_t dataRank = dataShape.size();
ArrayRef<int64_t> updatesShape = updatesTy.getSizes();
int64_t updatesRank = updatesShape.size();
ArrayRef<int64_t> indicesShape = indicesTy.getSizes();
int64_t indicesRank = indicesShape.size();
int64_t indicesLastDim = indicesShape.back();
// Given data tensor of rank r >= 1, indices tensor of rank q >= 1, and
// updates tensor of rank q + r - indices_shape[-1] - 1, the output is
// produced by creating a copy of the input data, and then updating
// its value to values specified by updates at specific index positions
// specified by indices. Its output shape is the same as the shape of
// data.
// indices_shape[-1] must be static to have deterministic ranks.
if (dataRank < 1 || indicesRank < 1 || updatesRank < 1)
return rewriter.notifyMatchFailure(
binder.op, "expected data, indices and updates rank to be >= 1");
if (indicesLastDim == Torch::kUnknownSize || indicesLastDim <= 0)
return rewriter.notifyMatchFailure(
binder.op, "expected last dimension of indices to be static and "
"greater than zero");

// step 2. Get dimension list of data.
SmallVector<Value> dataDims;
for (int64_t i = 0; i < dataRank; ++i) {
Value k = rewriter.create<Torch::ConstantIntOp>(loc, i);
Value dataDim = rewriter.create<Torch::AtenSizeIntOp>(loc, data, k);
dataDims.push_back(dataDim);
}

// step 3. Get dimension list of indices.
Value constZero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value constOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
SmallVector<Value> indicesDimsMinusOne;
Value indicesFlattenDim = constOne;
for (int64_t i = 0; i < indicesRank - 1; ++i) {
Value k = rewriter.create<Torch::ConstantIntOp>(loc, i);
Value indicesDim =
rewriter.create<Torch::AtenSizeIntOp>(loc, indices, k);
indicesDimsMinusOne.push_back(indicesDim);
indicesFlattenDim = rewriter.create<Torch::AtenMulIntOp>(
loc, indicesFlattenDim, indicesDim);
}
ArrayRef<int64_t> indicesShapeMinusOne = indicesShape.drop_back();

// Algorithm: We can not directly perform torch.scatter as it requires
// the ranks of data(`r`), indices(`q`) and updates to be same.
// So we will perform collapse and expand operations to match the
// ranks of data, indices and updates(making sure the semantics of the
// // onnx.scatter_nd are preserved), perform torch.scatter operation,
// later unflatten the scatter result to match onnx.scatter_nd output.
// For example, assuming
// indices is of shape (4, 5, 3, 2), data is (4, 10, 11, 7, 4) and
// updates is (4, 5, 3, 11, 7, 4). Firstly, modify indices to 1-D
// indexing as the torch.scatter op supports only single dimensional
// indexing. (this algorithm would have been simpler if we can get a
// torch op that supports indexing at multiple dimensions
// simultaneously). 1-D indexed indices will be of shape (4, 5, 3, 1),
// now materialize it to `r-indices_shape[-1]` dimension of data i.e.
// reshaping it to the shape (4, 5, 3, 1, 1, 1). Next step is to
// flatten+expand the indices and flatten the data to (60, 11, 7, 4) and
// (40, 11, 7, 4) shapes respectively and then perform the torch.scatter
// operation. Post the scatter operation, unflatten the first dimension
// of result to (4, 10, 11, 7, 4) which is our required result.

// step 4. Convert indices_shape[-1] dimensional indexing to 1D
// indexing.
Value sliceDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(indicesRank - 1));
SmallVector<int64_t> indicesSliceShape(indicesShapeMinusOne);
indicesSliceShape.push_back(1);
auto indicesSliceTy = rewriter.getType<Torch::ValueTensorType>(
indicesSliceShape, indicesTy.getOptionalDtype());

Value start = constZero;
Value updatedIndices;
for (int64_t i = 0; i < indicesLastDim; ++i) {
Value end = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i + 1));
Value indicesSlice = rewriter.create<Torch::AtenSliceTensorOp>(
loc, indicesSliceTy, indices, sliceDim, start, end,
/*step=*/constOne);
start = end;
// Apply bounds checking on the indices slice.
auto boolTy = rewriter.getType<Torch::ValueTensorType>(
indicesSliceShape, rewriter.getI1Type());
Value lt = rewriter.create<Torch::AtenLtScalarOp>(
loc, boolTy, indicesSlice, constZero);
Value add = rewriter.create<Torch::AtenAddScalarOp>(
loc, indicesSliceTy, indicesSlice, dataDims[i],
/*alpha=*/constOne);
indicesSlice = rewriter.create<Torch::AtenWhereSelfOp>(
loc, indicesSliceTy, lt, add, indicesSlice);
if (i == 0) {
updatedIndices = indicesSlice;
continue;
}
updatedIndices = rewriter.create<Torch::AtenAddTensorOp>(
loc, indicesSliceTy, indicesSlice, updatedIndices, dataDims[i]);
}

// step 5. Compute all the required result types here.
SmallVector<int64_t> reshapeIndicesShape(indicesShapeMinusOne);
SmallVector<Value> reshapeIndicesDims(indicesDimsMinusOne);
// Determine the collapsed dim size of indices(index_shape[-1] is not
// part of collapsing as we already removed it by 1-D indexing).
SmallVector<int64_t> flattenIndicesShape;
auto indicesCt = 1;
for (int64_t i = 0; i < indicesRank - 1; ++i) {
if (indicesShape[i] == Torch::kUnknownSize) {
indicesCt = Torch::kUnknownSize;
break;
}
indicesCt *= indicesShape[i];
}
flattenIndicesShape.push_back(indicesCt);
// Compute the shape of expand op.
SmallVector<Value> expandIndicesDims;
expandIndicesDims.push_back(indicesFlattenDim);
SmallVector<int64_t> expandIndicesShape;
expandIndicesShape.push_back(indicesCt);
// Determine the collapsed dim size of data.
SmallVector<int64_t> flattenDataShape;
auto dataCt = 1;
for (int64_t i = 0; i < indicesLastDim; ++i) {
if (dataShape[i] == Torch::kUnknownSize) {
dataCt = Torch::kUnknownSize;
break;
}
dataCt *= dataShape[i];
}
flattenDataShape.push_back(dataCt);
// Determine the collapsed dim size of updates.
SmallVector<int64_t> flattenUpdatesShape;
auto updatesCt = 1;
for (int64_t i = 0; i < indicesRank - 1; ++i) {
if (updatesShape[i] == Torch::kUnknownSize) {
updatesCt = Torch::kUnknownSize;
break;
}
updatesCt *= updatesShape[i];
}
flattenUpdatesShape.push_back(updatesCt);
flattenUpdatesShape.insert(flattenUpdatesShape.end(),
updatesShape.begin() + indicesRank - 1,
updatesShape.end());
// Append `r-indices_shape[-1]` unit or data dims appropriately to all
// result types.
for (int64_t i = indicesLastDim; i < dataRank; ++i) {
reshapeIndicesShape.push_back(1);
flattenIndicesShape.push_back(1);
flattenDataShape.push_back(dataShape[i]);
expandIndicesShape.push_back(dataShape[i]);
reshapeIndicesDims.push_back(constOne);
expandIndicesDims.push_back(dataDims[i]);
}

// step 6. Reshape 1-D indexed indices to match the rank of flattened
// data by inserting unit dimensions.
auto intListTy = rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::IntType>());
Value reshapeIndicesSizeList =
rewriter.create<Torch::PrimListConstructOp>(loc, intListTy,
reshapeIndicesDims);
auto reshapeIndicesTy = rewriter.getType<Torch::ValueTensorType>(
reshapeIndicesShape, indicesTy.getOptionalDtype());
Value reshapedIndices = rewriter.create<Torch::AtenViewOp>(
loc, reshapeIndicesTy, updatedIndices, reshapeIndicesSizeList);

// step 7. Flatten `q-1` dimensions of the indices and updates.
auto flattenIndicesTy = rewriter.getType<Torch::ValueTensorType>(
flattenIndicesShape, indicesTy.getOptionalDtype());
auto flattenUpdatesTy = rewriter.getType<Torch::ValueTensorType>(
flattenUpdatesShape, updatesTy.getOptionalDtype());
Value flattenedIndices = reshapedIndices;
Value flattenedUpdates = updates;
if (indicesRank == 1) {
flattenedIndices = rewriter.create<Torch::AtenUnsqueezeOp>(
loc, flattenIndicesTy, reshapedIndices, constZero);
flattenedUpdates = rewriter.create<Torch::AtenUnsqueezeOp>(
loc, flattenUpdatesTy, updates, constZero);
} else if (indicesRank > 1) {
Value endDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(indicesRank - 2));
flattenedIndices = rewriter.create<Torch::AtenFlattenUsingIntsOp>(
loc, flattenIndicesTy, reshapedIndices, constZero, endDim);
flattenedUpdates = rewriter.create<Torch::AtenFlattenUsingIntsOp>(
loc, flattenUpdatesTy, updates, constZero, endDim);
}

// step 8. Expand `r-indices_shape[-1]` dims of flattened indices.
auto expandIndicesTy = rewriter.getType<Torch::ValueTensorType>(
expandIndicesShape, indicesTy.getOptionalDtype());
Value expandIndicesSizeList =
rewriter.create<Torch::PrimListConstructOp>(loc, intListTy,
expandIndicesDims);
Value constFalse = rewriter.create<Torch::ConstantBoolOp>(
loc, rewriter.getType<Torch::BoolType>(),
rewriter.getBoolAttr(false));
Value expandedIndices = rewriter.create<Torch::AtenExpandOp>(
loc, expandIndicesTy, flattenedIndices, expandIndicesSizeList,
/*implicit=*/constFalse);

// step 9. Flatten indices_shape[-1] dimensions of data.
auto flattenDataTy = rewriter.getType<Torch::ValueTensorType>(
flattenDataShape, dataTy.getOptionalDtype());
Value endDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(indicesLastDim - 1));
Value flattenedData = rewriter.create<Torch::AtenFlattenUsingIntsOp>(
loc, flattenDataTy, data, constZero, endDim);

// step 10. Now we have flattenedData, expandedIndices and
// flattenedUpdates of same rank to perform scatter operation.
auto scatterTy = rewriter.getType<Torch::ValueTensorType>(
flattenDataShape, dataTy.getOptionalDtype());

Value scatter;
if (reduction == "none") {
scatter = rewriter.create<Torch::AtenScatterSrcOp>(
loc, scatterTy, flattenedData, /*axis=*/constZero,
expandedIndices, flattenedUpdates);
} else {
Value cstReduction =
rewriter.create<Torch::ConstantStrOp>(loc, reduction);
Value constTrue = rewriter.create<Torch::ConstantBoolOp>(
loc, rewriter.getType<Torch::BoolType>(),
rewriter.getBoolAttr(true));
scatter = rewriter.create<Torch::AtenScatterReduceTwoOp>(
loc, scatterTy, flattenedData, /*axis=*/constZero,
expandedIndices, flattenedUpdates, cstReduction,
/*include_self=*/constTrue);
}

// step 11. Unflatten the collapsed data dims of scatter result.
if (indicesLastDim == 1) {
rewriter.replaceOp(binder.op, scatter);
return success();
}
Value unflattenSizeList = rewriter.create<Torch::PrimListConstructOp>(
loc, intListTy, dataDims);
rewriter.replaceOpWithNewOp<Torch::AtenUnflattenIntOp>(
binder.op, resultType, scatter, constZero, unflattenSizeList);
return success();
});
}
24 changes: 12 additions & 12 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2682,29 +2682,29 @@
"ScatterValueFloatModule_basic",
# Failure - onnx_lowering: onnx.ScatterND
"IndexPut1DFloatAccumulateModule_basic",
"IndexPut1DFloatNonAccumulateModule_basic",
# "IndexPut1DFloatNonAccumulateModule_basic",
"IndexPut1DIntAccumulateModule_basic",
"IndexPut1DIntNonAccumulateModule_basic",
# "IndexPut1DIntNonAccumulateModule_basic",
"IndexPut2DFloatAccumulateModule_basic",
"IndexPut2DFloatNonAccumulateModule_basic",
# "IndexPut2DFloatNonAccumulateModule_basic",
"IndexPut2DIntAccumulateModule_basic",
"IndexPut2DIntNonAccumulateModule_basic",
# "IndexPut2DIntNonAccumulateModule_basic",
"IndexPut3DFloatAccumulateModule_basic",
"IndexPut3DFloatNonAccumulateModule_basic",
# "IndexPut3DFloatNonAccumulateModule_basic",
"IndexPut3DIntAccumulateModule_basic",
"IndexPut3DIntNonAccumulateModule_basic",
# "IndexPut3DIntNonAccumulateModule_basic",
"IndexPutHackedTwin1DFloatAccumulateModule_basic",
"IndexPutHackedTwin1DFloatNonAccumulateModule_basic",
# "IndexPutHackedTwin1DFloatNonAccumulateModule_basic",
"IndexPutHackedTwin1DIntAccumulateModule_basic",
"IndexPutHackedTwin1DIntNonAccumulateModule_basic",
# "IndexPutHackedTwin1DIntNonAccumulateModule_basic",
"IndexPutHackedTwin2DFloatAccumulateModule_basic",
"IndexPutHackedTwin2DFloatNonAccumulateModule_basic",
# "IndexPutHackedTwin2DFloatNonAccumulateModule_basic",
"IndexPutHackedTwin2DIntAccumulateModule_basic",
"IndexPutHackedTwin2DIntNonAccumulateModule_basic",
# "IndexPutHackedTwin2DIntNonAccumulateModule_basic",
"IndexPutHackedTwin3DFloatAccumulateModule_basic",
"IndexPutHackedTwin3DFloatNonAccumulateModule_basic",
# "IndexPutHackedTwin3DFloatNonAccumulateModule_basic",
"IndexPutHackedTwin3DIntAccumulateModule_basic",
"IndexPutHackedTwin3DIntNonAccumulateModule_basic",
# "IndexPutHackedTwin3DIntNonAccumulateModule_basic",
# RuntimeError: unsupported input type: Device
"PrimsIotaModule_basic",
# Error: 'aten::renorm' to ONNX opset version 17 is not supported.
Expand Down
Loading

0 comments on commit cff22ab

Please sign in to comment.