diff --git a/docs/algorithms/circuit_validator.rst b/docs/algorithms/circuit_validator.rst new file mode 100644 index 000000000..d097acbcf --- /dev/null +++ b/docs/algorithms/circuit_validator.rst @@ -0,0 +1,76 @@ +Functional equivalence of circuit nodes +--------------------------------------- + +**Header:** ``mockturtle/algorithms/circuit_validator.hpp`` + +**Example** + +The following code shows how to check functional equivalence of a root node to signals existing in the network, or created with nodes within the network. If not, get the counter example. + +.. code-block:: c++ + + /* derive some AIG (can be AIG, XAG, MIG, or XMG) */ + aig_network aig; + auto const a = aig.create_pi(); + auto const b = aig.create_pi(); + auto const f1 = aig.create_and( !a, b ); + auto const f2 = aig.create_and( a, !b ); + auto const f3 = aig.create_or( f1, f2 ); + + circuit_validator v( aig ); + + auto result = v.validate( f1, f2 ); + /* result is an optional, which is nullopt if SAT conflict limit was exceeded */ + if ( result ) + { + if ( *result ) + { + std::cout << "f1 and f2 are functionally equivalent\n"; + } + else + { + std::cout << "f1 and f2 have different values under PI assignment: " << v.cex[0] << v.cex[1] << "\n"; + } + } + + circuit_validator::gate::fanin fi1; + fi1.idx = 0; fi1.inv = true; + circuit_validator::gate::fanin fi2; + fi2.idx = 1; fi2.inv = true; + circuit_validator::gate g; + g.fanins = {fi1, fi2}; + g.type = circuit_validator::gate_type::AND; + + result = v.validate( f3, {aig.get_node( f1 ), aig.get_node( f2 )}, {g}, true ); + if ( result && *result ) + { + std::cout << "f3 is equivalent to NOT(NOT f1 AND NOT f2)\n"; + } + +**Parameters** + +.. doxygenstruct:: mockturtle::validator_params + :members: + +**Validate with existing signals** + +.. doxygenfunction:: mockturtle::circuit_validator::validate( signal const&, signal const& ) +.. doxygenfunction:: mockturtle::circuit_validator::validate( node const&, signal const& ) +.. doxygenfunction:: mockturtle::circuit_validator::validate( signal const&, bool ) +.. doxygenfunction:: mockturtle::circuit_validator::validate( node const&, bool ) + +**Validate with non-existing circuit** + +.. doxygenstruct:: mockturtle::circuit_validator::gate + :members: fanins, type +.. doxygenstruct:: mockturtle::circuit_validator::gate::fanin + :members: idx, inv + +.. doxygenfunction:: mockturtle::circuit_validator::validate( signal const&, std::vector const&, std::vector const&, bool ) +.. doxygenfunction:: mockturtle::circuit_validator::validate( node const&, std::vector const&, std::vector const&, bool ) +.. doxygenfunction:: mockturtle::circuit_validator::validate( signal const&, iterator_type, iterator_type, std::vector const&, bool ) +.. doxygenfunction:: mockturtle::circuit_validator::validate( node const&, iterator_type, iterator_type, std::vector const&, bool ) + +**Updating** +.. doxygenfunction:: mockturtle::circuit_validator::add_node +.. doxygenfunction:: mockturtle::circuit_validator::update diff --git a/docs/changelog.rst b/docs/changelog.rst index 939bf5068..b803318ef 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -29,9 +29,10 @@ v0.2 (not yet released) - Davio decomposition (`positive_davio_decomposition`, `positive_davio_decomposition`) `#308 `_ - Collapse network into single node per output network `#309 `_ - Generic balancing algorithm `#340 `_ + - Check functional equivalence (`circuit_validator`) `#346 `_ * Views: - Assign names to signals and outputs (`names_view`) `#181 `_ `#184 `_ - - Creates a CNF while creating a network (`cnf_view`) `#181 `_ `#184 `_ + - Creates a CNF while creating a network (`cnf_view`) `#274 `_ * I/O: - Write networks to DIMACS files for CNF (`write_dimacs`) `#146 `_ - Read BLIF files using *lorina* (`blif_reader`) `#167 `_ diff --git a/include/mockturtle/algorithms/circuit_validator.hpp b/include/mockturtle/algorithms/circuit_validator.hpp new file mode 100644 index 000000000..9946b5c81 --- /dev/null +++ b/include/mockturtle/algorithms/circuit_validator.hpp @@ -0,0 +1,535 @@ +/* mockturtle: C++ logic network library + * Copyright (C) 2018-2019 EPFL + * + * Permission is hereby granted, free of charge, to any person + * obtaining a copy of this software and associated documentation + * files (the "Software"), to deal in the Software without + * restriction, including without limitation the rights to use, + * copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following + * conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES + * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT + * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, + * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + * OTHER DEALINGS IN THE SOFTWARE. + */ + +/*! + \file circuit_validator.hpp + \brief Validate potential circuit optimization choices with SAT. + + \author Siang-Yun Lee +*/ + +#pragma once + +#include "../utils/node_map.hpp" +#include "cnf.hpp" +#include +#include +#include +#include + +namespace mockturtle +{ + +struct validator_params +{ + /*! \brief Whether to consider ODC, and how many levels. 0 = no. -1 = Consider TFO until PO. */ + int odc_levels{0}; + + /*! \brief Conflict limit of the SAT solver. */ + uint32_t conflict_limit{1000}; + + /*! \brief Seed for randomized solving. */ + uint32_t random_seed{0}; +}; + +template +class circuit_validator +{ +public: + using node = typename Ntk::node; + using signal = typename Ntk::signal; + using add_clause_fn_t = std::function const& )>; + + enum gate_type + { + AND, + XOR, + MAJ + }; + + struct gate + { + struct fanin + { + /*! \brief Index in the concatenated list of `divs` and `ckt`. */ + uint32_t idx; + + /*! \brief Input negation. */ + bool inv{false}; + }; + + /*! \brief Fanins of the gate. */ + std::vector fanins; + + /*! \brief Type of the gate. + * + * Supported types include AND, XOR, and MAJ. + */ + gate_type type{AND}; + }; + + explicit circuit_validator( Ntk const& ntk, validator_params const& ps = {} ) + : ntk( ntk ), ps( ps ), literals( ntk ), cex( ntk.num_pis() ) + { + static_assert( is_network_type_v, "Ntk is not a network type" ); + static_assert( has_foreach_fanin_v, "Ntk does not implement the foreach_fanin method" ); + static_assert( has_foreach_gate_v, "Ntk does not implement the foreach_gate method" ); + static_assert( has_foreach_pi_v, "Ntk does not implement the foreach_pi method" ); + static_assert( has_get_constant_v, "Ntk does not implement the get_constant method" ); + static_assert( has_get_node_v, "Ntk does not implement the get_node method" ); + static_assert( has_is_complemented_v, "Ntk does not implement the is_complemented method" ); + static_assert( has_make_signal_v, "Ntk does not implement the make_signal method" ); + static_assert( has_size_v, "Ntk does not implement the size method" ); + static_assert( has_is_and_v, "Ntk does not implement the is_and method" ); + static_assert( has_is_xor_v, "Ntk does not implement the is_xor method" ); + static_assert( has_is_xor3_v, "Ntk does not implement the is_xor3 method" ); + static_assert( has_is_maj_v, "Ntk does not implement the is_maj method" ); + + if constexpr ( use_pushpop ) + { +#if defined( BILL_HAS_Z3 ) + static_assert( Solver == bill::solvers::z3 || Solver == bill::solvers::bsat2, "Solver does not support push/pop" ); +#else + static_assert( Solver == bill::solvers::bsat2, "Solver does not support push/pop" ); +#endif + } + if constexpr ( randomize ) + { +#if defined( BILL_HAS_Z3 ) + static_assert( Solver == bill::solvers::z3 || Solver == bill::solvers::bsat2, "Solver does not support set_random" ); +#else + static_assert( Solver == bill::solvers::bsat2, "Solver does not support set_random" ); +#endif + } + if constexpr ( use_odc ) + { + static_assert( has_set_visited_v, "Ntk does not implement the set_visited method" ); + static_assert( has_visited_v, "Ntk does not implement the visited method" ); + static_assert( has_foreach_po_v, "Ntk does not implement the foreach_po method" ); + static_assert( has_foreach_fanout_v, "Ntk does not implement the foreach_fanout method" ); + } + + restart(); + } + + /*! \brief Validate functional equivalence of signals `f` and `d`. */ + std::optional validate( signal const& f, signal const& d ) + { + return validate( ntk.get_node( f ), lit_not_cond( literals[d], ntk.is_complemented( f ) ^ ntk.is_complemented( d ) ) ); + } + + /*! \brief Validate functional equivalence of node `root` and signal `d`. */ + std::optional validate( node const& root, signal const& d ) + { + return validate( root, lit_not_cond( literals[d], ntk.is_complemented( d ) ) ); + } + + /*! \brief Validate functional equivalence of signal `f` with a circuit. */ + std::optional validate( signal const& f, std::vector const& divs, std::vector const& ckt, bool output_negation = false ) + { + return validate( ntk.get_node( f ), divs.begin(), divs.end(), ckt, output_negation ^ ntk.is_complemented( f ) ); + } + + /*! \brief Validate functional equivalence of node `root` with a circuit. + * + * The circuit `ckt` uses `divs` as inputs, which are existing nodes in the network. + */ + std::optional validate( node const& root, std::vector const& divs, std::vector const& ckt, bool output_negation = false ) + { + return validate( root, divs.begin(), divs.end(), ckt, output_negation ); + } + + /*! \brief Validate functional equivalence of signal `f` with a circuit. */ + template + std::optional validate( signal const& f, iterator_type divs_begin, iterator_type divs_end, std::vector const& ckt, bool output_negation = false ) + { + return validate( ntk.get_node( f ), divs_begin, divs_end, ckt, output_negation ^ ntk.is_complemented( f ) ); + } + + /*! \brief Validate functional equivalence of node `root` with a circuit. */ + template + std::optional validate( node const& root, iterator_type divs_begin, iterator_type divs_end, std::vector const& ckt, bool output_negation = false ) + { + if constexpr ( use_pushpop ) + { + solver.push(); + } + + std::vector lits; + while ( divs_begin != divs_end ) + { + lits.emplace_back( literals[*divs_begin] ); + divs_begin++; + } + + for ( auto g : ckt ) + { + lits.emplace_back( add_tmp_gate( lits, g ) ); + } + + auto const res = validate( root, lit_not_cond( lits.back(), output_negation ) ); + + if constexpr ( use_pushpop ) + { + solver.pop(); + } + + return res; + } + + /*! \brief Validate whether signal `f` is a constant of `value`. */ + std::optional validate( signal const& f, bool value ) + { + return validate( ntk.get_node( f ), value ^ ntk.is_complemented( f ) ); + } + + /*! \brief Validate whether node `root` is a constant of `value`. */ + std::optional validate( node const& root, bool value ) + { + assert( literals[root].variable() != bill::var_type( 0 ) ); + if constexpr ( use_pushpop ) + { + solver.push(); + } + + std::optional res; + if constexpr ( use_odc ) + { + if ( ps.odc_levels != 0 ) + { + res = solve( {build_odc_window( root, ~literals[root] ), lit_not_cond( literals[root], value )} ); + } + else + { + res = solve( {lit_not_cond( literals[root], value )} ); + } + } + else + { + res = solve( {lit_not_cond( literals[root], value )} ); + } + + if constexpr ( use_pushpop ) + { + solver.pop(); + } + return res; + } + + /*! \brief Add CNF clauses for a newly created node. + * + * This function should be called when a new node is created after + * construction of circuit_validator. + * It can be called manually every time or be added to ntk.on_add events. + */ + void add_node( node const& n ) + { + std::vector lit_fi; + ntk.foreach_fanin( n, [&]( const auto& f ) { + lit_fi.emplace_back( lit_not_cond( literals[f], ntk.is_complemented( f ) ) ); + } ); + + literals.resize(); + assert( lit_fi.size() == 2u || lit_fi.size() == 3u ); + if ( lit_fi.size() == 2u ) + { + assert( ntk.is_and( n ) || ntk.is_xor( n ) ); + literals[n] = add_clauses_for_2input_gate( lit_fi[0], lit_fi[1], std::nullopt, ntk.is_and( n ) ? AND : XOR ); + } + else + { + assert( lit_fi.size() == 3u ); + assert( ntk.is_maj( n ) || ntk.is_xor3( n ) ); + literals[n] = add_clauses_for_3input_gate( lit_fi[0], lit_fi[1], lit_fi[2], std::nullopt, ntk.is_maj( n ) ? MAJ : XOR ); + } + } + + /*! \brief Update CNF clauses. + * + * This function should be called when the function of one or more nodes + * has been modified (typically when utilizing ODCs). + */ + void update() + { + restart(); + } + +private: + void restart() + { + solver.restart(); + if constexpr ( randomize ) + { + solver.set_random_phase( ps.random_seed ); + } + + literals.reset(); + /* constants are mapped to var 0 */ + literals[ntk.get_constant( false )] = bill::lit_type( 0, bill::lit_type::polarities::positive ); + if ( ntk.get_node( ntk.get_constant( false ) ) != ntk.get_node( ntk.get_constant( true ) ) ) + { + literals[ntk.get_constant( true )] = bill::lit_type( 0, bill::lit_type::polarities::negative ); + } + + /* first indexes (starting from 1) are for PIs */ + ntk.foreach_pi( [&]( auto const& n, auto i ) { + literals[n] = bill::lit_type( i + 1, bill::lit_type::polarities::positive ); + } ); + + /* compute literals for nodes */ + uint32_t next_var = ntk.num_pis() + 1; + ntk.foreach_gate( [&]( auto const& n ) { + literals[n] = bill::lit_type( next_var++, bill::lit_type::polarities::positive ); + } ); + + solver.add_variables( ntk.size() ); + generate_cnf( + ntk, [&]( auto const& clause ) { + solver.add_clause( clause ); + }, + literals ); + } + + bill::lit_type add_clauses_for_2input_gate( bill::lit_type a, bill::lit_type b, std::optional c = std::nullopt, gate_type type = AND ) + { + assert( type == AND || type == XOR ); + + auto nlit = c ? *c : bill::lit_type( solver.add_variable(), bill::lit_type::polarities::positive ); + if ( type == AND ) + { + detail::on_and( nlit, a, b, [&]( auto const& clause ) { + solver.add_clause( clause ); + } ); + } + else if ( type == XOR ) + { + detail::on_xor( nlit, a, b, [&]( auto const& clause ) { + solver.add_clause( clause ); + } ); + } + + return nlit; + } + + bill::lit_type add_clauses_for_3input_gate( bill::lit_type a, bill::lit_type b, bill::lit_type c, std::optional d = std::nullopt, gate_type type = MAJ ) + { + assert( type == MAJ || type == XOR ); + + auto nlit = d ? *d : bill::lit_type( solver.add_variable(), bill::lit_type::polarities::positive ); + if ( type == MAJ ) + { + detail::on_maj( nlit, a, b, c, [&]( auto const& clause ) { + solver.add_clause( clause ); + } ); + } + else if ( type == XOR ) + { + detail::on_xor3( nlit, a, b, c, [&]( auto const& clause ) { + solver.add_clause( clause ); + } ); + } + + return nlit; + } + + bill::lit_type add_tmp_gate( std::vector const& lits, gate const& g ) + { + /* currently supports AND2, XOR2, XOR3, MAJ3 */ + assert( g.fanins.size() == 2u || g.fanins.size() == 3u ); + + if ( g.fanins.size() == 2u ) + { + assert( g.fanins[0].idx < lits.size() ); + assert( g.fanins[1].idx < lits.size() ); + return add_clauses_for_2input_gate( lit_not_cond( lits[g.fanins[0].idx], g.fanins[0].inv ), lit_not_cond( lits[g.fanins[1].idx], g.fanins[1].inv ), std::nullopt, g.type ); + } + else + { + assert( g.fanins[0].idx < lits.size() ); + assert( g.fanins[1].idx < lits.size() ); + assert( g.fanins[2].idx < lits.size() ); + return add_clauses_for_3input_gate( lit_not_cond( lits[g.fanins[0].idx], g.fanins[0].inv ), lit_not_cond( lits[g.fanins[1].idx], g.fanins[1].inv ), lit_not_cond( lits[g.fanins[2].idx], g.fanins[2].inv ), std::nullopt, g.type ); + } + } + + std::optional solve( std::vector assumptions ) + { + auto const res = solver.solve( assumptions, ps.conflict_limit ); + if ( res == bill::result::states::satisfiable ) + { + auto model = solver.get_model().model(); + for ( auto i = 0u; i < ntk.num_pis(); ++i ) + { + cex.at( i ) = model.at( i + 1 ) == bill::lbool_type::true_; + } + return false; + } + else if ( res == bill::result::states::unsatisfiable ) + { + return true; + } + return std::nullopt; /* timeout or something wrong */ + } + + std::optional validate( node const& root, bill::lit_type const& lit ) + { + assert( literals[root].variable() != bill::var_type( 0 ) ); + if constexpr ( use_pushpop ) + { + solver.push(); + } + + std::optional res; + if constexpr ( use_odc ) + { + if ( ps.odc_levels != 0 ) + { + res = solve( {build_odc_window( root, lit )} ); + } + else + { + auto nlit = bill::lit_type( solver.add_variable(), bill::lit_type::polarities::positive ); + solver.add_clause( {literals[root], lit, nlit} ); + solver.add_clause( {~( literals[root] ), ~lit, nlit} ); + res = solve( {~nlit} ); + } + } + else + { + auto nlit = bill::lit_type( solver.add_variable(), bill::lit_type::polarities::positive ); + solver.add_clause( {literals[root], lit, nlit} ); + solver.add_clause( {~( literals[root] ), ~lit, nlit} ); + res = solve( {~nlit} ); + } + + if constexpr ( use_pushpop ) + { + solver.pop(); + } + return res; + } + +private: + template> + bill::lit_type build_odc_window( node const& root, bill::lit_type const& lit ) + { + /* literals for the duplicated fanout cone */ + unordered_node_map lits( ntk ); + /* literals of XORs in the miter */ + std::vector miter; + + lits[root] = lit; + ntk.incr_trav_id(); + make_lit_fanout_cone_rec( root, lits, miter, 1 ); + ntk.incr_trav_id(); + duplicate_fanout_cone_rec( root, lits, 1 ); + + /* miter for POs */ + ntk.foreach_po( [&]( auto const& f ) { + if ( !lits.has( ntk.get_node( f ) ) ) + return true; /* PO not in TFO, skip */ + add_miter_clauses( ntk.get_node( f ), lits, miter ); + return true; /* next */ + } ); + + assert( miter.size() > 0 && "max fanout depth < odc_levels (-1 is infinity) and there is no PO in TFO cone" ); + auto nlit2 = bill::lit_type( solver.add_variable(), bill::lit_type::polarities::positive ); + miter.emplace_back( nlit2 ); + solver.add_clause( miter ); + return ~nlit2; + } + + template> + void duplicate_fanout_cone_rec( node const& n, unordered_node_map const& lits, int level ) + { + ntk.foreach_fanout( n, [&]( auto const& fo ) { + if ( ntk.visited( fo ) == ntk.trav_id() ) + return true; /* skip */ + ntk.set_visited( fo, ntk.trav_id() ); + + std::vector l_fi; + ntk.foreach_fanin( fo, [&]( auto const& fi ) { + l_fi.emplace_back( lit_not_cond( lits.has( ntk.get_node( fi ) ) ? lits[fi] : literals[fi], ntk.is_complemented( fi ) ) ); + } ); + if ( l_fi.size() == 2u ) + { + assert( ntk.is_and( fo ) || ntk.is_xor( fo ) ); + add_clauses_for_2input_gate( l_fi[0], l_fi[1], lits[fo], ntk.is_and( fo ) ? AND : XOR ); + } + else + { + assert( l_fi.size() == 3u ); + assert( ntk.is_maj( fo ) || ntk.is_xor3( fo ) ); + add_clauses_for_3input_gate( l_fi[0], l_fi[1], l_fi[2], lits[fo], ntk.is_maj( fo ) ? MAJ : XOR ); + } + + if ( level == ps.odc_levels ) + return true; + + duplicate_fanout_cone_rec( fo, lits, level + 1 ); + return true; /* next */ + } ); + } + + template> + void make_lit_fanout_cone_rec( node const& n, unordered_node_map& lits, std::vector& miter, int level ) + { + ntk.foreach_fanout( n, [&]( auto const& fo ) { + if ( ntk.visited( fo ) == ntk.trav_id() ) + return true; /* skip */ + ntk.set_visited( fo, ntk.trav_id() ); + + lits[fo] = bill::lit_type( solver.add_variable(), bill::lit_type::polarities::positive ); + + if ( level == ps.odc_levels ) + { + add_miter_clauses( fo, lits, miter ); + return true; + } + + make_lit_fanout_cone_rec( fo, lits, miter, level + 1 ); + return true; /* next */ + } ); + } + + template> + void add_miter_clauses( node const& n, unordered_node_map const& lits, std::vector& miter ) + { + miter.emplace_back( add_clauses_for_2input_gate( literals[n], lits[n], std::nullopt, XOR ) ); + } + +private: + Ntk const& ntk; + + validator_params const& ps; + + node_map literals; + bill::solver solver; + +public: + std::vector cex; +}; + +} /* namespace mockturtle */ diff --git a/include/mockturtle/algorithms/cnf.hpp b/include/mockturtle/algorithms/cnf.hpp index 77ac45c06..cd12075fa 100644 --- a/include/mockturtle/algorithms/cnf.hpp +++ b/include/mockturtle/algorithms/cnf.hpp @@ -59,11 +59,21 @@ inline constexpr uint32_t lit_not( uint32_t lit ) return lit ^ 0x1; } +inline bill::lit_type lit_not( bill::lit_type lit ) +{ + return ~lit; +} + inline constexpr uint32_t lit_not_cond( uint32_t lit, bool cond ) { return cond ? lit ^ 0x1 : lit; } +inline bill::lit_type lit_not_cond( bill::lit_type lit, bool cond ) +{ + return cond ? ~lit : lit; +} + namespace detail { @@ -242,7 +252,8 @@ inline void on_function( bill::lit_type f, std::vector const& ch } // namespace detail /*! \brief Clause callback function for generate_cnf. */ -using clause_callback_t = std::function const& )>; +template +using clause_callback_t = std::function const& )>; /*! \brief Create a default node literal map. * @@ -254,8 +265,8 @@ using clause_callback_t = std::function const& )>; * independent sets of node literals for two networks, but keep the same indexes * for the primary inputs. */ -template -node_map node_literals( Ntk const& ntk, std::optional const& gate_offset = {} ) +template +node_map node_literals( Ntk const& ntk, std::optional const& gate_offset = {} ) { static_assert( is_network_type_v, "Ntk is not a network type" ); static_assert( has_num_pis_v, "Ntk does not implement the num_pis method" ); @@ -264,25 +275,48 @@ node_map node_literals( Ntk const& ntk, std::optional c static_assert( has_foreach_pi_v, "Ntk does not implement the foreach_pi method" ); static_assert( has_foreach_gate_v, "Ntk does not implement the foreach_gate method" ); - node_map node_lits( ntk ); + node_map node_lits( ntk ); - /* constants are mapped to var 0 */ - node_lits[ntk.get_constant( false )] = make_lit( 0 ); - if ( ntk.get_node( ntk.get_constant( false ) ) != ntk.get_node( ntk.get_constant( true ) ) ) + if constexpr ( std::is_same::value ) { - node_lits[ntk.get_constant( true )] = make_lit( 0, true ); + /* constants are mapped to var 0 */ + node_lits[ntk.get_constant( false )] = make_lit( 0 ); + if ( ntk.get_node( ntk.get_constant( false ) ) != ntk.get_node( ntk.get_constant( true ) ) ) + { + node_lits[ntk.get_constant( true )] = make_lit( 0, true ); + } + + /* first indexes (starting from 1) are for PIs */ + ntk.foreach_pi( [&]( auto const& n, auto i ) { + node_lits[n] = make_lit( i + 1 ); + } ); + + /* compute literals for nodes */ + uint32_t next_var = gate_offset ? *gate_offset : ntk.num_pis() + 1; + ntk.foreach_gate( [&]( auto const& n ) { + node_lits[n] = make_lit( next_var++ ); + } ); } + else if constexpr ( std::is_same::value ) + { + /* constants are mapped to var 0 */ + node_lits[ntk.get_constant( false )] = bill::lit_type( 0, bill::lit_type::polarities::positive ); + if ( ntk.get_node( ntk.get_constant( false ) ) != ntk.get_node( ntk.get_constant( true ) ) ) + { + node_lits[ntk.get_constant( true )] = bill::lit_type( 0, bill::lit_type::polarities::negative ); + } - /* first indexes (starting from 1) are for PIs */ - ntk.foreach_pi( [&]( auto const& n, auto i ) { - node_lits[n] = make_lit( i + 1 ); - } ); + /* first indexes (starting from 1) are for PIs */ + ntk.foreach_pi( [&]( auto const& n, auto i ) { + node_lits[n] = bill::lit_type( i + 1, bill::lit_type::polarities::positive ); + } ); - /* compute literals for nodes */ - uint32_t next_var = gate_offset ? *gate_offset : ntk.num_pis() + 1; - ntk.foreach_gate( [&]( auto const& n ) { - node_lits[n] = make_lit( next_var++ ); - } ); + /* compute literals for nodes */ + uint32_t next_var = gate_offset ? *gate_offset : ntk.num_pis() + 1; + ntk.foreach_gate( [&]( auto const& n ) { + node_lits[n] = bill::lit_type( next_var++, bill::lit_type::polarities::positive ); + } ); + } return node_lits; } @@ -290,29 +324,29 @@ node_map node_literals( Ntk const& ntk, std::optional c namespace detail { -template +template class generate_cnf_impl { public: - generate_cnf_impl( Ntk const& ntk, clause_callback_t const& fn, std::optional> const& node_lits ) + generate_cnf_impl( Ntk const& ntk, clause_callback_t const& fn, std::optional> const& node_lits ) : ntk_( ntk ), fn_( fn ), - node_lits_( node_lits ? *node_lits : node_literals( ntk ) ) + node_lits_( node_lits ? *node_lits : node_literals( ntk ) ) { } - std::vector run() + std::vector run() { /* unit clause for constant-0 */ - fn_( {1} ); + fn_( {lit_not( node_lits_[ntk_.get_constant( false )])} ); /* compute clauses for nodes */ ntk_.foreach_gate( [&]( auto const& n ) { - std::vector child_lits; + std::vector child_lits; ntk_.foreach_fanin( n, [&]( auto const& f ) { child_lits.push_back( lit_not_cond( node_lits_[f], ntk_.is_complemented( f ) ) ); } ); - uint32_t node_lit = node_lits_[n]; + lit_t node_lit = node_lits_[n]; if constexpr ( has_is_and_v ) { @@ -399,7 +433,7 @@ class generate_cnf_impl return true; } ); - std::vector output_lits; + std::vector output_lits; ntk_.foreach_po( [&]( auto const& f ) { output_lits.push_back( lit_not_cond( node_lits_[f], ntk_.is_complemented( f ) ) ); } ); @@ -409,9 +443,9 @@ class generate_cnf_impl private: Ntk const& ntk_; - clause_callback_t const& fn_; + clause_callback_t const& fn_; - node_map node_lits_; + node_map node_lits_; }; } // namespace detail @@ -439,7 +473,22 @@ class generate_cnf_impl * \param node_lits (optional) custom node literal map */ template -std::vector generate_cnf( Ntk const& ntk, clause_callback_t const& fn, std::optional> const& node_lits = {} ) +std::vector generate_cnf( Ntk const& ntk, clause_callback_t const& fn, std::optional> const& node_lits = {} ) +{ + static_assert( is_network_type_v, "Ntk is not a network type" ); + static_assert( has_foreach_gate_v, "Ntk does not implement the foreach_gate method" ); + static_assert( has_foreach_po_v, "Ntk does not implement the foreach_po method" ); + static_assert( has_foreach_fanin_v, "Ntk does not implement the foreach_fanin method" ); + static_assert( has_is_complemented_v, "Ntk does not implement the is_complemented method" ); + static_assert( has_node_function_v, "Ntk does not implement the node_function method" ); + static_assert( has_fanin_size_v, "Ntk does not implement the fanin_size method" ); + + detail::generate_cnf_impl impl( ntk, fn, node_lits ); + return impl.run(); +} + +template +std::vector generate_cnf( Ntk const& ntk, clause_callback_t const& fn, std::optional> const& node_lits = {} ) { static_assert( is_network_type_v, "Ntk is not a network type" ); static_assert( has_foreach_gate_v, "Ntk does not implement the foreach_gate method" ); @@ -449,7 +498,7 @@ std::vector generate_cnf( Ntk const& ntk, clause_callback_t const& fn, static_assert( has_node_function_v, "Ntk does not implement the node_function method" ); static_assert( has_fanin_size_v, "Ntk does not implement the fanin_size method" ); - detail::generate_cnf_impl impl( ntk, fn, node_lits ); + detail::generate_cnf_impl impl( ntk, fn, node_lits ); return impl.run(); } diff --git a/lib/bill/bill/bill.hpp b/lib/bill/bill/bill.hpp new file mode 100644 index 000000000..bd734b8b9 --- /dev/null +++ b/lib/bill/bill/bill.hpp @@ -0,0 +1,22 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include diff --git a/lib/bill/bill/sat/interface/abc_bsat2.hpp b/lib/bill/bill/sat/interface/abc_bsat2.hpp index c43bec78d..043bb6f09 100644 --- a/lib/bill/bill/sat/interface/abc_bsat2.hpp +++ b/lib/bill/bill/sat/interface/abc_bsat2.hpp @@ -10,10 +10,10 @@ #include #include #include +#include namespace bill { -#if !defined(BILL_WINDOWS_PLATFORM) template<> class solver { using solver_type = pabc::sat_solver; @@ -41,6 +41,7 @@ class solver { { pabc::sat_solver_restart(solver_); state_ = result::states::undefined; + randomize = false; } var_type add_variable() @@ -111,6 +112,19 @@ class solver { if (num_variables() == 0u) return result::states::undefined; + if ( randomize ) + { + std::vector vars; + for ( auto i = 0u; i < num_variables(); ++i ) + { + if ( random() % 2 ) + { + vars.push_back( i ); + } + } + pabc::sat_solver_set_polarity( solver_, (int*)(const_cast(vars.data())), vars.size() ); + } + int result; if (assumptions.size() > 0u) { /* solve with assumptions */ @@ -152,6 +166,24 @@ class solver { } #pragma endregion + void push() + { + pabc::sat_solver_bookmark(solver_); + } + + void pop( uint32_t n = 1u ) + { + assert( n == 1u && "bsat does not support multiple step pop" ); (void)n; + pabc::sat_solver_rollback(solver_); + } + + void set_random_phase( uint32_t seed = 0u ) + { + randomize = true; + pabc::sat_solver_set_random(solver_, 1); + random.seed( seed ); + } + private: /*! \brief Backend solver */ solver_type* solver_ = nullptr; @@ -161,7 +193,9 @@ class solver { /*! \brief Temporary storage for one clause */ pabc::lit literals[2048]; + + std::default_random_engine random; + bool randomize = false; }; -#endif } // namespace bill diff --git a/lib/bill/bill/sat/interface/common.hpp b/lib/bill/bill/sat/interface/common.hpp index 01f2a1f07..7f17d15d3 100644 --- a/lib/bill/bill/sat/interface/common.hpp +++ b/lib/bill/bill/sat/interface/common.hpp @@ -12,6 +12,10 @@ disable : 4018 4127 4189 4200 4242 4244 4245 4305 4365 4388 4389 4456 4457 4459 4514 4552 4571 4583 4619 4623 4625 4626 4706 4710 4711 4774 4820 4820 4996 5026 5027 5039) #include "../solver/ghack.hpp" #include "../solver/glucose.hpp" +#define ABC_USE_NAMESPACE pabc +#define ABC_NAMESPACE pabc +#define ABC_USE_NO_READLINE +#include "../solver/abc.hpp" #pragma warning(pop) #else #pragma GCC diagnostic push @@ -27,7 +31,6 @@ #include "../solver/ghack.hpp" #include "../solver/glucose.hpp" #include "../solver/maple.hpp" -#if !defined(BILL_WINDOWS_PLATFORM) #ifndef LIN64 #define LIN64 #endif @@ -35,7 +38,6 @@ #define ABC_NAMESPACE pabc #define ABC_USE_NO_READLINE #include "../solver/abc.hpp" -#endif #pragma GCC diagnostic pop #endif @@ -113,6 +115,11 @@ class result { { return std::get(data_); } + + inline clause_type core() const + { + return std::get(data_); + } #pragma endregion #pragma region Overloads @@ -135,9 +142,9 @@ class result { enum class solvers { glucose_41, ghack, + bsat2, #if !defined(BILL_WINDOWS_PLATFORM) maple, - bsat2, bmcg, #endif #if defined(BILL_HAS_Z3) @@ -145,6 +152,8 @@ enum class solvers { #endif }; +/*! \brief Solver interface + */ template class solver; diff --git a/lib/bill/bill/sat/interface/z3.hpp b/lib/bill/bill/sat/interface/z3.hpp index 95c6da84c..bdae5103a 100644 --- a/lib/bill/bill/sat/interface/z3.hpp +++ b/lib/bill/bill/sat/interface/z3.hpp @@ -21,7 +21,7 @@ class solver { public: #pragma region Constructors solver() - : solver_(ctx_) + : solver_(ctx_), var_ctr_( 1, 0u ), cls_ctr_( 1, 0u ) { } @@ -38,15 +38,17 @@ class solver { { solver_.reset(); vars_.clear(); - var_ctr_ = 0u; - cls_ctr_ = 0u; + var_ctr_.clear(); + var_ctr_.emplace_back( 0u ); + cls_ctr_.clear(); + cls_ctr_.emplace_back( 0u ); state_ = result::states::undefined; } var_type add_variable() { - vars_.push_back(ctx_.bool_const(fmt::format("x{}", var_ctr_).c_str())); - return var_ctr_++; + vars_.push_back(ctx_.bool_const(fmt::format("x{}", var_ctr_.back()).c_str())); + return var_ctr_.back()++; } void add_variables(uint32_t num_variables = 1) @@ -66,7 +68,7 @@ class solver { ++it; } solver_.add(mk_or(vec)); - ++cls_ctr_; + ++cls_ctr_.back(); return result::states::dirty; } @@ -131,22 +133,47 @@ class solver { #pragma region Properties uint32_t num_variables() const { - return var_ctr_; + return var_ctr_.back(); } uint32_t num_clauses() const { - return cls_ctr_; + return cls_ctr_.back(); } #pragma endregion + void push() + { + solver_.push(); + var_ctr_.emplace_back( var_ctr_.back() ); + cls_ctr_.emplace_back( cls_ctr_.back() ); + } + + void pop( uint32_t n = 1u ) + { + assert( n < var_ctr_.size() ); + solver_.pop( n ); + var_ctr_.resize( var_ctr_.size() - n ); + cls_ctr_.resize( cls_ctr_.size() - n ); + if ( vars_.size() > var_ctr_.back() ) + { + vars_.erase( vars_.begin() + var_ctr_.back(), vars_.end() ); + } + } + + void set_random_phase( uint32_t seed = 0u ) + { + solver_.set("sat.random_seed", seed); + solver_.set("phase_selection", 5u); + } + private: z3::context ctx_; z3::solver solver_; result::states state_ = result::states::undefined; std::vector vars_; - uint32_t var_ctr_{}; - uint32_t cls_ctr_{}; + std::vector var_ctr_; + std::vector cls_ctr_; }; } // namespace bill diff --git a/lib/bill/bill/sat/tseytin.hpp b/lib/bill/bill/sat/tseytin.hpp index 1dcec4216..c94d2290d 100644 --- a/lib/bill/bill/sat/tseytin.hpp +++ b/lib/bill/bill/sat/tseytin.hpp @@ -9,6 +9,13 @@ namespace bill { +/*! \brief Adds CNF clauses for `y = (a and b)` to the solver. + * + * \param solver Solver + * \param a Literal + * \param b Literal + * \return Literal y + */ template lit_type add_tseytin_and(Solver& solver, lit_type const& a, lit_type const& b) { @@ -19,6 +26,12 @@ lit_type add_tseytin_and(Solver& solver, lit_type const& a, lit_type const& b) return lit_type(r, lit_type::polarities::positive); } +/*! \brief Adds CNF clauses for `y = (l_0 and ... and l_{n-1})` to the solver. + * + * \param solver Solver + * \param ls List of literals + * \return Literal y + */ template lit_type add_tseytin_and(Solver& solver, std::vector const& ls) { @@ -33,6 +46,13 @@ lit_type add_tseytin_and(Solver& solver, std::vector const& ls) return lit_type(r, lit_type::polarities::positive); } +/*! \brief Adds CNF clauses for `y = a or b` to the solver. + * + * \param solver Solver + * \param a Literal + * \param b Literal + * \return Literal y + */ template lit_type add_tseytin_or(Solver& solver, lit_type const& a, lit_type const& b) { @@ -43,6 +63,12 @@ lit_type add_tseytin_or(Solver& solver, lit_type const& a, lit_type const& b) return lit_type(r, lit_type::polarities::positive); } +/*! \brief Adds CNF clauses for `y = (l_0 or ... or l_{n-1})` to the solver. + * + * \param solver Solver + * \param ls List of literals + * \return Literal y + */ template lit_type add_tseytin_or(Solver& solver, std::vector const& ls) { @@ -55,6 +81,13 @@ lit_type add_tseytin_or(Solver& solver, std::vector const& ls) return lit_type(r, lit_type::polarities::positive); } +/*! \brief Adds CNF clauses for `y = (a xor b)` to the solver. + * + * \param solver Solver + * \param a Literal + * \param b Literal + * \return Literal y + */ template lit_type add_tseytin_xor(Solver& solver, lit_type const& a, lit_type const& b) { @@ -66,6 +99,13 @@ lit_type add_tseytin_xor(Solver& solver, lit_type const& a, lit_type const& b) return lit_type(r, lit_type::polarities::positive); } +/*! \brief Adds CNF clauses for `y = (a == b)` to the solver. + * + * \param solver Solver + * \param a Literal + * \param b Literal + * \return Literal y + */ template lit_type add_tseytin_equals(Solver& solver, lit_type const& a, lit_type const& b) { diff --git a/lib/bill/bill/sat/unsat_cores.hpp b/lib/bill/bill/sat/unsat_cores.hpp new file mode 100644 index 000000000..a08a636d5 --- /dev/null +++ b/lib/bill/bill/sat/unsat_cores.hpp @@ -0,0 +1,73 @@ +#pragma once + +namespace bill { + +namespace detail { + +template +inline std::vector copy_vector_without_index(std::vector const& vs, uint32_t index) +{ + assert(index < vs.size()); + std::vector copy(vs); + copy.erase(std::begin(copy) + index); + return copy; +} + +} // namespace detail + +template +inline result::clause_type trim_core_copy(Solver& solver, result::clause_type const& core, + uint32_t num_tries = 8u) +{ + auto current = core; + + uint32_t counter = 0u; + while (counter++ < num_tries && solver.solve(current) == result::states::unsatisfiable) { + auto const new_core = solver.get_core().core(); + if (new_core.size() == current.size()) + break; + + current = new_core; + } + + return current; +} + +template +inline void trim_core(Solver& solver, result::clause_type& core, uint32_t num_tries = 0u) +{ + core = trim_core_copy(solver, core, num_tries); +} + +template +inline result::clause_type minimize_core_copy(Solver& solver, result::clause_type& core, + int64_t budget = 1000) +{ + auto pos = 0u; + auto current = core; + + while (pos < current.size()) { + auto temp = detail::copy_vector_without_index(current, pos); + + auto result = solver.solve(temp, budget); + if (result == result::states::unsatisfiable) { + current = temp; + } else { + ++pos; + } + } + + if (current.size() < core.size()) { + return current; + } else { + return core; + } +} + +template +inline void minimize_core(Solver& solver, result::clause_type& core, int64_t budget = 1000) +{ + core = minimize_core_copy(solver, core, budget); +} + +} // namespace bill diff --git a/lib/bill/bill/sat/xor_clauses.hpp b/lib/bill/bill/sat/xor_clauses.hpp index 0a45617d5..b6794b367 100644 --- a/lib/bill/bill/sat/xor_clauses.hpp +++ b/lib/bill/bill/sat/xor_clauses.hpp @@ -11,6 +11,13 @@ namespace bill { +/*! \brief Adds CNF clauses for `y = ((l_0 ^ ... ^ l_{n-1}) == pol)` to the solver. + * + * \param solver Solver + * \param clause List of literals + * \param pol Clause polarity + * \return Literal y + */ template lit_type add_xor_clause(Solver& solver, std::vector const& clause, lit_type::polarities pol = lit_type::polarities::positive) diff --git a/test/algorithms/circuit_validator.cpp b/test/algorithms/circuit_validator.cpp new file mode 100644 index 000000000..a54d40ff7 --- /dev/null +++ b/test/algorithms/circuit_validator.cpp @@ -0,0 +1,160 @@ +#include + +#include +#include +#include +#include +#include +#include + +using namespace mockturtle; + +TEST_CASE( "Validating NEQ nodes and get CEX", "[validator]" ) +{ + /* original circuit */ + aig_network aig; + auto const a = aig.create_pi(); + auto const b = aig.create_pi(); + auto const f1 = aig.create_and( !a, b ); + auto const f2 = aig.create_and( a, !b ); + + circuit_validator v( aig ); + + CHECK( *( v.validate( f1, f2 ) ) == false ); + CHECK( unsigned( v.cex[0] ) + unsigned( v.cex[1] ) == 1u ); /* either 01 or 10 */ +} + +TEST_CASE( "Validating EQ nodes in XAG", "[validator]" ) +{ + /* original circuit */ + xag_network xag; + auto const a = xag.create_pi(); + auto const b = xag.create_pi(); + auto const f1 = xag.create_and( !a, b ); + auto const f2 = xag.create_and( a, !b ); + auto const f3 = xag.create_or( f1, f2 ); + auto const g = xag.create_xor( a, b ); + + circuit_validator v( xag ); + + CHECK( *( v.validate( f3, g ) ) == true ); +} + +TEST_CASE( "Validating EQ nodes in MIG", "[validator]" ) +{ + /* original circuit */ + mig_network mig; + auto const a = mig.create_pi(); + auto const b = mig.create_pi(); + auto const c = mig.create_pi(); + + auto const f1 = mig.create_maj( a, b, mig.get_constant( false ) ); // a & b + auto const f2 = mig.create_maj( f1, c, mig.get_constant( false ) ); // a & b & c + + auto const f3 = mig.create_maj( !b, !c, mig.get_constant( true ) ); // !b | !c + auto const f4 = mig.create_maj( f3, !a, mig.get_constant( true ) ); // !a | !b | !c + + circuit_validator v( mig ); + + CHECK( *( v.validate( mig.get_node( f2 ), !f4 ) ) == true ); +} + +TEST_CASE( "Validating with non-existing circuit", "[validator]" ) +{ + /* original circuit */ + aig_network aig; + auto const a = aig.create_pi(); + auto const b = aig.create_pi(); + auto const f1 = aig.create_and( !a, b ); + auto const f2 = aig.create_and( a, !b ); + auto const f3 = aig.create_or( f1, f2 ); + + circuit_validator v( aig ); + + circuit_validator::gate::fanin fi1; + fi1.idx = 0; fi1.inv = true; + circuit_validator::gate::fanin fi2; + fi2.idx = 1; fi2.inv = true; + circuit_validator::gate g; + g.fanins = {fi1, fi2}; + CHECK( *( v.validate( f3, {aig.get_node( f1 ), aig.get_node( f2 )}, {g}, true ) ) == true ); + CHECK( *( v.validate( aig.get_node( f3 ), {aig.get_node( f1 ), aig.get_node( f2 )}, {g}, false ) ) == true ); +} + +TEST_CASE( "Validating after circuit update", "[validator]" ) +{ + /* original circuit */ + aig_network aig; + auto const a = aig.create_pi(); + auto const b = aig.create_pi(); + auto const f1 = aig.create_and( !a, b ); + auto const f2 = aig.create_and( a, !b ); + auto const f3 = aig.create_or( f1, f2 ); + + circuit_validator v( aig ); + + /* new nodes created after construction of `circuit_validator` have to be added to it manually with `add_node` */ + auto const g1 = aig.create_and( a, b ); + auto const g2 = aig.create_and( !a, !b ); + auto const g3 = aig.create_or( g1, g2 ); + v.add_node( aig.get_node( g1 ) ); + v.add_node( aig.get_node( g2 ) ); + v.add_node( aig.get_node( g3 ) ); + + CHECK( *( v.validate( aig.get_node( f3 ), g3 ) ) == true ); +} + +TEST_CASE( "Validating const nodes", "[validator]" ) +{ + /* original circuit */ + aig_network aig; + auto const a = aig.create_pi(); + auto const b = aig.create_pi(); + auto const f1 = aig.create_and( !a, b ); + auto const f2 = aig.create_and( a, !b ); + auto const f3 = aig.create_or( f1, f2 ); // a ^ b + + auto const g1 = aig.create_and( a, b ); + auto const g2 = aig.create_and( !a, !b ); + auto const g3 = aig.create_or( g1, g2 ); // a == b + + auto const h = aig.create_and( f3, g3 ); // const 0 + + circuit_validator v( aig ); + + CHECK( *( v.validate( aig.get_node( h ), false ) ) == true ); + CHECK( *( v.validate( f1, false ) ) == false ); + CHECK( v.cex[0] == false ); + CHECK( v.cex[1] == true ); +} + +TEST_CASE( "Validating with ODC", "[validator]" ) +{ + /* original circuit */ + aig_network aig; + auto const a = aig.create_pi(); + auto const b = aig.create_pi(); + auto const f1 = aig.create_and( !a, b ); + auto const f2 = aig.create_and( a, !b ); + auto const f3 = aig.create_or( f1, f2 ); // a ^ b + + auto const g1 = aig.create_and( a, b ); + auto const g2 = aig.create_and( !a, !b ); + auto const g3 = aig.create_or( g1, g2 ); // a == b + + auto const h = aig.create_and( f3, g3 ); // const 0 + aig.create_po( h ); + + validator_params ps; + fanout_view view{aig}; + circuit_validator, bill::solvers::bsat2, false, false, true> v( view, ps ); + + /* considering only 1 level, f1 can not be substituted with const 0 */ + ps.odc_levels = 1; + CHECK( *( v.validate( aig.get_node( f1 ), false ) ) == false ); + + /* considering 2 levels, f1 can be substituted with const 0 */ + ps.odc_levels = 2; + CHECK( *( v.validate( f1, false ) ) == true ); + CHECK( *( v.validate( aig.get_node( f1 ), aig.get_constant( false ) ) ) == true ); +} \ No newline at end of file