Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,19 @@ macro(torch_mlir_add_llvm_external_project name identifier location)
set(LLVM_EXTERNAL_PROJECTS ${LLVM_EXTERNAL_PROJECTS} CACHE STRING "" FORCE)
endmacro()

option(TORCH_MLIR_ENABLE_STABLEHLO "Add stablehlo dialect" ON)
# Optional conversion targets.
if(TORCH_MLIR_ENABLE_LINALG)
add_definitions(-DTORCH_MLIR_ENABLE_LINALG)
endif()
option(TORCH_MLIR_ENABLE_STABLEHLO "Add stablehlo dialect conversions" ON)
if(TORCH_MLIR_ENABLE_STABLEHLO)
add_definitions(-DTORCH_MLIR_ENABLE_STABLEHLO)
endif()
option(TORCH_MLIR_ENABLE_TOSA "Add tosa dialect conversions" ON)
if(TORCH_MLIR_ENABLE_TOSA)
add_definitions(-DTORCH_MLIR_ENABLE_TOSA)
endif()
option(TORCH_MLIR_ENABLE_LINALG "Add linalg dialect" ON)

option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON)
option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF)
Expand Down
19 changes: 12 additions & 7 deletions include/torch-mlir/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
if(TORCH_MLIR_ENABLE_STABLEHLO)
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_STABLEHLO)
else()
mlir_tablegen(Passes.h.inc -gen-pass-decls)
endif()
add_public_tablegen_target(TorchMLIRConversionPassIncGen)
mlir_tablegen(Passes.h.inc -gen-pass-decls)
add_public_tablegen_target(TorchMLIRConversionCommonPassIncGen)
add_mlir_doc(Passes TorchMLIRConversionCommonPasses ./ -gen-pass-doc)

add_mlir_doc(Passes TorchMLIRConversionPasses ./ -gen-pass-doc)
if(TORCH_MLIR_ENABLE_LINALG)
add_subdirectory(TorchToLinalg)
endif()
if(TORCH_MLIR_ENABLE_TOSA)
add_subdirectory(TorchToTosa)
endif()
if(TORCH_MLIR_ENABLE_STABLEHLO)
add_subdirectory(TorchToStablehlo)
endif()
15 changes: 10 additions & 5 deletions include/torch-mlir/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,21 @@
//
//===----------------------------------------------------------------------===//

#ifndef TORCHMLIR_CONVERSION_PASSES_H
#define TORCHMLIR_CONVERSION_PASSES_H
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include <memory>

namespace mlir {
namespace torch {

/// Registers all torch-mlir conversion passes.
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToArithPass();
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToSCFPass();

// Note that this only registers common conversion passes. Backend
// specific passes with their own Passes.h in a subdirectory must be
// included/registered explicitly as they are all optional.
void registerConversionPasses();

} // namespace torch
} // namespace mlir

#endif // TORCHMLIR_CONVERSION_PASSES_H
130 changes: 1 addition & 129 deletions include/torch-mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
include "mlir/Pass/PassBase.td"

//===----------------------------------------------------------------------===//
// Torch conversions
// Common conversions
//===----------------------------------------------------------------------===//

def ConvertTorchToArith : Pass<"convert-torch-to-arith", "func::FuncOp"> {
Expand All @@ -26,132 +26,4 @@ def ConvertTorchToSCF: Pass<"convert-torch-to-scf", "func::FuncOp"> {
let constructor = "mlir::torch::createConvertTorchToSCFPass()";
}

def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "func::FuncOp"> {
let summary = "Convert recognized Torch ops to Linalg ops";
let description = [{
Convert ATen ops to linalg ops.

This pass's main responsibility is to bridge the world between ops
that safely terminate the program in case of operand shape mismatches
(ATen) and ops where such mismatches are undefined behavior (linalg).

To model the termination of the program for implementing error guards,
we use the `cf.assert` op.
This is a design decision that is at variance from other passes in the
ecosystem, which use the
`shape` dialect's witness system (`shape.cstr_*` family of ops feeding into
`shape.assuming` regions). This is a change in design decisions
from those passes (which the authors of this pass have contributed to).
The reasons for this change are heuristic, but boil down to:
1. The modeling of `shape.assuming` is odd, as it uses a region, which is
not a good fit for modeling error guards. Regions mark a "start" and an
"end" (which is their nesting property). But
modeling assertions in the program doesn't fit into that. For assertions,
only the "start" matters (once tested, a predicate remains true "forever"
-- it doesn't end at the "yield" of the region).
Thus, having regions places arbitrary "end"s that just add IR structure
that has no semantic value for modeling this problem! (and to make things
worse the "end"s, which we don't need, are what require "yielding"
values, which interrupts use-def chains). Consider the different
structural properties of regions:
a. IsolatedFromAbove region:
- "start" interrupts use-def chains,
- "end" interrupts use-def chains
- structurally protects from intra-block upward and downward
code motion
b. Capturing region (like `shape.assuming`):
- "start" does not interrupt use-def chains,
- "end" interrupts use-def chains
- structurally protects from intra-block upward and downward
code motion
c. What we "ideally" want:
- "start" interrupts use-def chains (can be pruned though)
- no "end" IR structure!
- structurally protects from intra-block upward code motion
(but not downward code motion!)
- Observation: We probably can't get all of this, but overall this
problem is much better suited for a "MemorySSA"-like
abstraction, call it "EffectSSA" which is constructed on-demand
based on MLIR's effect modeling system (rather than
`shape.assuming`, which only covers the effects the IR creator
encoded -- with witnesses/`shape.assuming` -- it is easy to forget
to handle effects other than those encoded in the
witness structure).
2. The presence of `shape.assuming` regions tends to create highly nested
IR structures, which don't interoperate well with any other IR
structures, and creates very bulky IR (and IR creation code). In general
if we are going to do anything with anything (e.g. canonicalize) we
end up needing need to either:
a. Flatten the `shape.assuming` IR (defeating the purpose of having
it).
b. Do some sort of shape.assuming "region merging".
c. Have special patterns that handle a subset of special cases (looking
through "yields" and such) and don't generalize.
3. Witnesses tend to encourage non-scalable peephole transformations, which
tend to make analyses/transformations non-robust to the presence of
control flow and side effecting ops (easy to forget to handle side
effects other than those modeled by the witness system).
4. All this code operates on ranked tensors, for which using individual
SSA values for sizes (rather than a "shape type") seems to
work really well at this level of abstraction based on prior experience
in other projects. (unranked code tends to benefit from having a discrete
"shape type" to model shapes).

We will see if we end up needing something like `shape.assuming`, but for
now, it seems likely we can do something simpler and just bypass it. The
design of having an EffectSSA that is constructed on-demand seems very
compelling for modeling effects more broadly.
}];
let constructor = "mlir::torch::createConvertTorchToLinalgPass()";
}

def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> {
let summary = "Convert Torch ops to TOSA ops";
let description = [{
This pass assumes that TOSA ops are responsible for emitting error
guards in case of shape mismatches.
}];
let constructor = "mlir::torch::createConvertTorchToTosaPass()";
}

def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> {
let summary = "Convert recognized Torch ops to TMTensor/Linalg ops";
let description = [{
Convert ATen ops to tmtensor/linalg ops.

This pass is similar to the TorchToLinalg pass; the difference is that this
pass also makes use of TMTensor Dialect, which the former one doesn't.
}];
let constructor = "mlir::torch::createConvertTorchToTMTensorPass()";
}

def ConvertTorchConversionToMLProgram : Pass<"convert-torch-conversion-to-mlprogram", "ModuleOp"> {
let summary = "Convert recognized TorchConversion ops to MLProgram ops";
let description = [{
Convert TorchConversion ops to mlprogram ops.
}];
let constructor = "mlir::torch::createConvertTorchConversionToMLProgramPass()";
}

#ifdef TORCH_MLIR_ENABLE_STABLEHLO
def ConvertTorchToStablehlo : Pass<"convert-torch-to-stablehlo", "func::FuncOp"> {
let summary = "Convert Torch ops to Stablehlo ops";
let description = [{
Convert Torch ops to Stablehlo ops.
}];
let constructor = "mlir::torch::createConvertTorchToStablehloPass()";

// Specify any options.
let options = [
Option<"enableStaticShape", "enable-static-shape", "bool", /*default=*/"false",
"Enable static shape conversion">,
// The i64 calculation is much slower than i32 on some devices, such as
// Nvidia GPU. One can truncate from i64 to i32 since dimension sizes
// are unlikely to exceed the range of i32(4GiB)
Option<"enableI32Index", "enable-i32-index", "bool", /*default=*/"false",
"Enable truncate index from i64 to i32(unsafely)">,
];
}
#endif

#endif // TORCHMLIR_CONVERSION_PASSES

This file was deleted.

4 changes: 4 additions & 0 deletions include/torch-mlir/Conversion/TorchToLinalg/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls)
add_public_tablegen_target(TorchMLIRConversionLinalgPassIncGen)
add_mlir_doc(Passes TorchMLIRConversionLinalgPasses ./ -gen-pass-doc)
41 changes: 41 additions & 0 deletions include/torch-mlir/Conversion/TorchToLinalg/Passes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#ifndef TORCHMLIR_CONVERSION_LINALG_PASSES_H
#define TORCHMLIR_CONVERSION_LINALG_PASSES_H

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include <memory>

namespace mlir {
namespace torch {

/// Creates a pipeline that lowers from the torch backend contract to the
/// linalg-on-tensors backend contract.
void createTorchBackendToLinalgOnTensorsBackendPipeline(OpPassManager &pm);

std::unique_ptr<OperationPass<ModuleOp>>
createVerifyLinalgOnTensorsBackendContractPass();

std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToLinalgPass();

std::unique_ptr<OperationPass<ModuleOp>>
createConvertTorchConversionToMLProgramPass();

std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTMTensorPass();

/// Registers all torch-mlir conversion passes.
void registerLinalgConversionPasses();

} // namespace torch
} // namespace mlir

#endif // TORCHMLIR_CONVERSION_PASSES_H
Loading