Skip to content

Commit

Permalink
Merge 1df928c into 82bbdec
Browse files Browse the repository at this point in the history
  • Loading branch information
msakai committed Aug 11, 2022
2 parents 82bbdec + 1df928c commit c54b8bd
Show file tree
Hide file tree
Showing 9 changed files with 201 additions and 86 deletions.
23 changes: 18 additions & 5 deletions src/ToySolver/SAT/Encoder/Cardinality.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,27 @@ module ToySolver.SAT.Encoder.Cardinality
, newEncoder
, newEncoderWithStrategy
, encodeAtLeast
, encodeAtLeastWithPolarity
, getTseitinEncoder

-- XXX
, TotalizerDefinitions
, getTotalizerDefinitions
, evalTotalizerDefinitions

-- * Polarity
, Polarity (..)
, negatePolarity
, polarityPos
, polarityNeg
, polarityBoth
, polarityNone
) where

import Control.Monad.Primitive
import qualified ToySolver.SAT.Types as SAT
import qualified ToySolver.SAT.Encoder.Tseitin as Tseitin
import ToySolver.SAT.Encoder.Tseitin (Polarity (..), negatePolarity, polarityPos, polarityNeg, polarityBoth, polarityNone)
import ToySolver.SAT.Encoder.Cardinality.Internal.Naive
import ToySolver.SAT.Encoder.Cardinality.Internal.ParallelCounter
import ToySolver.SAT.Encoder.PB.Internal.BDD as BDD
Expand Down Expand Up @@ -95,9 +105,12 @@ instance PrimMonad m => SAT.AddCardinality m (Encoder m) where
Totalizer -> Totalizer.addAtLeast base (lhs,rhs)

encodeAtLeast :: PrimMonad m => Encoder m -> SAT.AtLeast -> m SAT.Lit
encodeAtLeast (Encoder base@(Totalizer.Encoder tseitin _) strategy) =
encodeAtLeast enc = encodeAtLeastWithPolarity enc polarityBoth

encodeAtLeastWithPolarity :: PrimMonad m => Encoder m -> Polarity -> SAT.AtLeast -> m SAT.Lit
encodeAtLeastWithPolarity (Encoder base@(Totalizer.Encoder tseitin _) strategy) polarity =
case strategy of
Naive -> encodeAtLeastNaive tseitin
ParallelCounter -> encodeAtLeastParallelCounter tseitin
SequentialCounter -> \(lhs,rhs) -> BDD.encodePBLinAtLeastBDD tseitin ([(1,l) | l <- lhs], fromIntegral rhs)
Totalizer -> Totalizer.encodeAtLeast base
Naive -> encodeAtLeastWithPolarityNaive tseitin polarity
ParallelCounter -> encodeAtLeastWithPolarityParallelCounter tseitin polarity
SequentialCounter -> \(lhs,rhs) -> BDD.encodePBLinAtLeastWithPolarityBDD tseitin polarity ([(1,l) | l <- lhs], fromIntegral rhs)
Totalizer -> Totalizer.encodeAtLeastWithPolarity base polarity
14 changes: 7 additions & 7 deletions src/ToySolver/SAT/Encoder/Cardinality/Internal/Naive.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
-----------------------------------------------------------------------------
module ToySolver.SAT.Encoder.Cardinality.Internal.Naive
( addAtLeastNaive
, encodeAtLeastNaive
, encodeAtLeastWithPolarityNaive
) where

import Control.Monad.Primitive
import qualified ToySolver.SAT.Types as SAT
import qualified ToySolver.SAT.Encoder.Tseitin as Tseitin
import ToySolver.SAT.Encoder.Tseitin (Polarity ())

addAtLeastNaive :: PrimMonad m => Tseitin.Encoder m -> SAT.AtLeast -> m ()
addAtLeastNaive enc (lhs,rhs) = do
Expand All @@ -27,15 +28,14 @@ addAtLeastNaive enc (lhs,rhs) = do
else do
mapM_ (SAT.addClause enc) (comb (n - rhs + 1) lhs)

-- TODO: consider polarity
encodeAtLeastNaive :: PrimMonad m => Tseitin.Encoder m -> SAT.AtLeast -> m SAT.Lit
encodeAtLeastNaive enc (lhs,rhs) = do
encodeAtLeastWithPolarityNaive :: PrimMonad m => Tseitin.Encoder m -> Polarity -> SAT.AtLeast -> m SAT.Lit
encodeAtLeastWithPolarityNaive enc polarity (lhs,rhs) = do
let n = length lhs
if n < rhs then do
Tseitin.encodeDisj enc []
Tseitin.encodeDisjWithPolarity enc polarity []
else do
ls <- mapM (Tseitin.encodeDisj enc) (comb (n - rhs + 1) lhs)
Tseitin.encodeConj enc ls
ls <- mapM (Tseitin.encodeDisjWithPolarity enc polarity) (comb (n - rhs + 1) lhs)
Tseitin.encodeConjWithPolarity enc polarity ls

comb :: Int -> [a] -> [[a]]
comb 0 _ = [[]]
Expand Down
31 changes: 15 additions & 16 deletions src/ToySolver/SAT/Encoder/Cardinality/Internal/ParallelCounter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
-----------------------------------------------------------------------------
module ToySolver.SAT.Encoder.Cardinality.Internal.ParallelCounter
( addAtLeastParallelCounter
, encodeAtLeastParallelCounter
, encodeAtLeastWithPolarityParallelCounter
) where

import Control.Monad.Primitive
Expand All @@ -28,21 +28,20 @@ import qualified ToySolver.SAT.Encoder.Tseitin as Tseitin

addAtLeastParallelCounter :: PrimMonad m => Tseitin.Encoder m -> SAT.AtLeast -> m ()
addAtLeastParallelCounter enc constr = do
l <- encodeAtLeastParallelCounter enc constr
l <- encodeAtLeastWithPolarityParallelCounter enc Tseitin.polarityPos constr
SAT.addClause enc [l]

-- TODO: consider polarity
encodeAtLeastParallelCounter :: forall m. PrimMonad m => Tseitin.Encoder m -> SAT.AtLeast -> m SAT.Lit
encodeAtLeastParallelCounter enc (lhs,rhs) = do
encodeAtLeastWithPolarityParallelCounter :: forall m. PrimMonad m => Tseitin.Encoder m -> Tseitin.Polarity -> SAT.AtLeast -> m SAT.Lit
encodeAtLeastWithPolarityParallelCounter enc polarity (lhs,rhs) = do
if rhs <= 0 then
Tseitin.encodeConj enc []
Tseitin.encodeConjWithPolarity enc polarity []
else if length lhs < rhs then
Tseitin.encodeDisj enc []
Tseitin.encodeDisjWithPolarity enc polarity []
else do
let rhs_bits = bits (fromIntegral rhs)
(cnt, overflowBits) <- encodeSumParallelCounter enc (length rhs_bits) lhs
isGE <- encodeGE enc cnt rhs_bits
Tseitin.encodeDisj enc $ isGE : overflowBits
isGE <- encodeGE enc polarity cnt rhs_bits
Tseitin.encodeDisjWithPolarity enc polarity $ isGE : overflowBits
where
bits :: Integer -> [Bool]
bits n = f n 0
Expand Down Expand Up @@ -81,17 +80,17 @@ encodeSumParallelCounter enc w lits = do

runStateT (f (V.fromList lits)) []

encodeGE :: forall m. PrimMonad m => Tseitin.Encoder m -> [SAT.Lit] -> [Bool] -> m SAT.Lit
encodeGE enc lhs rhs = do
encodeGE :: forall m. PrimMonad m => Tseitin.Encoder m -> Tseitin.Polarity -> [SAT.Lit] -> [Bool] -> m SAT.Lit
encodeGE enc polarity lhs rhs = do
let f :: [SAT.Lit] -> [Bool] -> SAT.Lit -> m SAT.Lit
f [] [] r = return r
f [] (True : _) _ = Tseitin.encodeDisj enc [] -- false
f [] (True : _) _ = Tseitin.encodeDisjWithPolarity enc polarity [] -- false
f [] (False : bs) r = f [] bs r
f (l : ls) (True : bs) r = do
f ls bs =<< Tseitin.encodeConj enc [l, r]
f ls bs =<< Tseitin.encodeConjWithPolarity enc polarity [l, r]
f (l : ls) (False : bs) r = do
f ls bs =<< Tseitin.encodeDisj enc [l, r]
f ls bs =<< Tseitin.encodeDisjWithPolarity enc polarity [l, r]
f (l : ls) [] r = do
f ls [] =<< Tseitin.encodeDisj enc [l, r]
t <- Tseitin.encodeConj enc [] -- true
f ls [] =<< Tseitin.encodeDisjWithPolarity enc polarity [l, r]
t <- Tseitin.encodeConjWithPolarity enc polarity [] -- true
f lhs rhs t
23 changes: 11 additions & 12 deletions src/ToySolver/SAT/Encoder/Cardinality/Internal/Totalizer.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ module ToySolver.SAT.Encoder.Cardinality.Internal.Totalizer
, evalDefinitions

, addAtLeast
, encodeAtLeast
, encodeAtLeastWithPolarity

, addCardinality
, encodeCardinality
, encodeCardinalityWithPolarity

, encodeSum
) where
Expand Down Expand Up @@ -94,25 +94,24 @@ addCardinality enc lits (lb, ub) = do
forM_ (drop ub lits') $ \l -> SAT.addClause enc [- l]


-- TODO: consider polarity
encodeAtLeast :: PrimMonad m => Encoder m -> SAT.AtLeast -> m SAT.Lit
encodeAtLeast enc (lhs,rhs) = do
encodeCardinality enc lhs (rhs, length lhs)

encodeAtLeastWithPolarity :: PrimMonad m => Encoder m -> Tseitin.Polarity -> SAT.AtLeast -> m SAT.Lit
encodeAtLeastWithPolarity enc polarity (lhs,rhs) = do
encodeCardinalityWithPolarity enc polarity lhs (rhs, length lhs)

-- TODO: consider polarity
encodeCardinality :: PrimMonad m => Encoder m -> [SAT.Lit] -> (Int, Int) -> m SAT.Lit
encodeCardinality enc@(Encoder tseitin _) lits (lb, ub) = do

encodeCardinalityWithPolarity :: PrimMonad m => Encoder m -> Tseitin.Polarity -> [SAT.Lit] -> (Int, Int) -> m SAT.Lit
encodeCardinalityWithPolarity enc@(Encoder tseitin _) polarity lits (lb, ub) = do
let n = length lits
if lb <= 0 && n <= ub then
Tseitin.encodeConj tseitin []
Tseitin.encodeConjWithPolarity tseitin polarity []
else if n < lb || ub < 0 then
Tseitin.encodeDisj tseitin []
Tseitin.encodeDisjWithPolarity tseitin polarity []
else do
lits' <- encodeSum enc lits
forM_ (zip lits' (tail lits')) $ \(l1, l2) -> do
SAT.addClause enc [-l2, l1] -- l2→l1 or equivalently ¬l1→¬l2
Tseitin.encodeConj tseitin $
Tseitin.encodeConjWithPolarity tseitin polarity $
[lits' !! (lb - 1) | lb > 0] ++ [- (lits' !! (ub + 1 - 1)) | ub < n]


Expand Down
35 changes: 24 additions & 11 deletions src/ToySolver/SAT/Encoder/PB.hs
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,24 @@
--
-----------------------------------------------------------------------------
module ToySolver.SAT.Encoder.PB
( Encoder
( Encoder (..)
, newEncoder
, newEncoderWithStrategy
, encodePBLinAtLeast
, encodePBLinAtLeastWithPolarity

-- * Configulation
, Strategy (..)
, showStrategy
, parseStrategy

-- * Polarity
, Polarity (..)
, negatePolarity
, polarityPos
, polarityNeg
, polarityBoth
, polarityNone
) where

import Control.Monad.Primitive
Expand All @@ -38,9 +47,10 @@ import Data.Default.Class
import qualified ToySolver.SAT.Types as SAT
import qualified ToySolver.SAT.Encoder.Cardinality as Card
import qualified ToySolver.SAT.Encoder.Tseitin as Tseitin
import ToySolver.SAT.Encoder.Tseitin (Polarity (..), negatePolarity, polarityPos, polarityNeg, polarityBoth, polarityNone)
import ToySolver.SAT.Encoder.PB.Internal.Adder (addPBLinAtLeastAdder, encodePBLinAtLeastAdder)
import ToySolver.SAT.Encoder.PB.Internal.BCCNF (addPBLinAtLeastBCCNF, encodePBLinAtLeastBCCNF)
import ToySolver.SAT.Encoder.PB.Internal.BDD (addPBLinAtLeastBDD, encodePBLinAtLeastBDD)
import ToySolver.SAT.Encoder.PB.Internal.BCCNF (addPBLinAtLeastBCCNF, encodePBLinAtLeastWithPolarityBCCNF)
import ToySolver.SAT.Encoder.PB.Internal.BDD (addPBLinAtLeastBDD, encodePBLinAtLeastWithPolarityBDD)
import ToySolver.SAT.Encoder.PB.Internal.Sorter (addPBLinAtLeastSorter, encodePBLinAtLeastSorter)

data Encoder m = Encoder (Card.Encoder m) Strategy
Expand Down Expand Up @@ -101,8 +111,11 @@ instance PrimMonad m => SAT.AddPBLin m (Encoder m) where
addPBLinAtLeast' enc (lhs',rhs')

encodePBLinAtLeast :: forall m. PrimMonad m => Encoder m -> SAT.PBLinAtLeast -> m SAT.Lit
encodePBLinAtLeast enc constr =
encodePBLinAtLeast' enc $ SAT.normalizePBLinAtLeast constr
encodePBLinAtLeast enc constr = encodePBLinAtLeastWithPolarity enc polarityBoth constr

encodePBLinAtLeastWithPolarity :: forall m. PrimMonad m => Encoder m -> Polarity -> SAT.PBLinAtLeast -> m SAT.Lit
encodePBLinAtLeastWithPolarity enc polarity constr =
encodePBLinAtLeastWithPolarity' enc polarity $ SAT.normalizePBLinAtLeast constr

-- -----------------------------------------------------------------------

Expand All @@ -115,14 +128,14 @@ addPBLinAtLeast' (Encoder card strategy) = do
BCCNF -> addPBLinAtLeastBCCNF card
_ -> addPBLinAtLeastBDD tseitin

encodePBLinAtLeast' :: PrimMonad m => Encoder m -> SAT.PBLinAtLeast -> m SAT.Lit
encodePBLinAtLeast' (Encoder card strategy) = do
encodePBLinAtLeastWithPolarity' :: PrimMonad m => Encoder m -> Polarity -> SAT.PBLinAtLeast -> m SAT.Lit
encodePBLinAtLeastWithPolarity' (Encoder card strategy) polarity constr = do
let tseitin = Card.getTseitinEncoder card
case strategy of
Adder -> encodePBLinAtLeastAdder tseitin
Sorter -> encodePBLinAtLeastSorter tseitin
BCCNF -> encodePBLinAtLeastBCCNF card
_ -> encodePBLinAtLeastBDD tseitin
Adder -> encodePBLinAtLeastAdder tseitin constr
Sorter -> encodePBLinAtLeastSorter tseitin constr
BCCNF -> encodePBLinAtLeastWithPolarityBCCNF card polarity constr
_ -> encodePBLinAtLeastWithPolarityBDD tseitin polarity constr

-- -----------------------------------------------------------------------

10 changes: 5 additions & 5 deletions src/ToySolver/SAT/Encoder/PB/Internal/BCCNF.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ module ToySolver.SAT.Encoder.PB.Internal.BCCNF
(
-- * Monadic interface
addPBLinAtLeastBCCNF
, encodePBLinAtLeastBCCNF
, encodePBLinAtLeastWithPolarityBCCNF

-- * High-level pure encoder
, encode
Expand Down Expand Up @@ -228,11 +228,11 @@ addPBLinAtLeastBCCNF enc constr = do
forM_ (encode constr) $ \clause -> do
addClause enc =<< mapM (Card.encodeAtLeast enc) clause

encodePBLinAtLeastBCCNF :: PrimMonad m => Card.Encoder m -> PBLinAtLeast -> m Lit
encodePBLinAtLeastBCCNF enc constr = do
encodePBLinAtLeastWithPolarityBCCNF :: PrimMonad m => Card.Encoder m -> Tseitin.Polarity -> PBLinAtLeast -> m Lit
encodePBLinAtLeastWithPolarityBCCNF enc polarity constr = do
let tseitin = Card.getTseitinEncoder enc
ls <- forM (encode constr) $ \clause -> do
Tseitin.encodeDisjWithPolarity tseitin Tseitin.polarityPos =<< mapM (Card.encodeAtLeast enc) clause
Tseitin.encodeConjWithPolarity tseitin Tseitin.polarityPos ls
Tseitin.encodeDisjWithPolarity tseitin polarity =<< mapM (Card.encodeAtLeastWithPolarity enc polarity) clause
Tseitin.encodeConjWithPolarity tseitin polarity ls

-- ------------------------------------------------------------------------
16 changes: 8 additions & 8 deletions src/ToySolver/SAT/Encoder/PB/Internal/BDD.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
-----------------------------------------------------------------------------
module ToySolver.SAT.Encoder.PB.Internal.BDD
( addPBLinAtLeastBDD
, encodePBLinAtLeastBDD
, encodePBLinAtLeastWithPolarityBDD
) where

import Control.Monad.State.Strict
Expand All @@ -34,17 +34,17 @@ import qualified ToySolver.SAT.Encoder.Tseitin as Tseitin

addPBLinAtLeastBDD :: PrimMonad m => Tseitin.Encoder m -> SAT.PBLinAtLeast -> m ()
addPBLinAtLeastBDD enc constr = do
l <- encodePBLinAtLeastBDD enc constr
l <- encodePBLinAtLeastWithPolarityBDD enc Tseitin.polarityPos constr
SAT.addClause enc [l]

encodePBLinAtLeastBDD :: forall m. PrimMonad m => Tseitin.Encoder m -> SAT.PBLinAtLeast -> m SAT.Lit
encodePBLinAtLeastBDD enc (lhs,rhs) = do
encodePBLinAtLeastWithPolarityBDD :: forall m. PrimMonad m => Tseitin.Encoder m -> Tseitin.Polarity -> SAT.PBLinAtLeast -> m SAT.Lit
encodePBLinAtLeastWithPolarityBDD enc polarity (lhs,rhs) = do
let lhs' = sortBy (flip (comparing fst)) lhs
flip evalStateT Map.empty $ do
let f :: SAT.PBLinSum -> Integer -> Integer -> StateT (Map (SAT.PBLinSum, Integer) SAT.Lit) m SAT.Lit
f xs rhs slack
| rhs <= 0 = lift $ Tseitin.encodeConj enc [] -- true
| slack < 0 = lift $ Tseitin.encodeDisj enc [] -- false
| rhs <= 0 = lift $ Tseitin.encodeConjWithPolarity enc polarity [] -- true
| slack < 0 = lift $ Tseitin.encodeDisjWithPolarity enc polarity [] -- false
| otherwise = do
m <- get
case Map.lookup (xs,rhs) m of
Expand All @@ -55,12 +55,12 @@ encodePBLinAtLeastBDD enc (lhs,rhs) = do
[(_,l)] -> return l
(c,l) : xs' -> do
thenLit <- f xs' (rhs - c) slack
l2 <- lift $ Tseitin.encodeConjWithPolarity enc Tseitin.polarityPos [l, thenLit]
l2 <- lift $ Tseitin.encodeConjWithPolarity enc polarity [l, thenLit]
l3 <- if c > slack then
return l2
else do
elseLit <- f xs' rhs (slack - c)
lift $ Tseitin.encodeDisjWithPolarity enc Tseitin.polarityPos [l2, elseLit]
lift $ Tseitin.encodeDisjWithPolarity enc polarity [l2, elseLit]
modify (Map.insert (xs,rhs) l3)
return l3
f lhs' rhs (sum [c | (c,_) <- lhs'] - rhs)

0 comments on commit c54b8bd

Please sign in to comment.