From 2c3f817b940a6308c23d4ee73a93d900df5401ab Mon Sep 17 00:00:00 2001 From: Martin Erhart Date: Thu, 22 Feb 2024 11:03:21 +0100 Subject: [PATCH] [SMT] Add function application operation, function and uninterpreted sort types --- include/circt/Dialect/SMT/SMTArrayOps.td | 2 +- include/circt/Dialect/SMT/SMTOps.td | 60 +++++++++++++-- include/circt/Dialect/SMT/SMTTypes.h | 3 + include/circt/Dialect/SMT/SMTTypes.td | 78 ++++++++++++++++++++ lib/Conversion/VerifToSMT/VerifToSMT.cpp | 3 +- lib/Dialect/SMT/SMTOps.cpp | 8 +- lib/Dialect/SMT/SMTTypes.cpp | 34 ++++++++- test/Conversion/VerifToSMT/verif-to-smt.mlir | 6 +- test/Dialect/SMT/basic.mlir | 27 +++++-- test/Dialect/SMT/core-errors.mlir | 45 ++++++++++- test/Dialect/SMT/cse-test.mlir | 10 +-- 11 files changed, 245 insertions(+), 31 deletions(-) diff --git a/include/circt/Dialect/SMT/SMTArrayOps.td b/include/circt/Dialect/SMT/SMTArrayOps.td index 9afde82f7b5..3cbb10d6790 100644 --- a/include/circt/Dialect/SMT/SMTArrayOps.td +++ b/include/circt/Dialect/SMT/SMTArrayOps.td @@ -75,7 +75,7 @@ def ArrayBroadcastOp : SMTArrayOp<"broadcast", [ This operation represents a broadcast of the 'value' operand to all indices of the array. It is equivalent to ``` - %0 = smt.declare_const "array" : !smt.array<[!smt.int -> !smt.bool]> + %0 = smt.declare "array" : !smt.array<[!smt.int -> !smt.bool]> %1 = smt.forall ["idx"] { ^bb0(%idx: !smt.int): %2 = smt.array.select %0[%idx] : !smt.array<[!smt.int -> !smt.bool]> diff --git a/include/circt/Dialect/SMT/SMTOps.td b/include/circt/Dialect/SMT/SMTOps.td index f8be6695b80..1d1c381e1b0 100644 --- a/include/circt/Dialect/SMT/SMTOps.td +++ b/include/circt/Dialect/SMT/SMTOps.td @@ -21,20 +21,38 @@ include "mlir/Interfaces/ControlFlowInterfaces.td" class SMTOp traits = []> : Op; -def DeclareConstOp : SMTOp<"declare_const", [ +def DeclareOp : SMTOp<"declare", [ DeclareOpInterfaceMethods ]> { let summary = "declare a symbolic value of a given sort"; let description = [{ - This operation declares a symbolic value just as the `declare-const` - statement in SMT-LIB 2.6. The result type determines the SMT sort of the - symbolic value. The returned value can then be used to refer to the symbolic - value instead of using the identifier like in SMT-LIB. + This operation declares a symbolic value just as the `declare-const` and + `declare-func` statements in SMT-LIB 2.6. The result type determines the SMT + sort of the symbolic value. The returned value can then be used to refer to + the symbolic value instead of using the identifier like in SMT-LIB. The optionally provided string will be used as a prefix for the newly generated identifier (useful for easier readability when exporting to - SMT-LIB). Each `declare_const` will always provide a unique new symbolic - value even if the identifier strings are the same. + SMT-LIB). Each `declare` will always provide a unique new symbolic value + even if the identifier strings are the same. + + Note that there does not exist a separate operation equivalent to + SMT-LIBs `define-fun` since + ``` + (define-fun f (a Int) Int (-a)) + ``` + is only syntactic sugar for + ``` + %f = smt.declare : !smt.func<(!smt.int) !smt.int> + %0 = smt.forall ["a"] { + ^bb0(%arg0: !smt.int): + %1 = smt.apply_func %f(%arg0) : !smt.func<(!smt.int) !smt.int> + %2 = smt.int.neg %1 + %3 = smt.eq %1, %2 : !smt.int + smt.yield %3 : !smt.bool + } + smt.assert %0 + ``` Note that this operation cannot be marked as Pure since two operations (even with the same identifier string) could then be CSEd, leading to incorrect @@ -99,7 +117,7 @@ def SolverOp : SMTOp<"solver", [ ```mlir %0:2 = smt.solver (%in) {smt.some_attr} : (i8) -> (i8, i32) { ^bb0(%arg0: i8): - %c = smt.declare_const "c" : !smt.bool + %c = smt.declare "c" : !smt.bool smt.assert %c %1 = smt.check sat { %c1_i32 = arith.constant 1 : i32 @@ -190,6 +208,29 @@ def YieldOp : SMTOp<"yield", [ }]>]; } +def ApplyFuncOp : SMTOp<"apply_func", [ + Pure, + TypesMatchWith<"summary", "func", "result", + "cast($_self).getRangeType()">, + RangedTypesMatchWith<"summary", "func", "args", + "cast($_self).getDomainTypes()"> +]> { + let summary = "apply a function"; + let description = [{ + This operation performs a function application as described in the + [SMT-LIB 2.6 standard](https://smtlib.cs.uiowa.edu/papers/smt-lib-reference-v2.6-r2021-05-12.pdf). + It is part of the language itself rather than a theory or logic. + }]; + + let arguments = (ins SMTFuncType:$func, + Variadic:$args); + let results = (outs AnyNonFuncSMTType:$result); + + let assemblyFormat = [{ + $func `(` $args `)` attr-dict `:` qualified(type($func)) + }]; +} + def EqOp : SMTOp<"eq", [Pure, SameTypeOperands]> { let summary = "returns true iff all operands are identical"; let description = [{ @@ -358,6 +399,9 @@ class QuantifierOp : SMTOp { let genVerifyDecl = true; } +def SMTFuncType : SMTTypeDef<"SMTFunc"> { + let mnemonic = "func"; + let description = [{ + This type represents the SMT function sort as described in the + [SMT-LIB 2.6 standard](https://smtlib.cs.uiowa.edu/papers/smt-lib-reference-v2.6-r2021-05-12.pdf). + It is part of the language itself rather than a theory or logic. + + A function in SMT can have an arbitrary domain size, but always has exactly + one range sort. + + Since SMT only supports first-order logic, it is not possible to nest + function types. + + Example: `!smt.func<(!smt.bool, !smt.int) !smt.bool>` is equivalent to + `((Bool Int) Bool)` in SMT-LIB. + }]; + + let parameters = (ins + OptionalArrayRefParameter<"mlir::Type", "domain types">:$domainTypes, + "mlir::Type":$rangeType + ); + + // Note: We are not printing the parentheses when no domain type is present + // because the default MLIR parser thinks it is a builtin function type + // otherwise. + let assemblyFormat = "`<` (`(` $domainTypes^ `)` ` `)? $rangeType `>`"; + + let builders = [ + TypeBuilderWithInferredContext<(ins + "llvm::ArrayRef":$domainTypes, + "mlir::Type":$rangeType), [{ + return $_get(rangeType.getContext(), domainTypes, rangeType); + }]>, + TypeBuilderWithInferredContext<(ins "mlir::Type":$rangeType), [{ + return $_get(rangeType.getContext(), + llvm::ArrayRef{}, rangeType); + }]> + ]; + + let genVerifyDecl = true; +} + +def SortType : SMTTypeDef<"Sort"> { + let mnemonic = "sort"; + let description = [{ + This type represents uninterpreted sorts. The usage of a type like + `!smt.sort<"sort_name"[!smt.bool, !smt.sort<"other_sort">]>` implies a + `declare-sort sort_name 2` and a `declare-sort other_sort 0` in SMT-LIB. + This type represents concrete use-sites of such such declared sorts, in this + particular case it would be equivalent to `(sort_name Bool other_sort)` in + SMT-LIB. More details about the semantics can be found in the + [SMT-LIB 2.6 standard](https://smtlib.cs.uiowa.edu/papers/smt-lib-reference-v2.6-r2021-05-12.pdf). + }]; + + let parameters = (ins + "mlir::StringAttr":$identifier, + OptionalArrayRefParameter<"mlir::Type", "sort parameters">:$sortParams + ); + + let assemblyFormat = "`<` $identifier (`[` $sortParams^ `]`)? `>`"; + + let builders = [ + TypeBuilder<(ins "llvm::StringRef":$identifier, + "llvm::ArrayRef":$sortParams), [{ + return $_get($_ctxt, mlir::StringAttr::get($_ctxt, identifier), + sortParams); + }]>, + TypeBuilder<(ins "llvm::StringRef":$identifier), [{ + return $_get($_ctxt, mlir::StringAttr::get($_ctxt, identifier), + llvm::ArrayRef{}); + }]>, + ]; + + let genVerifyDecl = true; +} + def AnySMTType : Type, "any SMT value type">; +def AnyNonFuncSMTType : Type, + "any non-function SMT value type">; def AnyNonSMTType : Type, "any non-smt type">; #endif // CIRCT_DIALECT_SMT_SMTTYPES_TD diff --git a/lib/Conversion/VerifToSMT/VerifToSMT.cpp b/lib/Conversion/VerifToSMT/VerifToSMT.cpp index 6cc381af749..c721f3438a0 100644 --- a/lib/Conversion/VerifToSMT/VerifToSMT.cpp +++ b/lib/Conversion/VerifToSMT/VerifToSMT.cpp @@ -85,8 +85,7 @@ struct LogicEquivalenceCheckingOpConversion // Second, create the symbolic values we replace the block arguments with SmallVector inputs; for (auto arg : adaptor.getFirstCircuit().getArguments()) - inputs.push_back( - rewriter.create(loc, arg.getType())); + inputs.push_back(rewriter.create(loc, arg.getType())); // Third, inline the blocks // Note: the argument value replacement does not happen immediately, but diff --git a/lib/Dialect/SMT/SMTOps.cpp b/lib/Dialect/SMT/SMTOps.cpp index 2a97a558bd9..7a1fa8afa08 100644 --- a/lib/Dialect/SMT/SMTOps.cpp +++ b/lib/Dialect/SMT/SMTOps.cpp @@ -44,10 +44,10 @@ OpFoldResult BVConstantOp::fold(FoldAdaptor adaptor) { } //===----------------------------------------------------------------------===// -// DeclareConstOp +// DeclareOp //===----------------------------------------------------------------------===// -void DeclareConstOp::getAsmResultNames( +void DeclareOp::getAsmResultNames( function_ref setNameFn) { setNameFn(getResult(), getNamePrefix().has_value() ? *getNamePrefix() : ""); } @@ -322,6 +322,10 @@ static LogicalResult verifyQuantifierRegions(QuantifierOp op) { if (op.getBody().getNumArguments() != op.getBoundVarNames().size()) return op.emitOpError( "number of bound variable names must match number of block arguments"); + if (!llvm::all_of(op.getBody().getArgumentTypes(), isAnyNonFuncSMTValueType)) + return op.emitOpError() + << "bound variables must by any non-function SMT value"; + if (op.getBody().front().getTerminator()->getNumOperands() != 1) return op.emitOpError("must have exactly one yielded value"); if (!isa( diff --git a/lib/Dialect/SMT/SMTTypes.cpp b/lib/Dialect/SMT/SMTTypes.cpp index ffdda3cf728..ad1c0063265 100644 --- a/lib/Dialect/SMT/SMTTypes.cpp +++ b/lib/Dialect/SMT/SMTTypes.cpp @@ -26,8 +26,12 @@ void SMTDialect::registerTypes() { >(); } +bool smt::isAnyNonFuncSMTValueType(Type type) { + return isa(type); +} + bool smt::isAnySMTValueType(Type type) { - return isa(type); + return isAnyNonFuncSMTValueType(type) || isa(type); } //===----------------------------------------------------------------------===// @@ -55,3 +59,31 @@ LogicalResult ArrayType::verify(function_ref emitError, return success(); } + +//===----------------------------------------------------------------------===// +// SMTFuncType +//===----------------------------------------------------------------------===// + +LogicalResult SMTFuncType::verify(function_ref emitError, + ArrayRef domainTypes, Type rangeType) { + if (!llvm::all_of(domainTypes, isAnyNonFuncSMTValueType)) + return emitError() << "domain types must be any non-function SMT type"; + if (!isAnySMTValueType(rangeType)) + return emitError() << "range type must be any non-function SMT type"; + + return success(); +} + +//===----------------------------------------------------------------------===// +// SortType +//===----------------------------------------------------------------------===// + +LogicalResult SortType::verify(function_ref emitError, + StringAttr identifier, + ArrayRef sortParams) { + if (!llvm::all_of(sortParams, isAnyNonFuncSMTValueType)) + return emitError() + << "sort parameter types must be any non-function SMT type"; + + return success(); +} diff --git a/test/Conversion/VerifToSMT/verif-to-smt.mlir b/test/Conversion/VerifToSMT/verif-to-smt.mlir index c07a5fff5d6..0b75b3d8fde 100644 --- a/test/Conversion/VerifToSMT/verif-to-smt.mlir +++ b/test/Conversion/VerifToSMT/verif-to-smt.mlir @@ -12,9 +12,9 @@ func.func @test(%arg0: !smt.bv<1>) -> (i1, i1, i1) { // CHECK: [[EQ:%.+]] = smt.solver() : () -> i1 // CHECK: [[TRUE:%.+]] = arith.constant true // CHECK: [[FALSE:%.+]] = arith.constant false - // CHECK: [[IN0:%.+]] = smt.declare_const : !smt.bv<32> + // CHECK: [[IN0:%.+]] = smt.declare : !smt.bv<32> // CHECK: [[V0:%.+]] = builtin.unrealized_conversion_cast [[IN0]] : !smt.bv<32> to i32 - // CHECK: [[IN1:%.+]] = smt.declare_const : !smt.bv<32> + // CHECK: [[IN1:%.+]] = smt.declare : !smt.bv<32> // CHECK: [[V1:%.+]] = builtin.unrealized_conversion_cast [[IN1]] : !smt.bv<32> to i32 // CHECK: [[V2:%.+]]:2 = "some_op"([[V0]], [[V1]]) : (i32, i32) -> (i32, i32) // CHECK: [[V3:%.+]] = builtin.unrealized_conversion_cast [[V2]]#0 : i32 to !smt.bv<32> @@ -38,7 +38,7 @@ func.func @test(%arg0: !smt.bv<1>) -> (i1, i1, i1) { } // CHECK: [[EQ2:%.+]] = smt.solver() : () -> i1 - // CHECK: [[V9:%.+]] = smt.declare_const : !smt.bv<32> + // CHECK: [[V9:%.+]] = smt.declare : !smt.bv<32> // CHECK: [[V10:%.+]] = smt.distinct [[V9]], [[V9]] : !smt.bv<32> // CHECK: smt.assert [[V10]] %2 = verif.lec first { diff --git a/test/Dialect/SMT/basic.mlir b/test/Dialect/SMT/basic.mlir index 6c68cf9fd0f..34a4a87be00 100644 --- a/test/Dialect/SMT/basic.mlir +++ b/test/Dialect/SMT/basic.mlir @@ -1,18 +1,24 @@ // RUN: circt-opt %s | circt-opt | FileCheck %s // CHECK-LABEL: func @types -// CHECK-SAME: (%{{.*}}: !smt.bool, %{{.*}}: !smt.bv<32>, %{{.*}}: !smt.int) -func.func @types(%arg0: !smt.bool, %arg1: !smt.bv<32>, %arg2: !smt.int) { +// CHECK-SAME: (%{{.*}}: !smt.bool, %{{.*}}: !smt.bv<32>, %{{.*}}: !smt.int, %{{.*}}: !smt.sort<"uninterpreted_sort">, %{{.*}}: !smt.sort<"uninterpreted_sort"[!smt.bool, !smt.int]>, %{{.*}}: !smt.func<(!smt.bool, !smt.bool) !smt.bool>, %{{.*}}: !smt.func) +func.func @types(%arg0: !smt.bool, %arg1: !smt.bv<32>, %arg2: !smt.int, %arg3: !smt.sort<"uninterpreted_sort">, %arg4: !smt.sort<"uninterpreted_sort"[!smt.bool, !smt.int]>, %arg5: !smt.func<(!smt.bool, !smt.bool) !smt.bool>, %arg6: !smt.func) { return } func.func @core(%in: i8) { - // CHECK: %a = smt.declare_const "a" {smt.some_attr} : !smt.bool - %a = smt.declare_const "a" {smt.some_attr} : !smt.bool - // CHECK: smt.declare_const {smt.some_attr} : !smt.bv<32> - %b = smt.declare_const {smt.some_attr} : !smt.bv<32> - // CHECK: smt.declare_const {smt.some_attr} : !smt.int - %c = smt.declare_const {smt.some_attr} : !smt.int + // CHECK: %a = smt.declare "a" {smt.some_attr} : !smt.bool + %a = smt.declare "a" {smt.some_attr} : !smt.bool + // CHECK: smt.declare {smt.some_attr} : !smt.bv<32> + %b = smt.declare {smt.some_attr} : !smt.bv<32> + // CHECK: smt.declare {smt.some_attr} : !smt.int + %c = smt.declare {smt.some_attr} : !smt.int + // CHECK: smt.declare {smt.some_attr} : !smt.sort<"uninterpreted_sort"> + %d = smt.declare {smt.some_attr} : !smt.sort<"uninterpreted_sort"> + // CHECK: smt.declare {smt.some_attr} : !smt.func<(!smt.int, !smt.bool) !smt.bool> + %e = smt.declare {smt.some_attr} : !smt.func<(!smt.int, !smt.bool) !smt.bool> + // CHECK: smt.declare {smt.some_attr} : !smt.func + %f = smt.declare {smt.some_attr} : !smt.func // CHECK: smt.constant true {smt.some_attr} %true = smt.constant true {smt.some_attr} @@ -82,6 +88,11 @@ func.func @core(%in: i8) { // CHECK: %{{.*}} = smt.implies %{{.*}}, %{{.*}} {smt.some_attr} %10 = smt.implies %a, %a {smt.some_attr} + // CHECK: smt.apply_func %{{.*}}(%{{.*}}, %{{.*}}) {smt.some_attr} : !smt.func<(!smt.int, !smt.bool) !smt.bool> + %11 = smt.apply_func %e(%c, %a) {smt.some_attr} : !smt.func<(!smt.int, !smt.bool) !smt.bool> + // CHECK: smt.apply_func %{{.*}}() {smt.some_attr} : !smt.func + %12 = smt.apply_func %f() {smt.some_attr} : !smt.func + return } diff --git a/test/Dialect/SMT/core-errors.mlir b/test/Dialect/SMT/core-errors.mlir index 7fc445eec28..59320435176 100644 --- a/test/Dialect/SMT/core-errors.mlir +++ b/test/Dialect/SMT/core-errors.mlir @@ -25,7 +25,7 @@ func.func @no_smt_value_enters_solver(%arg0: !smt.bool) { func.func @no_smt_value_exits_solver() { // expected-error @below {{result #0 must be variadic of any non-smt type, but got '!smt.bool'}} %0 = smt.solver() : () -> !smt.bool { - %a = smt.declare_const "a" : !smt.bool + %a = smt.declare "a" : !smt.bool smt.yield %a : !smt.bool } return @@ -409,3 +409,46 @@ func.func @forall_patterns_region_block_args_used_at_least_once() { } return } + +// ----- + +func.func @exists_bound_variable_type_invalid() { + // expected-error @below {{bound variables must by any non-function SMT value}} + %1 = smt.exists ["a", "b"] { + ^bb0(%arg2: !smt.func<(!smt.int) !smt.int>, %arg3: !smt.bool): + smt.yield %arg3 : !smt.bool + } + return +} + +// ----- + +func.func @forall_bound_variable_type_invalid() { + // expected-error @below {{bound variables must by any non-function SMT value}} + %1 = smt.forall ["a", "b"] { + ^bb0(%arg2: !smt.func<(!smt.int) !smt.int>, %arg3: !smt.bool): + smt.yield %arg3 : !smt.bool + } + return +} + +// ----- + +// expected-error @below {{domain types must be any non-function SMT type}} +func.func @func_domain_no_smt_type(%arg0: !smt.func<(i32) !smt.bool>) { + return +} + +// ----- + +// expected-error @below {{range type must be any non-function SMT type}} +func.func @func_range_no_smt_type(%arg0: !smt.func<(!smt.bool) i32>) { + return +} + +// ----- + +// expected-error @below {{sort parameter types must be any non-function SMT type}} +func.func @sort_type_no_smt_type(%arg0: !smt.sort<"sortname"[i32]>) { + return +} diff --git a/test/Dialect/SMT/cse-test.mlir b/test/Dialect/SMT/cse-test.mlir index c1a356ce1db..82315c14c03 100644 --- a/test/Dialect/SMT/cse-test.mlir +++ b/test/Dialect/SMT/cse-test.mlir @@ -1,12 +1,12 @@ // RUN: circt-opt %s --cse | FileCheck %s func.func @declare_const_cse(%in: i8) -> (!smt.bool, !smt.bool){ - // CHECK: smt.declare_const "a" : !smt.bool - %a = smt.declare_const "a" : !smt.bool - // CHECK-NEXT: smt.declare_const "a" : !smt.bool - %b = smt.declare_const "a" : !smt.bool + // CHECK: smt.declare "a" : !smt.bool + %a = smt.declare "a" : !smt.bool + // CHECK-NEXT: smt.declare "a" : !smt.bool + %b = smt.declare "a" : !smt.bool // CHECK-NEXT: return - %c = smt.declare_const "a" : !smt.bool + %c = smt.declare "a" : !smt.bool return %a, %b : !smt.bool, !smt.bool }