Skip to content

Commit

Permalink
[SMT] Add function application operation, function and uninterpreted …
Browse files Browse the repository at this point in the history
…sort types
  • Loading branch information
maerhart committed Mar 18, 2024
1 parent 672949c commit 2c3f817
Show file tree
Hide file tree
Showing 11 changed files with 245 additions and 31 deletions.
2 changes: 1 addition & 1 deletion include/circt/Dialect/SMT/SMTArrayOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def ArrayBroadcastOp : SMTArrayOp<"broadcast", [
This operation represents a broadcast of the 'value' operand to all indices
of the array. It is equivalent to
```
%0 = smt.declare_const "array" : !smt.array<[!smt.int -> !smt.bool]>
%0 = smt.declare "array" : !smt.array<[!smt.int -> !smt.bool]>
%1 = smt.forall ["idx"] {
^bb0(%idx: !smt.int):
%2 = smt.array.select %0[%idx] : !smt.array<[!smt.int -> !smt.bool]>
Expand Down
60 changes: 52 additions & 8 deletions include/circt/Dialect/SMT/SMTOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,38 @@ include "mlir/Interfaces/ControlFlowInterfaces.td"
class SMTOp<string mnemonic, list<Trait> traits = []> :
Op<SMTDialect, mnemonic, traits>;

def DeclareConstOp : SMTOp<"declare_const", [
def DeclareOp : SMTOp<"declare", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
]> {
let summary = "declare a symbolic value of a given sort";
let description = [{
This operation declares a symbolic value just as the `declare-const`
statement in SMT-LIB 2.6. The result type determines the SMT sort of the
symbolic value. The returned value can then be used to refer to the symbolic
value instead of using the identifier like in SMT-LIB.
This operation declares a symbolic value just as the `declare-const` and
`declare-func` statements in SMT-LIB 2.6. The result type determines the SMT
sort of the symbolic value. The returned value can then be used to refer to
the symbolic value instead of using the identifier like in SMT-LIB.

The optionally provided string will be used as a prefix for the newly
generated identifier (useful for easier readability when exporting to
SMT-LIB). Each `declare_const` will always provide a unique new symbolic
value even if the identifier strings are the same.
SMT-LIB). Each `declare` will always provide a unique new symbolic value
even if the identifier strings are the same.

Note that there does not exist a separate operation equivalent to
SMT-LIBs `define-fun` since
```
(define-fun f (a Int) Int (-a))
```
is only syntactic sugar for
```
%f = smt.declare : !smt.func<(!smt.int) !smt.int>
%0 = smt.forall ["a"] {
^bb0(%arg0: !smt.int):
%1 = smt.apply_func %f(%arg0) : !smt.func<(!smt.int) !smt.int>
%2 = smt.int.neg %1
%3 = smt.eq %1, %2 : !smt.int
smt.yield %3 : !smt.bool
}
smt.assert %0
```

Note that this operation cannot be marked as Pure since two operations (even
with the same identifier string) could then be CSEd, leading to incorrect
Expand Down Expand Up @@ -99,7 +117,7 @@ def SolverOp : SMTOp<"solver", [
```mlir
%0:2 = smt.solver (%in) {smt.some_attr} : (i8) -> (i8, i32) {
^bb0(%arg0: i8):
%c = smt.declare_const "c" : !smt.bool
%c = smt.declare "c" : !smt.bool
smt.assert %c
%1 = smt.check sat {
%c1_i32 = arith.constant 1 : i32
Expand Down Expand Up @@ -190,6 +208,29 @@ def YieldOp : SMTOp<"yield", [
}]>];
}

def ApplyFuncOp : SMTOp<"apply_func", [
Pure,
TypesMatchWith<"summary", "func", "result",
"cast<SMTFuncType>($_self).getRangeType()">,
RangedTypesMatchWith<"summary", "func", "args",
"cast<SMTFuncType>($_self).getDomainTypes()">
]> {
let summary = "apply a function";
let description = [{
This operation performs a function application as described in the
[SMT-LIB 2.6 standard](https://smtlib.cs.uiowa.edu/papers/smt-lib-reference-v2.6-r2021-05-12.pdf).
It is part of the language itself rather than a theory or logic.
}];

let arguments = (ins SMTFuncType:$func,
Variadic<AnyNonFuncSMTType>:$args);
let results = (outs AnyNonFuncSMTType:$result);

let assemblyFormat = [{
$func `(` $args `)` attr-dict `:` qualified(type($func))
}];
}

def EqOp : SMTOp<"eq", [Pure, SameTypeOperands]> {
let summary = "returns true iff all operands are identical";
let description = [{
Expand Down Expand Up @@ -358,6 +399,9 @@ class QuantifierOp<string mnemonic> : SMTOp<mnemonic, [
'patterns' region can yield an arbitrary number (but at least one) of SMT
values.

The bound variables can be any SMT type except of functions, since SMT only
supports first-order logic.

The 'no_patterns' attribute is only allowed when no 'patterns' region is
specified and forbids the solver to generate and use patterns for this
quantifier.
Expand Down
3 changes: 3 additions & 0 deletions include/circt/Dialect/SMT/SMTTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ namespace smt {
/// Returns whether the given type is an SMT value type.
bool isAnySMTValueType(mlir::Type type);

/// Returns whether the given type is an SMT value type (excluding functions).
bool isAnyNonFuncSMTValueType(mlir::Type type);

} // namespace smt
} // namespace circt

Expand Down
78 changes: 78 additions & 0 deletions include/circt/Dialect/SMT/SMTTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,86 @@ def ArrayType : SMTTypeDef<"Array"> {
let genVerifyDecl = true;
}

def SMTFuncType : SMTTypeDef<"SMTFunc"> {
let mnemonic = "func";
let description = [{
This type represents the SMT function sort as described in the
[SMT-LIB 2.6 standard](https://smtlib.cs.uiowa.edu/papers/smt-lib-reference-v2.6-r2021-05-12.pdf).
It is part of the language itself rather than a theory or logic.

A function in SMT can have an arbitrary domain size, but always has exactly
one range sort.

Since SMT only supports first-order logic, it is not possible to nest
function types.

Example: `!smt.func<(!smt.bool, !smt.int) !smt.bool>` is equivalent to
`((Bool Int) Bool)` in SMT-LIB.
}];

let parameters = (ins
OptionalArrayRefParameter<"mlir::Type", "domain types">:$domainTypes,
"mlir::Type":$rangeType
);

// Note: We are not printing the parentheses when no domain type is present
// because the default MLIR parser thinks it is a builtin function type
// otherwise.
let assemblyFormat = "`<` (`(` $domainTypes^ `)` ` `)? $rangeType `>`";

let builders = [
TypeBuilderWithInferredContext<(ins
"llvm::ArrayRef<mlir::Type>":$domainTypes,
"mlir::Type":$rangeType), [{
return $_get(rangeType.getContext(), domainTypes, rangeType);
}]>,
TypeBuilderWithInferredContext<(ins "mlir::Type":$rangeType), [{
return $_get(rangeType.getContext(),
llvm::ArrayRef<mlir::Type>{}, rangeType);
}]>
];

let genVerifyDecl = true;
}

def SortType : SMTTypeDef<"Sort"> {
let mnemonic = "sort";
let description = [{
This type represents uninterpreted sorts. The usage of a type like
`!smt.sort<"sort_name"[!smt.bool, !smt.sort<"other_sort">]>` implies a
`declare-sort sort_name 2` and a `declare-sort other_sort 0` in SMT-LIB.
This type represents concrete use-sites of such such declared sorts, in this
particular case it would be equivalent to `(sort_name Bool other_sort)` in
SMT-LIB. More details about the semantics can be found in the
[SMT-LIB 2.6 standard](https://smtlib.cs.uiowa.edu/papers/smt-lib-reference-v2.6-r2021-05-12.pdf).
}];

let parameters = (ins
"mlir::StringAttr":$identifier,
OptionalArrayRefParameter<"mlir::Type", "sort parameters">:$sortParams
);

let assemblyFormat = "`<` $identifier (`[` $sortParams^ `]`)? `>`";

let builders = [
TypeBuilder<(ins "llvm::StringRef":$identifier,
"llvm::ArrayRef<mlir::Type>":$sortParams), [{
return $_get($_ctxt, mlir::StringAttr::get($_ctxt, identifier),
sortParams);
}]>,
TypeBuilder<(ins "llvm::StringRef":$identifier), [{
return $_get($_ctxt, mlir::StringAttr::get($_ctxt, identifier),
llvm::ArrayRef<mlir::Type>{});
}]>,
];

let genVerifyDecl = true;
}

def AnySMTType : Type<CPred<"smt::isAnySMTValueType($_self)">,
"any SMT value type">;
def AnyNonFuncSMTType : Type<CPred<"smt::isAnyNonFuncSMTValueType($_self)">,
"any non-function SMT value type">;
def AnyNonSMTType : Type<Neg<AnySMTType.predicate>, "any non-smt type">;

#endif // CIRCT_DIALECT_SMT_SMTTYPES_TD
3 changes: 1 addition & 2 deletions lib/Conversion/VerifToSMT/VerifToSMT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,7 @@ struct LogicEquivalenceCheckingOpConversion
// Second, create the symbolic values we replace the block arguments with
SmallVector<Value> inputs;
for (auto arg : adaptor.getFirstCircuit().getArguments())
inputs.push_back(
rewriter.create<smt::DeclareConstOp>(loc, arg.getType()));
inputs.push_back(rewriter.create<smt::DeclareOp>(loc, arg.getType()));

// Third, inline the blocks
// Note: the argument value replacement does not happen immediately, but
Expand Down
8 changes: 6 additions & 2 deletions lib/Dialect/SMT/SMTOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ OpFoldResult BVConstantOp::fold(FoldAdaptor adaptor) {
}

//===----------------------------------------------------------------------===//
// DeclareConstOp
// DeclareOp
//===----------------------------------------------------------------------===//

void DeclareConstOp::getAsmResultNames(
void DeclareOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), getNamePrefix().has_value() ? *getNamePrefix() : "");
}
Expand Down Expand Up @@ -322,6 +322,10 @@ static LogicalResult verifyQuantifierRegions(QuantifierOp op) {
if (op.getBody().getNumArguments() != op.getBoundVarNames().size())
return op.emitOpError(
"number of bound variable names must match number of block arguments");
if (!llvm::all_of(op.getBody().getArgumentTypes(), isAnyNonFuncSMTValueType))
return op.emitOpError()
<< "bound variables must by any non-function SMT value";

if (op.getBody().front().getTerminator()->getNumOperands() != 1)
return op.emitOpError("must have exactly one yielded value");
if (!isa<BoolType>(
Expand Down
34 changes: 33 additions & 1 deletion lib/Dialect/SMT/SMTTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,12 @@ void SMTDialect::registerTypes() {
>();
}

bool smt::isAnyNonFuncSMTValueType(Type type) {
return isa<BoolType, BitVectorType, ArrayType, IntType, SortType>(type);
}

bool smt::isAnySMTValueType(Type type) {
return isa<BoolType, BitVectorType, ArrayType, IntType>(type);
return isAnyNonFuncSMTValueType(type) || isa<SMTFuncType>(type);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -55,3 +59,31 @@ LogicalResult ArrayType::verify(function_ref<InFlightDiagnostic()> emitError,

return success();
}

//===----------------------------------------------------------------------===//
// SMTFuncType
//===----------------------------------------------------------------------===//

LogicalResult SMTFuncType::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<Type> domainTypes, Type rangeType) {
if (!llvm::all_of(domainTypes, isAnyNonFuncSMTValueType))
return emitError() << "domain types must be any non-function SMT type";
if (!isAnySMTValueType(rangeType))
return emitError() << "range type must be any non-function SMT type";

return success();
}

//===----------------------------------------------------------------------===//
// SortType
//===----------------------------------------------------------------------===//

LogicalResult SortType::verify(function_ref<InFlightDiagnostic()> emitError,
StringAttr identifier,
ArrayRef<Type> sortParams) {
if (!llvm::all_of(sortParams, isAnyNonFuncSMTValueType))
return emitError()
<< "sort parameter types must be any non-function SMT type";

return success();
}
6 changes: 3 additions & 3 deletions test/Conversion/VerifToSMT/verif-to-smt.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ func.func @test(%arg0: !smt.bv<1>) -> (i1, i1, i1) {
// CHECK: [[EQ:%.+]] = smt.solver() : () -> i1
// CHECK: [[TRUE:%.+]] = arith.constant true
// CHECK: [[FALSE:%.+]] = arith.constant false
// CHECK: [[IN0:%.+]] = smt.declare_const : !smt.bv<32>
// CHECK: [[IN0:%.+]] = smt.declare : !smt.bv<32>
// CHECK: [[V0:%.+]] = builtin.unrealized_conversion_cast [[IN0]] : !smt.bv<32> to i32
// CHECK: [[IN1:%.+]] = smt.declare_const : !smt.bv<32>
// CHECK: [[IN1:%.+]] = smt.declare : !smt.bv<32>
// CHECK: [[V1:%.+]] = builtin.unrealized_conversion_cast [[IN1]] : !smt.bv<32> to i32
// CHECK: [[V2:%.+]]:2 = "some_op"([[V0]], [[V1]]) : (i32, i32) -> (i32, i32)
// CHECK: [[V3:%.+]] = builtin.unrealized_conversion_cast [[V2]]#0 : i32 to !smt.bv<32>
Expand All @@ -38,7 +38,7 @@ func.func @test(%arg0: !smt.bv<1>) -> (i1, i1, i1) {
}

// CHECK: [[EQ2:%.+]] = smt.solver() : () -> i1
// CHECK: [[V9:%.+]] = smt.declare_const : !smt.bv<32>
// CHECK: [[V9:%.+]] = smt.declare : !smt.bv<32>
// CHECK: [[V10:%.+]] = smt.distinct [[V9]], [[V9]] : !smt.bv<32>
// CHECK: smt.assert [[V10]]
%2 = verif.lec first {
Expand Down
27 changes: 19 additions & 8 deletions test/Dialect/SMT/basic.mlir
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
// RUN: circt-opt %s | circt-opt | FileCheck %s

// CHECK-LABEL: func @types
// CHECK-SAME: (%{{.*}}: !smt.bool, %{{.*}}: !smt.bv<32>, %{{.*}}: !smt.int)
func.func @types(%arg0: !smt.bool, %arg1: !smt.bv<32>, %arg2: !smt.int) {
// CHECK-SAME: (%{{.*}}: !smt.bool, %{{.*}}: !smt.bv<32>, %{{.*}}: !smt.int, %{{.*}}: !smt.sort<"uninterpreted_sort">, %{{.*}}: !smt.sort<"uninterpreted_sort"[!smt.bool, !smt.int]>, %{{.*}}: !smt.func<(!smt.bool, !smt.bool) !smt.bool>, %{{.*}}: !smt.func<!smt.bool>)
func.func @types(%arg0: !smt.bool, %arg1: !smt.bv<32>, %arg2: !smt.int, %arg3: !smt.sort<"uninterpreted_sort">, %arg4: !smt.sort<"uninterpreted_sort"[!smt.bool, !smt.int]>, %arg5: !smt.func<(!smt.bool, !smt.bool) !smt.bool>, %arg6: !smt.func<!smt.bool>) {
return
}

func.func @core(%in: i8) {
// CHECK: %a = smt.declare_const "a" {smt.some_attr} : !smt.bool
%a = smt.declare_const "a" {smt.some_attr} : !smt.bool
// CHECK: smt.declare_const {smt.some_attr} : !smt.bv<32>
%b = smt.declare_const {smt.some_attr} : !smt.bv<32>
// CHECK: smt.declare_const {smt.some_attr} : !smt.int
%c = smt.declare_const {smt.some_attr} : !smt.int
// CHECK: %a = smt.declare "a" {smt.some_attr} : !smt.bool
%a = smt.declare "a" {smt.some_attr} : !smt.bool
// CHECK: smt.declare {smt.some_attr} : !smt.bv<32>
%b = smt.declare {smt.some_attr} : !smt.bv<32>
// CHECK: smt.declare {smt.some_attr} : !smt.int
%c = smt.declare {smt.some_attr} : !smt.int
// CHECK: smt.declare {smt.some_attr} : !smt.sort<"uninterpreted_sort">
%d = smt.declare {smt.some_attr} : !smt.sort<"uninterpreted_sort">
// CHECK: smt.declare {smt.some_attr} : !smt.func<(!smt.int, !smt.bool) !smt.bool>
%e = smt.declare {smt.some_attr} : !smt.func<(!smt.int, !smt.bool) !smt.bool>
// CHECK: smt.declare {smt.some_attr} : !smt.func<!smt.bool>
%f = smt.declare {smt.some_attr} : !smt.func<!smt.bool>

// CHECK: smt.constant true {smt.some_attr}
%true = smt.constant true {smt.some_attr}
Expand Down Expand Up @@ -82,6 +88,11 @@ func.func @core(%in: i8) {
// CHECK: %{{.*}} = smt.implies %{{.*}}, %{{.*}} {smt.some_attr}
%10 = smt.implies %a, %a {smt.some_attr}

// CHECK: smt.apply_func %{{.*}}(%{{.*}}, %{{.*}}) {smt.some_attr} : !smt.func<(!smt.int, !smt.bool) !smt.bool>
%11 = smt.apply_func %e(%c, %a) {smt.some_attr} : !smt.func<(!smt.int, !smt.bool) !smt.bool>
// CHECK: smt.apply_func %{{.*}}() {smt.some_attr} : !smt.func<!smt.bool>
%12 = smt.apply_func %f() {smt.some_attr} : !smt.func<!smt.bool>

return
}

Expand Down

0 comments on commit 2c3f817

Please sign in to comment.