-
Notifications
You must be signed in to change notification settings - Fork 44
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #478 from j2kun:tensor-ext
PiperOrigin-RevId: 613199009
- Loading branch information
Showing
12 changed files
with
248 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# TensorExt tablegen and headers | ||
|
||
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") | ||
|
||
package( | ||
default_applicable_licenses = ["@heir//:license"], | ||
default_visibility = ["//visibility:public"], | ||
) | ||
|
||
exports_files( | ||
[ | ||
"TensorExtDialect.h", | ||
"TensorExtOps.h", | ||
], | ||
) | ||
|
||
td_library( | ||
name = "td_files", | ||
srcs = [ | ||
"TensorExtDialect.td", | ||
"TensorExtOps.td", | ||
], | ||
# include from the heir-root to enable fully-qualified include-paths | ||
includes = ["../../../.."], | ||
deps = [ | ||
"@llvm-project//mlir:BuiltinDialectTdFiles", | ||
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles", | ||
"@llvm-project//mlir:OpBaseTdFiles", | ||
"@llvm-project//mlir:SideEffectInterfacesTdFiles", | ||
], | ||
) | ||
|
||
gentbl_cc_library( | ||
name = "dialect_inc_gen", | ||
tbl_outs = [ | ||
( | ||
[ | ||
"-gen-dialect-decls", | ||
], | ||
"TensorExtDialect.h.inc", | ||
), | ||
( | ||
[ | ||
"-gen-dialect-defs", | ||
], | ||
"TensorExtDialect.cpp.inc", | ||
), | ||
], | ||
tblgen = "@llvm-project//mlir:mlir-tblgen", | ||
td_file = "TensorExtDialect.td", | ||
deps = [ | ||
":td_files", | ||
], | ||
) | ||
|
||
gentbl_cc_library( | ||
name = "ops_inc_gen", | ||
tbl_outs = [ | ||
( | ||
["-gen-op-decls"], | ||
"TensorExtOps.h.inc", | ||
), | ||
( | ||
["-gen-op-defs"], | ||
"TensorExtOps.cpp.inc", | ||
), | ||
( | ||
["-gen-op-doc"], | ||
"TensorExtOps.md", | ||
), | ||
], | ||
tblgen = "@llvm-project//mlir:mlir-tblgen", | ||
td_file = "TensorExtOps.td", | ||
deps = [ | ||
":dialect_inc_gen", | ||
":td_files", | ||
"@heir//include/Dialect/Polynomial/IR:td_files", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
#ifndef HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtDIALECT_H_ | ||
#define HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtDIALECT_H_ | ||
|
||
#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project | ||
|
||
// Generated headers (block clang-format from messing up order) | ||
#include "include/Dialect/TensorExt/IR/TensorExtDialect.h.inc" | ||
|
||
#endif // HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtDIALECT_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#ifndef HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtDIALECT_TD_ | ||
#define HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtDIALECT_TD_ | ||
|
||
include "mlir/IR/DialectBase.td" | ||
|
||
def TensorExt_Dialect : Dialect { | ||
let name = "tensor_ext"; | ||
let description = [{ | ||
The `tensor_ext` dialect contains operations on plaintext tensors that | ||
correspond to the computation model of certain FHE schemes, but are | ||
unlikely to be upstreamed to MLIR due to their specificity to FHE. | ||
}]; | ||
|
||
let cppNamespace = "::mlir::heir::tensor_ext"; | ||
} | ||
|
||
#endif // HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtDIALECT_TD_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
#ifndef HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtOPS_H_ | ||
#define HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtOPS_H_ | ||
|
||
#include "include/Dialect/TensorExt/IR/TensorExtDialect.h" | ||
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project | ||
#include "mlir/include/mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project | ||
|
||
#define GET_OP_CLASSES | ||
#include "include/Dialect/TensorExt/IR/TensorExtOps.h.inc" | ||
|
||
#endif // HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtOPS_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
#ifndef INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTOPS_TD_ | ||
#define INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTOPS_TD_ | ||
|
||
include "include/Dialect/TensorExt/IR/TensorExtDialect.td" | ||
include "mlir/IR/BuiltinAttributes.td" | ||
include "mlir/IR/CommonTypeConstraints.td" | ||
include "mlir/IR/OpBase.td" | ||
include "mlir/Interfaces/InferTypeOpInterface.td" | ||
include "mlir/Interfaces/SideEffectInterfaces.td" | ||
|
||
|
||
class TensorExt_Op<string mnemonic, list<Trait> traits = []> : | ||
Op<TensorExt_Dialect, mnemonic, traits> { | ||
let cppNamespace = "::mlir::heir::tensor_ext"; | ||
} | ||
|
||
def TensorExt_RotateOp : TensorExt_Op<"rotate", [Pure, AllTypesMatch<["tensor", "output"]>]> { | ||
let summary = "Rotate a tensor some number of indices left."; | ||
let description = [{ | ||
This op represents a left-rotation of a tensor by given number of indices. | ||
Negative shift values are interpreted as right-rotations. | ||
|
||
This corresponds to the `rotate` operation in arithmetic FHE schemes like | ||
BGV. | ||
|
||
Examples: | ||
|
||
```mlir | ||
%0 = ... : tensor<16xi32> | ||
%c7 = arith.constant 7 : i32 | ||
%1 = tensor_ext.rotate %0, %c7 : tensor<16xi32>, i32 | ||
``` | ||
}]; | ||
|
||
let arguments = (ins AnyTensor:$tensor, SignlessIntegerLike:$shift); | ||
let results = (outs AnyTensor:$output); | ||
let assemblyFormat = "operands attr-dict `:` qualified(type($tensor)) `,` type($shift)"; | ||
} | ||
|
||
#endif // INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTOPS_TD_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# TensorExt dialect implementation | ||
|
||
package( | ||
default_applicable_licenses = ["@heir//:license"], | ||
default_visibility = ["//visibility:public"], | ||
) | ||
|
||
cc_library( | ||
name = "Dialect", | ||
srcs = [ | ||
"TensorExtDialect.cpp", | ||
], | ||
hdrs = [ | ||
"@heir//include/Dialect/TensorExt/IR:TensorExtDialect.h", | ||
"@heir//include/Dialect/TensorExt/IR:TensorExtOps.h", | ||
], | ||
deps = [ | ||
":TensorExtOps", | ||
"@heir//include/Dialect/TensorExt/IR:dialect_inc_gen", | ||
"@heir//include/Dialect/TensorExt/IR:ops_inc_gen", | ||
"@llvm-project//mlir:IR", | ||
"@llvm-project//mlir:InferTypeOpInterface", | ||
], | ||
) | ||
|
||
cc_library( | ||
name = "TensorExtOps", | ||
srcs = [ | ||
"TensorExtOps.cpp", | ||
], | ||
hdrs = [ | ||
"@heir//include/Dialect/TensorExt/IR:TensorExtDialect.h", | ||
"@heir//include/Dialect/TensorExt/IR:TensorExtOps.h", | ||
], | ||
deps = [ | ||
"@heir//include/Dialect/TensorExt/IR:dialect_inc_gen", | ||
"@heir//include/Dialect/TensorExt/IR:ops_inc_gen", | ||
"@llvm-project//mlir:IR", | ||
"@llvm-project//mlir:InferTypeOpInterface", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
#include "include/Dialect/TensorExt/IR/TensorExtDialect.h" | ||
|
||
#include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project | ||
|
||
// NOLINTNEXTLINE(misc-include-cleaner): Required to define TensorExtOps | ||
#include "include/Dialect/TensorExt/IR/TensorExtOps.h" | ||
|
||
// Generated definitions | ||
#include "include/Dialect/TensorExt/IR/TensorExtDialect.cpp.inc" | ||
|
||
#define GET_OP_CLASSES | ||
#include "include/Dialect/TensorExt/IR/TensorExtOps.cpp.inc" | ||
|
||
namespace mlir { | ||
namespace heir { | ||
namespace tensor_ext { | ||
|
||
void TensorExtDialect::initialize() { | ||
addOperations< | ||
#define GET_OP_LIST | ||
#include "include/Dialect/TensorExt/IR/TensorExtOps.cpp.inc" | ||
>(); | ||
} | ||
|
||
} // namespace tensor_ext | ||
} // namespace heir | ||
} // namespace mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
#include "include/Dialect/TensorExt/IR/TensorExtOps.h" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
load("//bazel:lit.bzl", "glob_lit_tests") | ||
|
||
package(default_applicable_licenses = ["@heir//:license"]) | ||
|
||
glob_lit_tests( | ||
name = "all_tests", | ||
data = ["@heir//tests:test_utilities"], | ||
driver = "@heir//tests:run_lit.sh", | ||
test_file_exts = ["mlir"], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
// RUN: heir-opt %s | ||
|
||
// Test for syntax | ||
|
||
func.func @test_rotate(%0: tensor<16xi32>) -> tensor<16xi32> { | ||
%c1 = arith.constant 1 : i32 | ||
%1 = tensor_ext.rotate %0, %c1 : tensor<16xi32>, i32 | ||
return %1 : tensor<16xi32> | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters