diff --git a/src/solvers/Makefile b/src/solvers/Makefile index ccf66d4cfa6..6c46d6e048f 100644 --- a/src/solvers/Makefile +++ b/src/solvers/Makefile @@ -128,6 +128,7 @@ SRC = $(BOOLEFORCE_SRC) \ flattening/boolbv_unary_minus.cpp \ flattening/boolbv_union.cpp \ flattening/boolbv_update.cpp \ + flattening/boolbv_update_bit.cpp \ flattening/boolbv_update_bits.cpp \ flattening/boolbv_width.cpp \ flattening/boolbv_with.cpp \ diff --git a/src/solvers/flattening/boolbv.cpp b/src/solvers/flattening/boolbv.cpp index 3f026d4b3b7..64a9816c9ba 100644 --- a/src/solvers/flattening/boolbv.cpp +++ b/src/solvers/flattening/boolbv.cpp @@ -109,6 +109,8 @@ bvt boolbvt::convert_bitvector(const exprt &expr) return convert_with(to_with_expr(expr)); else if(expr.id()==ID_update) return convert_update(to_update_expr(expr)); + else if(expr.id() == ID_update_bit) + return convert_update_bit(to_update_bit_expr(expr)); else if(expr.id()==ID_case) return convert_case(expr); else if(expr.id()==ID_cond) diff --git a/src/solvers/flattening/boolbv.h b/src/solvers/flattening/boolbv.h index 5992bdc4b0d..48d32ece23b 100644 --- a/src/solvers/flattening/boolbv.h +++ b/src/solvers/flattening/boolbv.h @@ -39,6 +39,7 @@ class overflow_result_exprt; class replication_exprt; class unary_overflow_exprt; class union_typet; +class update_bit_exprt; class update_bits_exprt; class boolbvt:public arrayst @@ -177,6 +178,7 @@ class boolbvt:public arrayst virtual bvt convert_member(const member_exprt &expr); virtual bvt convert_with(const with_exprt &expr); virtual bvt convert_update(const update_exprt &); + virtual bvt convert_update_bit(const update_bit_exprt &); virtual bvt convert_update_bits(const update_bits_exprt &); virtual bvt convert_case(const exprt &expr); virtual bvt convert_cond(const cond_exprt &); diff --git a/src/solvers/flattening/boolbv_update_bit.cpp b/src/solvers/flattening/boolbv_update_bit.cpp new file mode 100644 index 00000000000..2fd6feee706 --- /dev/null +++ b/src/solvers/flattening/boolbv_update_bit.cpp @@ -0,0 +1,16 @@ +/*******************************************************************\ + +Module: + +Author: Daniel Kroening, dkr@amazon.com + +\*******************************************************************/ + +#include + +#include "boolbv.h" + +bvt boolbvt::convert_update_bit(const update_bit_exprt &expr) +{ + return convert_bv(expr.lower()); +} diff --git a/src/solvers/flattening/boolbv_with.cpp b/src/solvers/flattening/boolbv_with.cpp index f04bf78575a..5eb83b0b3c0 100644 --- a/src/solvers/flattening/boolbv_with.cpp +++ b/src/solvers/flattening/boolbv_with.cpp @@ -22,9 +22,32 @@ bvt boolbvt::convert_with(const with_exprt &expr) type.id() == ID_bv || type.id() == ID_unsignedbv || type.id() == ID_signedbv) { + if(expr.operands().size() > 3) + { + std::size_t s = expr.operands().size(); + + // strip off the trailing two operands + with_exprt tmp = expr; + tmp.operands().resize(s - 2); + + with_exprt new_with_expr( + tmp, expr.operands()[s - 2], expr.operands().back()); + + // recursive call + return convert_with(new_with_expr); + } + PRECONDITION(expr.operands().size() == 3); - return convert_bv( - update_bits_exprt(expr.old(), expr.where(), expr.new_value())); + if(expr.new_value().type().id() == ID_bool) + { + return convert_bv( + update_bit_exprt(expr.old(), expr.where(), expr.new_value())); + } + else + { + return convert_bv( + update_bits_exprt(expr.old(), expr.where(), expr.new_value())); + } } bvt bv = convert_bv(expr.old()); diff --git a/src/solvers/smt2/smt2_conv.cpp b/src/solvers/smt2/smt2_conv.cpp index 5ddc20ce70f..16967c683cc 100644 --- a/src/solvers/smt2/smt2_conv.cpp +++ b/src/solvers/smt2/smt2_conv.cpp @@ -1685,6 +1685,10 @@ void smt2_convt::convert_expr(const exprt &expr) { convert_update(to_update_expr(expr)); } + else if(expr.id() == ID_update_bit) + { + convert_update_bit(to_update_bit_expr(expr)); + } else if(expr.id() == ID_update_bits) { convert_update_bits(to_update_bits_expr(expr)); @@ -4290,8 +4294,16 @@ void smt2_convt::convert_with(const with_exprt &expr) expr_type.id()==ID_unsignedbv || expr_type.id()==ID_signedbv) { - convert_update_bits( - update_bits_exprt(expr.old(), expr.where(), expr.new_value())); + if(expr.new_value().type().id() == ID_bool) + { + convert_update_bit( + update_bit_exprt(expr.old(), expr.where(), expr.new_value())); + } + else + { + convert_update_bits( + update_bits_exprt(expr.old(), expr.where(), expr.new_value())); + } } else UNEXPECTEDCASE( @@ -4306,6 +4318,11 @@ void smt2_convt::convert_update(const update_exprt &expr) SMT2_TODO("smt2_convt::convert_update to be implemented"); } +void smt2_convt::convert_update_bit(const update_bit_exprt &expr) +{ + return convert_expr(expr.lower()); +} + void smt2_convt::convert_update_bits(const update_bits_exprt &expr) { return convert_expr(expr.lower()); diff --git a/src/solvers/smt2/smt2_conv.h b/src/solvers/smt2/smt2_conv.h index 406b18cf8d1..4a043824831 100644 --- a/src/solvers/smt2/smt2_conv.h +++ b/src/solvers/smt2/smt2_conv.h @@ -32,6 +32,7 @@ Author: Daniel Kroening, kroening@kroening.com class floatbv_typecast_exprt; class ieee_float_op_exprt; class union_typet; +class update_bit_exprt; class update_bits_exprt; class smt2_convt : public stack_decision_proceduret @@ -150,6 +151,7 @@ class smt2_convt : public stack_decision_proceduret void convert_with(const with_exprt &expr); void convert_update(const update_exprt &); + void convert_update_bit(const update_bit_exprt &); void convert_update_bits(const update_bits_exprt &); void convert_expr(const exprt &); diff --git a/src/util/bitvector_expr.cpp b/src/util/bitvector_expr.cpp index eb8fa4807a5..c62a50e594c 100644 --- a/src/util/bitvector_expr.cpp +++ b/src/util/bitvector_expr.cpp @@ -42,6 +42,34 @@ extractbits_exprt::extractbits_exprt( from_integer(_lower, integer_typet())); } +exprt update_bit_exprt::lower() const +{ + const auto width = to_bitvector_type(type()).get_width(); + auto src_bv_type = bv_typet(width); + + // build a mask 0...0 1 + auto mask_bv = + make_bvrep(width, [](std::size_t index) { return index == 0; }); + auto mask_expr = constant_exprt(mask_bv, src_bv_type); + + // shift the mask by the index + auto mask_shifted = shl_exprt(mask_expr, index()); + + auto src_masked = bitand_exprt( + typecast_exprt(src(), src_bv_type), bitnot_exprt(mask_shifted)); + + // zero-extend the replacement bit to match src + auto new_value_casted = typecast_exprt( + typecast_exprt(new_value(), unsignedbv_typet(width)), src_bv_type); + + // shift the replacement bits + auto new_value_shifted = shl_exprt(new_value_casted, index()); + + // or the masked src and the shifted replacement bits + return typecast_exprt( + bitor_exprt(src_masked, new_value_shifted), src().type()); +} + exprt update_bits_exprt::lower() const { const auto width = to_bitvector_type(type()).get_width(); diff --git a/src/util/bitvector_expr.h b/src/util/bitvector_expr.h index 3d40b356798..1276bf287ff 100644 --- a/src/util/bitvector_expr.h +++ b/src/util/bitvector_expr.h @@ -531,6 +531,93 @@ inline extractbits_exprt &to_extractbits_expr(exprt &expr) return ret; } +/// \brief Replaces a sub-range of a bit-vector operand +class update_bit_exprt : public expr_protectedt +{ +public: + /// Replaces the bit [\p _index] from \p _src to produce a result of + /// the same type as \p _src. The index counts from the + /// least-significant bit. Updates outside of the range of \p _src + /// yield an expression equal to \p _src. + update_bit_exprt(exprt _src, exprt _index, exprt _new_value) + : expr_protectedt( + ID_update_bit, + _src.type(), + {_src, std::move(_index), std::move(_new_value)}) + { + PRECONDITION(new_value().type().id() == ID_bool); + } + + update_bit_exprt(exprt _src, const std::size_t _index, exprt _new_value); + + exprt &src() + { + return op0(); + } + + exprt &index() + { + return op1(); + } + + exprt &new_value() + { + return op2(); + } + + const exprt &src() const + { + return op0(); + } + + const exprt &index() const + { + return op1(); + } + + const exprt &new_value() const + { + return op2(); + } + + static void check( + const exprt &expr, + const validation_modet vm = validation_modet::INVARIANT) + { + validate_operands(expr, 3, "update_bit must have three operands"); + } + + /// A lowering to masking, shifting, or. + exprt lower() const; +}; + +template <> +inline bool can_cast_expr(const exprt &base) +{ + return base.id() == ID_update_bit; +} + +/// \brief Cast an exprt to an \ref update_bit_exprt +/// +/// \a expr must be known to be \ref update_bit_exprt. +/// +/// \param expr: Source expression +/// \return Object of type \ref update_bit_exprt +inline const update_bit_exprt &to_update_bit_expr(const exprt &expr) +{ + PRECONDITION(expr.id() == ID_update_bit); + update_bit_exprt::check(expr); + return static_cast(expr); +} + +/// \copydoc to_update_bit_expr(const exprt &) +inline update_bit_exprt &to_update_bit_expr(exprt &expr) +{ + PRECONDITION(expr.id() == ID_update_bit); + update_bit_exprt::check(expr); + return static_cast(expr); +} + /// \brief Replaces a sub-range of a bit-vector operand class update_bits_exprt : public expr_protectedt { diff --git a/src/util/irep_ids.def b/src/util/irep_ids.def index c06f6293055..eea285397a5 100644 --- a/src/util/irep_ids.def +++ b/src/util/irep_ids.def @@ -185,6 +185,7 @@ IREP_ID_ONE(exists) IREP_ID_ONE(repeat) IREP_ID_ONE(extractbit) IREP_ID_ONE(extractbits) +IREP_ID_ONE(update_bit) IREP_ID_ONE(update_bits) IREP_ID_TWO(C_reference, #reference) IREP_ID_TWO(C_rvalue_reference, #rvalue_reference)