@@ -12580,6 +12580,198 @@ class DecomposeAtenRoundDecimalsOp
1258012580};
1258112581} // namespace
1258212582
12583+ namespace {
12584+ class DecomposeAtenAsStridedOp : public OpRewritePattern<AtenAsStridedOp> {
12585+ public:
12586+ using OpRewritePattern<AtenAsStridedOp>::OpRewritePattern;
12587+ LogicalResult matchAndRewrite(AtenAsStridedOp op,
12588+ PatternRewriter &rewriter) const override {
12589+
12590+ // The `aten.as_strided` operation is decomposed into a series of
12591+ // operations that compute the indices based on the provided sizes and
12592+ // strides, and then index into the flattened input tensor as follows:
12593+
12594+ // input_flat = input.view(-1)
12595+ //
12596+ // for dim, s in enumerate(self.size):
12597+ // arange = torch.arange(s)
12598+ // view_shape = []
12599+ // for i in range(len(self.size)):
12600+ // if i == dim:
12601+ // view_shape.append(-1)
12602+ // else:
12603+ // view_shape.append(1)
12604+ // arange = arange.view(view_shape)
12605+ // if dim != 0:
12606+ // idx = idx + arange * self.stride[dim]
12607+ //
12608+ // # Flatten indices and add offset
12609+ // final_indices = idx.reshape(-1) + self.storage_offset
12610+ //
12611+ // # Index the flattened input tensor
12612+ // output = input_flat[final_indices]
12613+ //
12614+ // # Reshape to desired output size
12615+ // return output.view(self.size)
12616+
12617+ Location loc = op.getLoc();
12618+ MLIRContext *context = op->getContext();
12619+ Value input = op.getSelf();
12620+ auto inputType = dyn_cast<BaseTensorType>(input.getType());
12621+
12622+ if (!inputType || !inputType.hasSizes() || !inputType.areAllSizesKnown())
12623+ return rewriter.notifyMatchFailure(op, "input must have known sizes");
12624+
12625+ SmallVector<int64_t> sizesInts;
12626+ if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(sizesInts)))
12627+ return rewriter.notifyMatchFailure(
12628+ op, "sizes must be a list of constant ints");
12629+
12630+ SmallVector<int64_t> stridesInts;
12631+ if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(stridesInts)))
12632+ return rewriter.notifyMatchFailure(
12633+ op, "strides must be a list of constant ints");
12634+
12635+ int64_t storageOffset = 0;
12636+ if (!isa<Torch::NoneType>(op.getStorageOffset().getType())) {
12637+ if (!matchPattern(op.getStorageOffset(),
12638+ m_TorchConstantInt(&storageOffset)))
12639+ return rewriter.notifyMatchFailure(
12640+ op, "storage_offset must be a constant integer");
12641+ }
12642+
12643+ ArrayRef<int64_t> inputSizes = inputType.getSizes();
12644+ int64_t inputRank = inputSizes.size();
12645+ int64_t resultRank = sizesInts.size();
12646+
12647+ Value cstZero =
12648+ rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
12649+ if (inputRank > 1) {
12650+ // If the input is not a 1-d tensor, we need to flatten it
12651+ // to a 1D tensor before applying the strided indexing.
12652+ int64_t flattenedInputSize = 1;
12653+ for (int64_t size : inputSizes)
12654+ flattenedInputSize *= size;
12655+
12656+ auto flattenedInputTy =
12657+ cast<BaseTensorType>(inputType.getWithSizesAndDtype(
12658+ {flattenedInputSize}, inputType.getOptionalDtype()));
12659+
12660+ Value end = rewriter.create<ConstantIntOp>(
12661+ loc, rewriter.getI64IntegerAttr(inputRank - 1));
12662+ input = rewriter.create<AtenFlattenUsingIntsOp>(loc, flattenedInputTy,
12663+ input, cstZero, end);
12664+ }
12665+
12666+ Value cstOne =
12667+ rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
12668+ Value cstMinusOne =
12669+ rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1));
12670+
12671+ SmallVector<int64_t> viewShapeInts(resultRank, 1);
12672+ SmallVector<Value> viewShapeListElems(resultRank, cstOne);
12673+
12674+ auto si64Type = IntegerType::get(context, 64, IntegerType::Signed);
12675+ Value finalIndices;
12676+ for (unsigned dim = 0; dim < sizesInts.size(); dim++) {
12677+ int64_t size = sizesInts[dim];
12678+ Value cstNone = rewriter.create<ConstantNoneOp>(loc);
12679+ Value end =
12680+ rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(size));
12681+
12682+ auto arangeType =
12683+ ValueTensorType::get(context, llvm::ArrayRef(size), si64Type);
12684+ Value index = rewriter.create<Torch::AtenArangeOp>(
12685+ loc, arangeType, end, cstNone, cstNone, cstNone, cstNone);
12686+
12687+ // Set the current dimension to -1 for broadcasting
12688+ viewShapeInts[dim] = -1;
12689+ viewShapeListElems[dim] = cstMinusOne;
12690+
12691+ Value viewShapeList = rewriter.create<Torch::PrimListConstructOp>(
12692+ loc, Torch::ListType::get(Torch::IntType::get(context)),
12693+ viewShapeListElems);
12694+
12695+ auto viewType = ValueTensorType::get(
12696+ context, llvm::ArrayRef(viewShapeInts), si64Type);
12697+ index = rewriter.create<AtenViewOp>(loc, viewType, index, viewShapeList);
12698+
12699+ // Multiply the index with the stride for the current dimension
12700+ Value cstStride = rewriter.create<ConstantIntOp>(
12701+ loc, rewriter.getI64IntegerAttr(stridesInts[dim]));
12702+ index = rewriter.create<AtenMulScalarOp>(loc, viewType, index, cstStride);
12703+
12704+ // Reset the current dimension to 1 for the next iteration
12705+ viewShapeInts[dim] = 1;
12706+ viewShapeListElems[dim] = cstOne;
12707+
12708+ if (dim == 0) {
12709+ finalIndices = index;
12710+ continue;
12711+ }
12712+
12713+ // calculate common shape for broadcast
12714+ SmallVector<int64_t> broadcastShape;
12715+ SmallVector<Value> broadcastShapeValue;
12716+ computeBroadcastShape(rewriter, loc, finalIndices, index, broadcastShape,
12717+ broadcastShapeValue);
12718+ Type broadcastType = ValueTensorType::get(
12719+ context, llvm::ArrayRef(broadcastShape), si64Type);
12720+
12721+ finalIndices = rewriter.create<AtenAddTensorOp>(
12722+ loc, broadcastType, finalIndices, index, cstOne);
12723+ }
12724+
12725+ int64_t flattenedResultSize = 1;
12726+ for (int64_t size : sizesInts)
12727+ flattenedResultSize *= size;
12728+
12729+ // Flattening the indices and adding the storage offset
12730+ finalIndices = rewriter.create<AtenFlattenUsingIntsOp>(
12731+ loc,
12732+ ValueTensorType::get(context, llvm::ArrayRef(flattenedResultSize),
12733+ si64Type),
12734+ finalIndices, cstZero, cstMinusOne); // -1 means flatten all
12735+
12736+ if (storageOffset != 0) {
12737+ Value cstStorageOffset = rewriter.create<ConstantIntOp>(
12738+ loc, rewriter.getI64IntegerAttr(storageOffset));
12739+ finalIndices = rewriter.create<AtenAddScalarOp>(
12740+ loc, finalIndices.getType(), finalIndices, cstStorageOffset, cstOne);
12741+ }
12742+
12743+ // Index the flattened input tensor
12744+ Type listElemType =
12745+ inputType.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
12746+ /*optionalDtype=*/nullptr);
12747+ Value indicesList = rewriter.create<Torch::PrimListConstructOp>(
12748+ loc, Torch::ListType::get(listElemType),
12749+ SmallVector<Value>{finalIndices});
12750+
12751+ auto flattenedResultTy =
12752+ ValueTensorType::get(context, llvm::ArrayRef(flattenedResultSize),
12753+ inputType.getOptionalDtype());
12754+ Value result = rewriter.create<AtenIndexTensorOp>(loc, flattenedResultTy,
12755+ input, indicesList);
12756+
12757+ // Reshape the result to the desired output size
12758+ SmallVector<Value> sizesIntsValues;
12759+ for (int64_t size : sizesInts) {
12760+ sizesIntsValues.push_back(rewriter.create<ConstantIntOp>(
12761+ loc, rewriter.getI64IntegerAttr(size)));
12762+ }
12763+ Value resultSizeList = rewriter.create<Torch::PrimListConstructOp>(
12764+ loc, Torch::ListType::get(Torch::IntType::get(context)),
12765+ sizesIntsValues);
12766+ result =
12767+ rewriter.create<AtenViewOp>(loc, op.getType(), result, resultSizeList);
12768+
12769+ rewriter.replaceOp(op, result);
12770+ return success();
12771+ }
12772+ };
12773+ } // namespace
12774+
1258312775namespace {
1258412776class DecomposeComplexOpsPass
1258512777 : public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@@ -12904,6 +13096,7 @@ class DecomposeComplexOpsPass
1290413096 patterns);
1290513097 addPatternIfTargetOpIsIllegal<DecomposeAten_AssertScalarOp>(patterns);
1290613098 addPatternIfTargetOpIsIllegal<DecomposeAtenRoundDecimalsOp>(patterns);
13099+ addPatternIfTargetOpIsIllegal<DecomposeAtenAsStridedOp>(patterns);
1290713100
1290813101 GreedyRewriteConfig config;
1290913102 config.setUseTopDownTraversal(true);
0 commit comments