diff --git a/cadical/src/lib.rs b/cadical/src/lib.rs index d0c88651..c69ba06b 100644 --- a/cadical/src/lib.rs +++ b/cadical/src/lib.rs @@ -39,7 +39,10 @@ macro_rules! handle_oom { ($val:expr) => {{ let val = $val; if val == crate::OUT_OF_MEM { - return anyhow::Context::context(Err(rustsat::OutOfMemory), "cadical out of memory"); + return anyhow::Context::context( + Err(rustsat::OutOfMemory::ExternalApi), + "cadical out of memory", + ); } val }}; diff --git a/capi/src/lib.rs b/capi/src/lib.rs index 095fcb7d..688600df 100644 --- a/capi/src/lib.rs +++ b/capi/src/lib.rs @@ -63,6 +63,25 @@ pub mod encodings { fn n_clauses(&self) -> usize { self.n_clauses } + + fn extend_clauses(&mut self, cl_iter: T) -> Result<(), rustsat::OutOfMemory> + where + T: IntoIterator, + { + cl_iter.into_iter().for_each(|cl| { + cl.into_iter() + .for_each(|l| (self.ccol)(l.to_ipasir(), self.cdata)); + (self.ccol)(0, self.cdata); + }); + Ok(()) + } + + fn add_clause(&mut self, cl: Clause) -> Result<(), rustsat::OutOfMemory> { + cl.into_iter() + .for_each(|l| (self.ccol)(l.to_ipasir(), self.cdata)); + (self.ccol)(0, self.cdata); + Ok(()) + } } impl Extend for ClauseCollector { @@ -186,7 +205,9 @@ pub mod encodings { let mut collector = ClauseCollector::new(collector, collector_data); let mut var_manager = VarManager::new(n_vars_used); let mut boxed = unsafe { Box::from_raw(tot) }; - boxed.encode_ub_change(min_bound..=max_bound, &mut collector, &mut var_manager); + boxed + .encode_ub_change(min_bound..=max_bound, &mut collector, &mut var_manager) + .expect("clause collect returned out of memory"); Box::into_raw(boxed); } @@ -358,7 +379,9 @@ pub mod encodings { let mut collector = ClauseCollector::new(collector, collector_data); let mut var_manager = VarManager::new(n_vars_used); let mut boxed = unsafe { Box::from_raw(dpw) }; - boxed.encode_ub_change(min_bound..=max_bound, &mut collector, &mut var_manager); + boxed + .encode_ub_change(min_bound..=max_bound, &mut collector, &mut var_manager) + .expect("clause collector returned out of memory"); Box::into_raw(boxed); } diff --git a/glucose/src/lib.rs b/glucose/src/lib.rs index bb03055c..af4fc23e 100644 --- a/glucose/src/lib.rs +++ b/glucose/src/lib.rs @@ -80,7 +80,10 @@ macro_rules! handle_oom { ($val:expr) => {{ let val = $val; if val == crate::OUT_OF_MEM { - return anyhow::Context::context(Err(rustsat::OutOfMemory), "glucose out of memory"); + return anyhow::Context::context( + Err(rustsat::OutOfMemory::ExternalApi), + "glucose out of memory", + ); } val }}; diff --git a/minisat/src/lib.rs b/minisat/src/lib.rs index 799e3b18..695a32f3 100644 --- a/minisat/src/lib.rs +++ b/minisat/src/lib.rs @@ -80,7 +80,10 @@ macro_rules! handle_oom { ($val:expr) => {{ let val = $val; if val == crate::OUT_OF_MEM { - return anyhow::Context::context(Err(rustsat::OutOfMemory), "minisat out of memory"); + return anyhow::Context::context( + Err(rustsat::OutOfMemory::ExternalApi), + "minisat out of memory", + ); } val }}; diff --git a/pyapi/src/encodings.rs b/pyapi/src/encodings.rs index e7f1bba4..b5563651 100644 --- a/pyapi/src/encodings.rs +++ b/pyapi/src/encodings.rs @@ -19,6 +19,7 @@ use rustsat::{ }; use crate::{ + handle_oom, instances::{Cnf, VarManager}, types::Lit, }; @@ -89,12 +90,18 @@ impl Totalizer { /// Incrementally builds the totalizer encoding to that upper bounds /// in the range `max_ub..=min_ub` can be enforced. New variables will /// be taken from `var_manager`. - fn encode_ub(&mut self, max_ub: usize, min_ub: usize, var_manager: &mut VarManager) -> Cnf { + fn encode_ub( + &mut self, + max_ub: usize, + min_ub: usize, + var_manager: &mut VarManager, + ) -> PyResult { let mut cnf = RsCnf::new(); let var_manager: &mut BasicVarManager = var_manager.into(); - self.0 - .encode_ub_change(max_ub..=min_ub, &mut cnf, var_manager); - cnf.into() + handle_oom!(self + .0 + .encode_ub_change(max_ub..=min_ub, &mut cnf, var_manager)); + Ok(cnf.into()) } /// Gets assumptions to enforce the given upper bound. Make sure that @@ -165,12 +172,18 @@ impl GeneralizedTotalizer { /// Incrementally builds the GTE encoding to that upper bounds /// in the range `max_ub..=min_ub` can be enforced. New variables will /// be taken from `var_manager`. - fn encode_ub(&mut self, max_ub: usize, min_ub: usize, var_manager: &mut VarManager) -> Cnf { + fn encode_ub( + &mut self, + max_ub: usize, + min_ub: usize, + var_manager: &mut VarManager, + ) -> PyResult { let mut cnf = RsCnf::new(); let var_manager: &mut BasicVarManager = var_manager.into(); - self.0 - .encode_ub_change(max_ub..=min_ub, &mut cnf, var_manager); - cnf.into() + handle_oom!(self + .0 + .encode_ub_change(max_ub..=min_ub, &mut cnf, var_manager)); + Ok(cnf.into()) } /// Gets assumptions to enforce the given upper bound. Make sure that @@ -235,12 +248,18 @@ impl DynamicPolyWatchdog { /// Incrementally builds the DPW encoding to that upper bounds /// in the range `max_ub..=min_ub` can be enforced. New variables will /// be taken from `var_manager`. - fn encode_ub(&mut self, max_ub: usize, min_ub: usize, var_manager: &mut VarManager) -> Cnf { + fn encode_ub( + &mut self, + max_ub: usize, + min_ub: usize, + var_manager: &mut VarManager, + ) -> PyResult { let mut cnf = RsCnf::new(); let var_manager: &mut BasicVarManager = var_manager.into(); - self.0 - .encode_ub_change(max_ub..=min_ub, &mut cnf, var_manager); - cnf.into() + handle_oom!(self + .0 + .encode_ub_change(max_ub..=min_ub, &mut cnf, var_manager)); + Ok(cnf.into()) } /// Gets assumptions to enforce the given upper bound. Make sure that diff --git a/pyapi/src/lib.rs b/pyapi/src/lib.rs index f5d411bf..7d91d6a3 100644 --- a/pyapi/src/lib.rs +++ b/pyapi/src/lib.rs @@ -68,3 +68,13 @@ fn rustsat(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { Ok(()) } + +macro_rules! handle_oom { + ($result:expr) => {{ + match $result { + Ok(val) => val, + Err(err) => return Err(pyo3::exceptions::PyMemoryError::new_err(format!("{}", err))), + } + }}; +} +pub(crate) use handle_oom; diff --git a/rustsat/examples/profiling.rs b/rustsat/examples/profiling.rs index 0cb4f122..954065fd 100644 --- a/rustsat/examples/profiling.rs +++ b/rustsat/examples/profiling.rs @@ -73,7 +73,7 @@ fn build_full_ub>(lits: &[(Lit, usi let mut var_manager = BasicVarManager::default(); var_manager.increase_next_free(max_var + 1); let mut collector = Cnf::new(); - enc.encode_ub(.., &mut collector, &mut var_manager); + enc.encode_ub(.., &mut collector, &mut var_manager).unwrap(); } fn main() { diff --git a/rustsat/src/encodings.rs b/rustsat/src/encodings.rs index 9e00e7b0..7ca3f55b 100644 --- a/rustsat/src/encodings.rs +++ b/rustsat/src/encodings.rs @@ -4,7 +4,7 @@ use thiserror::Error; -use crate::types::Lit; +use crate::types::{Clause, Lit}; pub mod am1; pub mod atomics; @@ -13,9 +13,21 @@ pub mod pb; /// Trait for collecting clauses. Mainly used when generating encodings and implemented by /// [`crate::instances::Cnf`], and solvers. -pub trait CollectClauses: Extend { +pub trait CollectClauses { /// Gets the number of clauses in the collection fn n_clauses(&self) -> usize; + /// Extends the clause collector with an iterator of clauses + /// + /// # Error + /// + /// If the collector runs out of memory, return an [`crate::OutOfMemory`] error. + fn extend_clauses(&mut self, cl_iter: T) -> Result<(), crate::OutOfMemory> + where + T: IntoIterator; + /// Adds one clause to the collector + fn add_clause(&mut self, cl: Clause) -> Result<(), crate::OutOfMemory> { + self.extend_clauses([cl]) + } } /// Errors from encodings diff --git a/rustsat/src/encodings/am1.rs b/rustsat/src/encodings/am1.rs index 63648aa4..1190e021 100644 --- a/rustsat/src/encodings/am1.rs +++ b/rustsat/src/encodings/am1.rs @@ -20,7 +20,7 @@ //! enc.encode(&mut encoding, &mut var_manager).unwrap(); //! ``` -use super::{CollectClauses, Error}; +use super::CollectClauses; use crate::instances::ManageVars; mod pairwise; @@ -35,7 +35,7 @@ pub trait Encode { &mut self, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) -> Result<(), Error> + ) -> Result<(), crate::OutOfMemory> where Col: CollectClauses; } diff --git a/rustsat/src/encodings/am1/pairwise.rs b/rustsat/src/encodings/am1/pairwise.rs index e4513a56..dac0efb0 100644 --- a/rustsat/src/encodings/am1/pairwise.rs +++ b/rustsat/src/encodings/am1/pairwise.rs @@ -7,7 +7,7 @@ use super::Encode; use crate::{ clause, - encodings::{CollectClauses, EncodeStats, Error, IterInputs}, + encodings::{CollectClauses, EncodeStats, IterInputs}, instances::ManageVars, types::Lit, }; @@ -34,7 +34,7 @@ impl Encode for Pairwise { &mut self, collector: &mut Col, _var_manager: &mut dyn ManageVars, - ) -> Result<(), Error> + ) -> Result<(), crate::OutOfMemory> where Col: CollectClauses, { @@ -43,7 +43,7 @@ impl Encode for Pairwise { let clause_iter = (0..self.in_lits.len()).flat_map(|first| { (first + 1..self.in_lits.len()).map(move |second| clause![!lits[first], !lits[second]]) }); - collector.extend(clause_iter); + collector.extend_clauses(clause_iter)?; self.n_clauses = collector.n_clauses() - prev_clauses; Ok(()) } diff --git a/rustsat/src/encodings/card.rs b/rustsat/src/encodings/card.rs index 9ba9f4d9..a0825641 100644 --- a/rustsat/src/encodings/card.rs +++ b/rustsat/src/encodings/card.rs @@ -67,7 +67,8 @@ pub trait BoundUpper: Encode { range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: RangeBounds; /// Returns assumptions/units for enforcing an upper bound (`sum of lits <= @@ -76,24 +77,28 @@ pub trait BoundUpper: Encode { /// [`Error::NotEncoded`] will be returned. fn enforce_ub(&self, ub: usize) -> Result, Error>; /// Encodes an upper bound cardinality constraint to CNF + /// + /// # Errors + /// + /// Either an [`Error`] of [`crate::OutOfMemory`] fn encode_ub_constr( constr: CardUBConstr, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) -> Result<(), Error> + ) -> anyhow::Result<()> where Col: CollectClauses, Self: FromIterator + Sized, { let (lits, ub) = constr.decompose(); let mut enc = Self::from_iter(lits); - enc.encode_ub(ub..ub + 1, collector, var_manager); - collector.extend( + enc.encode_ub(ub..ub + 1, collector, var_manager)?; + collector.extend_clauses( enc.enforce_ub(ub) .unwrap() .into_iter() .map(|unit| clause![unit]), - ); + )?; Ok(()) } } @@ -110,7 +115,8 @@ pub trait BoundLower: Encode { range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: RangeBounds; /// Returns assumptions/units for enforcing a lower bound (`sum of lits >= @@ -121,24 +127,28 @@ pub trait BoundLower: Encode { /// returned. fn enforce_lb(&self, lb: usize) -> Result, Error>; /// Encodes a lower bound cardinality constraint to CNF + /// + /// # Errors + /// + /// Either an [`Error`] of [`crate::OutOfMemory`] fn encode_lb_constr( constr: CardLBConstr, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) -> Result<(), Error> + ) -> anyhow::Result<()> where Col: CollectClauses, Self: FromIterator + Sized, { let (lits, lb) = constr.decompose(); let mut enc = Self::from_iter(lits); - enc.encode_lb(lb..lb + 1, collector, var_manager); - collector.extend( + enc.encode_lb(lb..lb + 1, collector, var_manager)?; + collector.extend_clauses( enc.enforce_lb(lb) .unwrap() .into_iter() .map(|unit| clause![unit]), - ); + )?; Ok(()) } } @@ -154,12 +164,14 @@ pub trait BoundBoth: BoundUpper + BoundLower { range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: RangeBounds + Clone, { - self.encode_ub(range.clone(), collector, var_manager); - self.encode_lb(range, collector, var_manager); + self.encode_ub(range.clone(), collector, var_manager)?; + self.encode_lb(range, collector, var_manager)?; + Ok(()) } /// Returns assumptions for enforcing an equality (`sum of lits = b`) or an /// error if the encoding does not support one of the two required bound @@ -174,32 +186,40 @@ pub trait BoundBoth: BoundUpper + BoundLower { Ok(assumps) } /// Encodes an equality cardinality constraint to CNF + /// + /// # Errors + /// + /// Either an [`Error`] of [`crate::OutOfMemory`] fn encode_eq_constr( constr: CardEQConstr, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) -> Result<(), Error> + ) -> anyhow::Result<()> where Col: CollectClauses, Self: FromIterator + Sized, { let (lits, b) = constr.decompose(); let mut enc = Self::from_iter(lits); - enc.encode_both(b..b + 1, collector, var_manager); - collector.extend( + enc.encode_both(b..b + 1, collector, var_manager)?; + collector.extend_clauses( enc.enforce_eq(b) .unwrap() .into_iter() .map(|unit| clause![unit]), - ); + )?; Ok(()) } /// Encodes any cardinality constraint to CNF + /// + /// # Errors + /// + /// Either an [`Error`] of [`crate::OutOfMemory`] fn encode_constr( constr: CardConstraint, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) -> Result<(), Error> + ) -> anyhow::Result<()> where Col: CollectClauses, Self: FromIterator + Sized, @@ -235,7 +255,8 @@ pub trait BoundUpperIncremental: BoundUpper + EncodeIncremental { range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: RangeBounds; } @@ -252,7 +273,8 @@ pub trait BoundLowerIncremental: BoundLower + EncodeIncremental { range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: RangeBounds; } @@ -268,11 +290,12 @@ pub trait BoundBothIncremental: BoundUpperIncremental + BoundLowerIncremental { range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: RangeBounds + Clone, { - self.encode_ub_change(range.clone(), collector, var_manager); + self.encode_ub_change(range.clone(), collector, var_manager)?; self.encode_lb_change(range, collector, var_manager) } } @@ -330,7 +353,7 @@ pub fn default_encode_cardinality_constraint( constr: CardConstraint, collector: &mut Col, var_manager: &mut dyn ManageVars, -) { +) -> Result<(), crate::OutOfMemory> { encode_cardinality_constraint::(constr, collector, var_manager) } @@ -339,27 +362,28 @@ pub fn encode_cardinality_constraint, Col: Col constr: CardConstraint, collector: &mut Col, var_manager: &mut dyn ManageVars, -) { +) -> Result<(), crate::OutOfMemory> { if constr.is_tautology() { - return; + return Ok(()); } if constr.is_unsat() { - collector.extend([Clause::new()]); - return; + return collector.add_clause(Clause::new()); } if constr.is_positive_assignment() { - collector.extend(constr.into_lits().into_iter().map(|lit| clause![lit])); - return; + return collector.extend_clauses(constr.into_lits().into_iter().map(|lit| clause![lit])); } if constr.is_negative_assignment() { - collector.extend(constr.into_lits().into_iter().map(|lit| clause![!lit])); - return; + return collector.extend_clauses(constr.into_lits().into_iter().map(|lit| clause![!lit])); } if constr.is_clause() { - collector.extend([constr.into_clause().unwrap()]); - return; + return collector.add_clause(constr.into_clause().unwrap()); + } + match CE::encode_constr(constr, collector, var_manager) { + Ok(_) => Ok(()), + Err(err) => Err(err + .downcast::() + .expect("unexpected error when encoding constraint")), } - CE::encode_constr(constr, collector, var_manager).unwrap() } fn prepare_ub_range>(enc: &Enc, range: R) -> Range { diff --git a/rustsat/src/encodings/card/dbtotalizer.rs b/rustsat/src/encodings/card/dbtotalizer.rs index a6f03071..bebfa304 100644 --- a/rustsat/src/encodings/card/dbtotalizer.rs +++ b/rustsat/src/encodings/card/dbtotalizer.rs @@ -98,7 +98,12 @@ impl EncodeIncremental for DbTotalizer { } impl BoundUpper for DbTotalizer { - fn encode_ub(&mut self, range: R, collector: &mut Col, var_manager: &mut dyn ManageVars) + fn encode_ub( + &mut self, + range: R, + collector: &mut Col, + var_manager: &mut dyn ManageVars, + ) -> Result<(), crate::OutOfMemory> where Col: CollectClauses, R: RangeBounds, @@ -140,24 +145,26 @@ impl BoundUpperIncremental for DbTotalizer { range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: RangeBounds, { let range = super::prepare_ub_range(self, range); if range.is_empty() { - return; + return Ok(()); } self.extend_tree(); if let Some(id) = self.root { let n_vars_before = var_manager.n_used(); let n_clauses_before = collector.n_clauses(); for idx in range { - self.db.define_pos_tot(id, idx, collector, var_manager); + self.db.define_pos_tot(id, idx, collector, var_manager)?; } self.n_clauses += collector.n_clauses() - n_clauses_before; self.n_vars += var_manager.n_used() - n_vars_before; - } + }; + Ok(()) } } @@ -635,7 +642,7 @@ impl TotDb { val: usize, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) -> Option + ) -> Result, crate::OutOfMemory> where Col: CollectClauses, { @@ -645,23 +652,28 @@ impl TotDb { Node::Leaf(lit) => { debug_assert_eq!(val, 1); if val != 1 { - return None; + return Ok(None); } - Some(*lit) + Ok(Some(*lit)) } Node::Unit(node) => { if val > node.lits.len() || val == 0 { - return None; + return Ok(None); } // Check if already encoded if let LitData::Lit { lit, enc_pos, .. } = node.lits[val - 1] { if enc_pos { - return Some(lit); + return Ok(Some(lit)); } } - Some(self.define_pos_tot(id, val - 1, collector, var_manager)) + Ok(Some(self.define_pos_tot( + id, + val - 1, + collector, + var_manager, + )?)) } Node::General(node) => { // Check if already encoded @@ -670,10 +682,10 @@ impl TotDb { lit, enc_pos: true, .. } = lit_data { - return Some(*lit); + return Ok(Some(*lit)); } } else { - return None; + return Ok(None); } debug_assert!(val <= node.max_val); @@ -694,16 +706,16 @@ impl TotDb { // Propagate value if lcon.is_possible(val) && lcon.rev_map(val) <= self[lcon.id].max_val() { if let Some(llit) = - self.define_pos(lcon.id, lcon.rev_map(val), collector, var_manager) + self.define_pos(lcon.id, lcon.rev_map(val), collector, var_manager)? { - collector.extend([atomics::lit_impl_lit(llit, olit)]); + collector.add_clause(atomics::lit_impl_lit(llit, olit))?; } } if rcon.is_possible(val) && rcon.rev_map(val) <= self[rcon.id].max_val() { if let Some(rlit) = - self.define_pos(rcon.id, rcon.rev_map(val), collector, var_manager) + self.define_pos(rcon.id, rcon.rev_map(val), collector, var_manager)? { - collector.extend([atomics::lit_impl_lit(rlit, olit)]); + collector.add_clause(atomics::lit_impl_lit(rlit, olit))?; }; } @@ -717,15 +729,16 @@ impl TotDb { let rval_rev = rcon.rev_map(rval); if rcon.is_possible(rval) && rval_rev <= rmax { if let Some(rlit) = - self.define_pos(rcon.id, rval_rev, collector, var_manager) + self.define_pos(rcon.id, rval_rev, collector, var_manager)? { debug_assert!( lcon.len_limit.is_none() || lcon.offset() + 1 == lval ); let llit = self - .define_pos(lcon.id, lval, collector, var_manager) + .define_pos(lcon.id, lval, collector, var_manager)? .unwrap(); - collector.extend([atomics::cube_impl_lit(&[llit, rlit], olit)]); + collector + .add_clause(atomics::cube_impl_lit(&[llit, rlit], olit))?; } } } @@ -737,7 +750,7 @@ impl TotDb { LitData::Lit { enc_pos, .. } => *enc_pos = true, }; - Some(olit) + Ok(Some(olit)) } } } @@ -748,7 +761,7 @@ impl TotDb { idx: usize, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) -> Lit + ) -> Result where Col: CollectClauses, { @@ -756,7 +769,7 @@ impl TotDb { debug_assert!(idx < node.max_val()); if node.is_leaf() { debug_assert_eq!(idx, 0); - return node[1]; + return Ok(node[1]); } let lcon = node.left().unwrap(); let rcon = node.right().unwrap(); @@ -769,7 +782,7 @@ impl TotDb { // Check if already encoded if let LitData::Lit { lit, enc_pos, .. } = node.lits[idx] { if enc_pos { - return lit; + return Ok(lit); } } @@ -795,10 +808,10 @@ impl TotDb { // Encode children (recurse) for lidx in l_min_idx..=l_max_idx { - self.define_pos_tot(lcon.id, con_idx(lidx, lcon), collector, var_manager); + self.define_pos_tot(lcon.id, con_idx(lidx, lcon), collector, var_manager)?; } for ridx in r_min_idx..=r_max_idx { - self.define_pos_tot(rcon.id, con_idx(ridx, rcon), collector, var_manager); + self.define_pos_tot(rcon.id, con_idx(ridx, rcon), collector, var_manager)?; } // Reserve variable for this node, if needed @@ -832,16 +845,16 @@ impl TotDb { // Encode this node if l_max_idx == idx { - collector.extend([atomics::lit_impl_lit( + collector.add_clause(atomics::lit_impl_lit( *llits[con_idx(idx, lcon)].lit().unwrap(), olit, - )]); + ))?; } if r_max_idx == idx { - collector.extend([atomics::lit_impl_lit( + collector.add_clause(atomics::lit_impl_lit( *rlits[con_idx(idx, rcon)].lit().unwrap(), olit, - )]); + ))?; } let clause_for_lidx = |lidx: usize| { let ridx = idx - lidx - 1; @@ -851,7 +864,7 @@ impl TotDb { atomics::cube_impl_lit(&[llit, rlit], olit) }; let clause_iter = (l_min_idx..cmp::min(l_max_idx + 1, idx)).map(clause_for_lidx); - collector.extend(clause_iter); + collector.extend_clauses(clause_iter)?; // Mark positive literal as encoded match &mut self[id].mut_unit().lits[idx] { @@ -859,7 +872,7 @@ impl TotDb { LitData::Lit { enc_pos, .. } => *enc_pos = true, }; - olit + Ok(olit) } /// Recursively reserves all variables in the subtree rooted at the given node @@ -1036,7 +1049,8 @@ pub mod referenced { range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: std::ops::RangeBounds, { @@ -1072,7 +1086,8 @@ pub mod referenced { range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: std::ops::RangeBounds, { @@ -1108,18 +1123,20 @@ pub mod referenced { range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: std::ops::RangeBounds, { let range = super::super::prepare_ub_range(self, range); if range.is_empty() { - return; + return Ok(()); } for idx in range { self.db - .define_pos_tot(self.root, idx, collector, var_manager); + .define_pos_tot(self.root, idx, collector, var_manager)?; } + Ok(()) } } @@ -1129,19 +1146,21 @@ pub mod referenced { range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: std::ops::RangeBounds, { let range = super::super::prepare_ub_range(self, range); if range.is_empty() { - return; + return Ok(()); } for idx in range { self.db .borrow_mut() - .define_pos_tot(self.root, idx, collector, var_manager); + .define_pos_tot(self.root, idx, collector, var_manager)?; } + Ok(()) } } } @@ -1168,22 +1187,26 @@ mod tests { var_manager.increase_next_free(var![4]); let mut cnf = Cnf::new(); - db.define_pos_tot(root, 0, &mut cnf, &mut var_manager); + db.define_pos_tot(root, 0, &mut cnf, &mut var_manager) + .unwrap(); debug_assert_eq!(cnf.len(), 6); db.reset_encoded(); let mut cnf = Cnf::new(); - db.define_pos_tot(root, 1, &mut cnf, &mut var_manager); + db.define_pos_tot(root, 1, &mut cnf, &mut var_manager) + .unwrap(); debug_assert_eq!(cnf.len(), 9); db.reset_encoded(); let mut cnf = Cnf::new(); - db.define_pos_tot(root, 2, &mut cnf, &mut var_manager); + db.define_pos_tot(root, 2, &mut cnf, &mut var_manager) + .unwrap(); debug_assert_eq!(cnf.len(), 8); db.reset_encoded(); let mut cnf = Cnf::new(); - db.define_pos_tot(root, 3, &mut cnf, &mut var_manager); + db.define_pos_tot(root, 3, &mut cnf, &mut var_manager) + .unwrap(); debug_assert_eq!(cnf.len(), 3); } @@ -1200,32 +1223,32 @@ mod tests { var_manager.increase_next_free(var![4]); let mut cnf = Cnf::new(); - db.define_pos(root, 1, &mut cnf, &mut var_manager); + db.define_pos(root, 1, &mut cnf, &mut var_manager).unwrap(); debug_assert_eq!(cnf.len(), 0); db.reset_encoded(); let mut cnf = Cnf::new(); - db.define_pos(root, 4, &mut cnf, &mut var_manager); + db.define_pos(root, 4, &mut cnf, &mut var_manager).unwrap(); debug_assert_eq!(cnf.len(), 3); db.reset_encoded(); let mut cnf = Cnf::new(); - db.define_pos(root, 7, &mut cnf, &mut var_manager); + db.define_pos(root, 7, &mut cnf, &mut var_manager).unwrap(); debug_assert_eq!(cnf.len(), 3); db.reset_encoded(); let mut cnf = Cnf::new(); - db.define_pos(root, 8, &mut cnf, &mut var_manager); + db.define_pos(root, 8, &mut cnf, &mut var_manager).unwrap(); debug_assert_eq!(cnf.len(), 2); db.reset_encoded(); let mut cnf = Cnf::new(); - db.define_pos(root, 15, &mut cnf, &mut var_manager); + db.define_pos(root, 15, &mut cnf, &mut var_manager).unwrap(); debug_assert_eq!(cnf.len(), 4); db.reset_encoded(); let mut cnf = Cnf::new(); - db.define_pos(root, 22, &mut cnf, &mut var_manager); + db.define_pos(root, 22, &mut cnf, &mut var_manager).unwrap(); debug_assert_eq!(cnf.len(), 3); } @@ -1242,32 +1265,32 @@ mod tests { var_manager.increase_next_free(var![3]); let mut cnf = Cnf::new(); - db.define_pos(root, 1, &mut cnf, &mut var_manager); + db.define_pos(root, 1, &mut cnf, &mut var_manager).unwrap(); debug_assert_eq!(cnf.len(), 2); db.reset_encoded(); let mut cnf = Cnf::new(); - db.define_pos(root, 2, &mut cnf, &mut var_manager); + db.define_pos(root, 2, &mut cnf, &mut var_manager).unwrap(); debug_assert_eq!(cnf.len(), 2); db.reset_encoded(); let mut cnf = Cnf::new(); - db.define_pos(root, 3, &mut cnf, &mut var_manager); + db.define_pos(root, 3, &mut cnf, &mut var_manager).unwrap(); debug_assert_eq!(cnf.len(), 3); db.reset_encoded(); let mut cnf = Cnf::new(); - db.define_pos(root, 4, &mut cnf, &mut var_manager); + db.define_pos(root, 4, &mut cnf, &mut var_manager).unwrap(); debug_assert_eq!(cnf.len(), 2); db.reset_encoded(); let mut cnf = Cnf::new(); - db.define_pos(root, 5, &mut cnf, &mut var_manager); + db.define_pos(root, 5, &mut cnf, &mut var_manager).unwrap(); debug_assert_eq!(cnf.len(), 2); db.reset_encoded(); let mut cnf = Cnf::new(); - db.define_pos(root, 6, &mut cnf, &mut var_manager); + db.define_pos(root, 6, &mut cnf, &mut var_manager).unwrap(); debug_assert_eq!(cnf.len(), 2); } @@ -1279,7 +1302,7 @@ mod tests { let mut var_manager = BasicVarManager::default(); var_manager.increase_next_free(var![4]); let mut cnf = Cnf::new(); - tot.encode_ub(0..5, &mut cnf, &mut var_manager); + tot.encode_ub(0..5, &mut cnf, &mut var_manager).unwrap(); assert_eq!(tot.depth(), 3); println!("len: {}, {:?}", cnf.len(), cnf); assert_eq!(cnf.len(), 14); @@ -1295,7 +1318,7 @@ mod tests { let mut var_manager = BasicVarManager::default(); var_manager.increase_next_free(var![4]); let mut cnf = Cnf::new(); - tot.encode_ub(3..4, &mut cnf, &mut var_manager); + tot.encode_ub(3..4, &mut cnf, &mut var_manager).unwrap(); assert_eq!(tot.depth(), 3); assert_eq!(cnf.len(), 3); assert_eq!(cnf.len(), tot.n_clauses()); @@ -1308,14 +1331,15 @@ mod tests { let mut var_manager = BasicVarManager::default(); var_manager.increase_next_free(var![4]); let mut cnf1 = Cnf::new(); - tot1.encode_ub(0..5, &mut cnf1, &mut var_manager); + tot1.encode_ub(0..5, &mut cnf1, &mut var_manager).unwrap(); let mut tot2 = DbTotalizer::default(); tot2.extend(vec![lit![0], lit![1], lit![2], lit![3]]); let mut var_manager = BasicVarManager::default(); var_manager.increase_next_free(var![4]); let mut cnf2 = Cnf::new(); - tot2.encode_ub(0..3, &mut cnf2, &mut var_manager); - tot2.encode_ub_change(0..5, &mut cnf2, &mut var_manager); + tot2.encode_ub(0..3, &mut cnf2, &mut var_manager).unwrap(); + tot2.encode_ub_change(0..5, &mut cnf2, &mut var_manager) + .unwrap(); assert_eq!(cnf1.len(), cnf2.len()); assert_eq!(cnf1.len(), tot1.n_clauses()); assert_eq!(cnf2.len(), tot2.n_clauses()); diff --git a/rustsat/src/encodings/card/simulators.rs b/rustsat/src/encodings/card/simulators.rs index ca8a1634..7f4b8ed5 100644 --- a/rustsat/src/encodings/card/simulators.rs +++ b/rustsat/src/encodings/card/simulators.rs @@ -121,7 +121,12 @@ impl BoundUpper for Inverted where CE: BoundLower, { - fn encode_ub(&mut self, range: R, collector: &mut Col, var_manager: &mut dyn ManageVars) + fn encode_ub( + &mut self, + range: R, + collector: &mut Col, + var_manager: &mut dyn ManageVars, + ) -> Result<(), crate::OutOfMemory> where Col: CollectClauses, R: RangeBounds, @@ -147,7 +152,12 @@ impl BoundLower for Inverted where CE: BoundUpper, { - fn encode_lb(&mut self, range: R, collector: &mut Col, var_manager: &mut dyn ManageVars) + fn encode_lb( + &mut self, + range: R, + collector: &mut Col, + var_manager: &mut dyn ManageVars, + ) -> Result<(), crate::OutOfMemory> where Col: CollectClauses, R: RangeBounds, @@ -178,7 +188,8 @@ where range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: RangeBounds, { @@ -199,7 +210,8 @@ where range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: RangeBounds, { @@ -325,7 +337,12 @@ where UBE: BoundUpper, LBE: BoundLower, { - fn encode_ub(&mut self, range: R, collector: &mut Col, var_manager: &mut dyn ManageVars) + fn encode_ub( + &mut self, + range: R, + collector: &mut Col, + var_manager: &mut dyn ManageVars, + ) -> Result<(), crate::OutOfMemory> where Col: CollectClauses, R: RangeBounds, @@ -343,7 +360,12 @@ where UBE: BoundUpper, LBE: BoundLower, { - fn encode_lb(&mut self, range: R, collector: &mut Col, var_manager: &mut dyn ManageVars) + fn encode_lb( + &mut self, + range: R, + collector: &mut Col, + var_manager: &mut dyn ManageVars, + ) -> Result<(), crate::OutOfMemory> where Col: CollectClauses, R: RangeBounds, @@ -366,7 +388,8 @@ where range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: RangeBounds, { @@ -384,7 +407,8 @@ where range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: RangeBounds, { diff --git a/rustsat/src/encodings/card/totalizer.rs b/rustsat/src/encodings/card/totalizer.rs index dc569f0a..e471cbcb 100644 --- a/rustsat/src/encodings/card/totalizer.rs +++ b/rustsat/src/encodings/card/totalizer.rs @@ -121,14 +121,19 @@ impl EncodeIncremental for Totalizer { } impl BoundUpper for Totalizer { - fn encode_ub(&mut self, range: R, collector: &mut Col, var_manager: &mut dyn ManageVars) + fn encode_ub( + &mut self, + range: R, + collector: &mut Col, + var_manager: &mut dyn ManageVars, + ) -> Result<(), crate::OutOfMemory> where Col: CollectClauses, R: RangeBounds, { let range = super::prepare_ub_range(self, range); if range.is_empty() { - return; + return Ok(()); }; self.extend_tree(); match &mut self.root { @@ -136,11 +141,12 @@ impl BoundUpper for Totalizer { Some(root) => { let n_vars_before = var_manager.n_used(); let n_clauses_before = collector.n_clauses(); - root.rec_encode_ub(range, collector, var_manager); + root.rec_encode_ub(range, collector, var_manager)?; self.n_clauses += collector.n_clauses() - n_clauses_before; self.n_vars += var_manager.n_used() - n_vars_before; } - } + }; + Ok(()) } fn enforce_ub(&self, ub: usize) -> Result, Error> { @@ -169,14 +175,19 @@ impl BoundUpper for Totalizer { } impl BoundLower for Totalizer { - fn encode_lb(&mut self, range: R, collector: &mut Col, var_manager: &mut dyn ManageVars) + fn encode_lb( + &mut self, + range: R, + collector: &mut Col, + var_manager: &mut dyn ManageVars, + ) -> Result<(), crate::OutOfMemory> where Col: CollectClauses, R: RangeBounds, { let range = super::prepare_lb_range(self, range); if range.is_empty() { - return; + return Ok(()); }; self.extend_tree(); match &mut self.root { @@ -184,11 +195,12 @@ impl BoundLower for Totalizer { Some(root) => { let n_vars_before = var_manager.n_used(); let n_clauses_before = collector.n_clauses(); - root.rec_encode_lb(range, collector, var_manager); + root.rec_encode_lb(range, collector, var_manager)?; self.n_clauses += collector.n_clauses() - n_clauses_before; self.n_vars += var_manager.n_used() - n_vars_before; } - } + }; + Ok(()) } fn enforce_lb(&self, lb: usize) -> Result, Error> { @@ -224,13 +236,14 @@ impl BoundUpperIncremental for Totalizer { range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: RangeBounds, { let range = super::prepare_ub_range(self, range); if range.is_empty() { - return; + return Ok(()); }; self.extend_tree(); match &mut self.root { @@ -238,11 +251,12 @@ impl BoundUpperIncremental for Totalizer { Some(root) => { let n_vars_before = var_manager.n_used(); let n_clauses_before = collector.n_clauses(); - root.rec_encode_ub_change(range, collector, var_manager); + root.rec_encode_ub_change(range, collector, var_manager)?; self.n_clauses += collector.n_clauses() - n_clauses_before; self.n_vars += var_manager.n_used() - n_vars_before; } - } + }; + Ok(()) } } @@ -252,13 +266,14 @@ impl BoundLowerIncremental for Totalizer { range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: RangeBounds, { let range = super::prepare_lb_range(self, range); if range.is_empty() { - return; + return Ok(()); }; self.extend_tree(); match &mut self.root { @@ -266,11 +281,12 @@ impl BoundLowerIncremental for Totalizer { Some(root) => { let n_vars_before = var_manager.n_used(); let n_clauses_before = collector.n_clauses(); - root.rec_encode_lb_change(range, collector, var_manager); + root.rec_encode_lb_change(range, collector, var_manager)?; self.n_clauses += collector.n_clauses() - n_clauses_before; self.n_vars += var_manager.n_used() - n_vars_before; } - } + }; + Ok(()) } } @@ -389,12 +405,13 @@ impl Node { range: Range, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, { let range = self.limit_range(range); if range.is_empty() { - return; + return Ok(()); } // Reserve vars if needed @@ -447,9 +464,11 @@ impl Node { (0..=right_lits.len()) .filter_map(move |right_val| clause_for_vals(left_val, right_val)) }); - collector.extend(clause_iter); + collector.extend_clauses(clause_iter)?; } - } + }; + + Ok(()) } /// Encodes the lower bound adder for this node in a given range. This @@ -460,12 +479,13 @@ impl Node { range: Range, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, { let range = self.limit_range(range); if range.is_empty() { - return; + return Ok(()); } // Reserve vars if needed @@ -525,9 +545,11 @@ impl Node { (0..=right_lits.len()) .filter_map(move |right_val| clause_for_vals(left_val, right_val)) }); - collector.extend(clause_iter); + collector.extend_clauses(clause_iter)?; } - } + }; + + Ok(()) } /// Encodes the upper bound adder from the children to this node in a given @@ -538,12 +560,13 @@ impl Node { range: Range, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, { let range = self.limit_range(range); if range.is_empty() { - return; + return Ok(()); } match self { @@ -560,16 +583,17 @@ impl Node { let left_range = Node::compute_required_range(range.clone(), right.max_val()); let right_range = Node::compute_required_range(range.clone(), left.max_val()); // Recurse - left.rec_encode_ub(left_range, collector, var_manager); - right.rec_encode_ub(right_range, collector, var_manager); + left.rec_encode_ub(left_range, collector, var_manager)?; + right.rec_encode_ub(right_range, collector, var_manager)?; // Ignore all previous encoding and encode from scratch let n_clauses_before = collector.n_clauses(); - self.encode_ub_range(range.clone(), collector, var_manager); + self.encode_ub_range(range.clone(), collector, var_manager)?; self.update_stats(range, lb_range, collector.n_clauses() - n_clauses_before); } } + Ok(()) } /// Encodes the lower bound adder from the children to this node in a given @@ -580,12 +604,13 @@ impl Node { range: Range, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, { let range = self.limit_range(range); if range.is_empty() { - return; + return Ok(()); } match self { @@ -602,16 +627,18 @@ impl Node { let left_range = Node::compute_required_range(range.clone(), right.max_val()); let right_range = Node::compute_required_range(range.clone(), left.max_val()); // Recurse - left.rec_encode_lb(left_range, collector, var_manager); - right.rec_encode_lb(right_range, collector, var_manager); + left.rec_encode_lb(left_range, collector, var_manager)?; + right.rec_encode_lb(right_range, collector, var_manager)?; // Ignore all previous encoding and encode from scratch let n_clauses_before = collector.n_clauses(); - self.encode_lb_range(range.clone(), collector, var_manager); + self.encode_lb_range(range.clone(), collector, var_manager)?; self.update_stats(ub_range, range, collector.n_clauses() - n_clauses_before); } - } + }; + + Ok(()) } /// Encodes the upper bound adder from the children to this node in a given @@ -621,12 +648,13 @@ impl Node { range: Range, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, { let range = self.limit_range(range); if range.is_empty() { - return; + return Ok(()); } match self { @@ -645,27 +673,29 @@ impl Node { let left_range = Node::compute_required_range(range.clone(), right.max_val()); let right_range = Node::compute_required_range(range.clone(), left.max_val()); // Recurse - left.rec_encode_ub_change(left_range, collector, var_manager); - right.rec_encode_ub_change(right_range, collector, var_manager); + left.rec_encode_ub_change(left_range, collector, var_manager)?; + right.rec_encode_ub_change(right_range, collector, var_manager)?; // Encode changes for current node let n_clauses_before = collector.n_clauses(); if ub_range.is_empty() { // First time encoding this node - self.encode_ub_range(range.clone(), collector, var_manager) + self.encode_ub_range(range.clone(), collector, var_manager)?; } else { // Part already encoded if range.start < ub_range.start { - self.encode_ub_range(range.start..ub_range.start, collector, var_manager); + self.encode_ub_range(range.start..ub_range.start, collector, var_manager)?; }; if range.end > ub_range.end { - self.encode_ub_range(ub_range.end..range.end, collector, var_manager); + self.encode_ub_range(ub_range.end..range.end, collector, var_manager)?; }; }; self.update_stats(range, lb_range, collector.n_clauses() - n_clauses_before); } - } + }; + + Ok(()) } /// Encodes the lower bound adder from the children to this node in a given @@ -675,7 +705,8 @@ impl Node { range: Range, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, { match self { @@ -694,27 +725,29 @@ impl Node { let left_range = Node::compute_required_range(range.clone(), right.max_val()); let right_range = Node::compute_required_range(range.clone(), left.max_val()); // Recurse - left.rec_encode_lb_change(left_range, collector, var_manager); - right.rec_encode_lb_change(right_range, collector, var_manager); + left.rec_encode_lb_change(left_range, collector, var_manager)?; + right.rec_encode_lb_change(right_range, collector, var_manager)?; // Encode changes for current node let n_clauses_before = collector.n_clauses(); if lb_range.is_empty() { // First time encoding this node - self.encode_lb_range(range.clone(), collector, var_manager) + self.encode_lb_range(range.clone(), collector, var_manager)?; } else { // Part already encoded if range.start < lb_range.start { - self.encode_lb_range(range.start..lb_range.start, collector, var_manager); + self.encode_lb_range(range.start..lb_range.start, collector, var_manager)?; }; if range.end > lb_range.end { - self.encode_lb_range(lb_range.end..range.end, collector, var_manager); + self.encode_lb_range(lb_range.end..range.end, collector, var_manager)?; }; }; self.update_stats(ub_range, range, collector.n_clauses() - n_clauses_before); } - } + }; + + Ok(()) } /// Reserves variables this node might need for indices in a given range. @@ -852,8 +885,10 @@ mod tests { let mut var_manager = BasicVarManager::default(); var_manager.increase_next_free(var![2]); let mut cnf = Cnf::new(); - node.encode_ub_range(0..3, &mut cnf, &mut var_manager); - node.encode_lb_range(0..3, &mut cnf, &mut var_manager); + node.encode_ub_range(0..3, &mut cnf, &mut var_manager) + .unwrap(); + node.encode_lb_range(0..3, &mut cnf, &mut var_manager) + .unwrap(); match &node { Node::Leaf { .. } => panic!(), Node::Internal { out_lits, .. } => assert_eq!(out_lits.len(), 2), @@ -890,8 +925,10 @@ mod tests { let mut var_manager = BasicVarManager::default(); var_manager.increase_next_free(var![5]); let mut cnf = Cnf::new(); - node.encode_ub_range(0..5, &mut cnf, &mut var_manager); - node.encode_lb_range(0..5, &mut cnf, &mut var_manager); + node.encode_ub_range(0..5, &mut cnf, &mut var_manager) + .unwrap(); + node.encode_lb_range(0..5, &mut cnf, &mut var_manager) + .unwrap(); match &node { Node::Leaf { .. } => panic!(), Node::Internal { out_lits, .. } => assert_eq!(out_lits.len(), 4), @@ -908,7 +945,8 @@ mod tests { let mut var_manager = BasicVarManager::default(); var_manager.increase_next_free(var![2]); let mut cnf = Cnf::new(); - node.encode_lb_range(0..2, &mut cnf, &mut var_manager); + node.encode_lb_range(0..2, &mut cnf, &mut var_manager) + .unwrap(); match &node { Node::Leaf { .. } => panic!(), Node::Internal { out_lits, .. } => assert_eq!(out_lits.len(), 1), @@ -945,8 +983,10 @@ mod tests { let mut var_manager = BasicVarManager::default(); var_manager.increase_next_free(var![7]); let mut cnf = Cnf::new(); - node.encode_ub_range(0..4, &mut cnf, &mut var_manager); - node.encode_lb_range(0..4, &mut cnf, &mut var_manager); + node.encode_ub_range(0..4, &mut cnf, &mut var_manager) + .unwrap(); + node.encode_lb_range(0..4, &mut cnf, &mut var_manager) + .unwrap(); match &node { Node::Leaf { .. } => panic!(), Node::Internal { out_lits, .. } => assert_eq!(out_lits.len(), 4), @@ -983,8 +1023,10 @@ mod tests { let mut var_manager = BasicVarManager::default(); var_manager.increase_next_free(var![7]); let mut cnf = Cnf::new(); - node.encode_ub_range(3..3, &mut cnf, &mut var_manager); - node.encode_lb_range(3..3, &mut cnf, &mut var_manager); + node.encode_ub_range(3..3, &mut cnf, &mut var_manager) + .unwrap(); + node.encode_lb_range(3..3, &mut cnf, &mut var_manager) + .unwrap(); assert_eq!(cnf.len(), 0); } @@ -1017,8 +1059,10 @@ mod tests { let mut var_manager = BasicVarManager::default(); var_manager.increase_next_free(var![7]); let mut cnf = Cnf::new(); - node.encode_ub_range(2..4, &mut cnf, &mut var_manager); - node.encode_lb_range(2..4, &mut cnf, &mut var_manager); + node.encode_ub_range(2..4, &mut cnf, &mut var_manager) + .unwrap(); + node.encode_lb_range(2..4, &mut cnf, &mut var_manager) + .unwrap(); match &node { Node::Leaf { .. } => panic!(), Node::Internal { out_lits, .. } => assert_eq!(out_lits.len(), 4), @@ -1035,8 +1079,8 @@ mod tests { let mut var_manager = BasicVarManager::default(); var_manager.increase_next_free(var![4]); let mut cnf = Cnf::new(); - tot.encode_ub(0..5, &mut cnf, &mut var_manager); - tot.encode_lb(0..5, &mut cnf, &mut var_manager); + tot.encode_ub(0..5, &mut cnf, &mut var_manager).unwrap(); + tot.encode_lb(0..5, &mut cnf, &mut var_manager).unwrap(); assert_eq!(tot.depth(), 3); assert_eq!(cnf.len(), 28); assert_eq!(tot.n_clauses(), 28); @@ -1052,7 +1096,7 @@ mod tests { let mut var_manager = BasicVarManager::default(); var_manager.increase_next_free(var![4]); let mut cnf = Cnf::new(); - tot.encode_both(3..4, &mut cnf, &mut var_manager); + tot.encode_both(3..4, &mut cnf, &mut var_manager).unwrap(); assert_eq!(tot.depth(), 3); assert_eq!(cnf.len(), 12); assert_eq!(cnf.len(), tot.n_clauses()); @@ -1065,14 +1109,15 @@ mod tests { let mut var_manager = BasicVarManager::default(); var_manager.increase_next_free(var![4]); let mut cnf1 = Cnf::new(); - tot1.encode_ub(0..5, &mut cnf1, &mut var_manager); + tot1.encode_ub(0..5, &mut cnf1, &mut var_manager).unwrap(); let mut tot2 = Totalizer::default(); tot2.extend(vec![lit![0], lit![1], lit![2], lit![3]]); let mut var_manager = BasicVarManager::default(); var_manager.increase_next_free(var![4]); let mut cnf2 = Cnf::new(); - tot2.encode_ub(0..3, &mut cnf2, &mut var_manager); - tot2.encode_ub_change(0..5, &mut cnf2, &mut var_manager); + tot2.encode_ub(0..3, &mut cnf2, &mut var_manager).unwrap(); + tot2.encode_ub_change(0..5, &mut cnf2, &mut var_manager) + .unwrap(); assert_eq!(cnf1.len(), cnf2.len()); assert_eq!(cnf1.len(), tot1.n_clauses()); assert_eq!(cnf2.len(), tot2.n_clauses()); @@ -1085,14 +1130,15 @@ mod tests { let mut var_manager = BasicVarManager::default(); var_manager.increase_next_free(var![4]); let mut cnf1 = Cnf::new(); - tot1.encode_lb(0..5, &mut cnf1, &mut var_manager); + tot1.encode_lb(0..5, &mut cnf1, &mut var_manager).unwrap(); let mut tot2 = Totalizer::default(); tot2.extend(vec![lit![0], lit![1], lit![2], lit![3]]); let mut var_manager = BasicVarManager::default(); var_manager.increase_next_free(var![4]); let mut cnf2 = Cnf::new(); - tot2.encode_lb(0..3, &mut cnf2, &mut var_manager); - tot2.encode_lb_change(0..5, &mut cnf2, &mut var_manager); + tot2.encode_lb(0..3, &mut cnf2, &mut var_manager).unwrap(); + tot2.encode_lb_change(0..5, &mut cnf2, &mut var_manager) + .unwrap(); assert_eq!(cnf1.len(), cnf2.len()); assert_eq!(cnf1.len(), tot1.n_clauses()); assert_eq!(cnf2.len(), tot2.n_clauses()); diff --git a/rustsat/src/encodings/pb.rs b/rustsat/src/encodings/pb.rs index ca7557e5..370497c7 100644 --- a/rustsat/src/encodings/pb.rs +++ b/rustsat/src/encodings/pb.rs @@ -89,7 +89,8 @@ pub trait BoundUpper: Encode { range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: RangeBounds; /// Returns assumptions/units for enforcing an upper bound (`weighted sum of @@ -98,29 +99,33 @@ pub trait BoundUpper: Encode { /// [`Error::NotEncoded`] will be returned. fn enforce_ub(&self, ub: usize) -> Result, Error>; /// Encodes an upper bound pseudo-boolean constraint to CNF + /// + /// # Errors + /// + /// Either an [`Error`] of [`crate::OutOfMemory`] fn encode_ub_constr( constr: PBUBConstr, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) -> Result<(), Error> + ) -> anyhow::Result<()> where Col: CollectClauses, Self: FromIterator<(Lit, usize)> + Sized, { let (lits, ub) = constr.decompose(); let ub = if ub < 0 { - return Err(Error::Unsat); + anyhow::bail!(Error::Unsat); } else { ub as usize }; let mut enc = Self::from_iter(lits); - enc.encode_ub(ub..ub + 1, collector, var_manager); - collector.extend( + enc.encode_ub(ub..ub + 1, collector, var_manager)?; + collector.extend_clauses( enc.enforce_ub(ub) .unwrap() .into_iter() .map(|unit| clause![unit]), - ); + )?; Ok(()) } /// Gets the next smaller upper bound value that can be _easily_ encoded. This @@ -143,7 +148,8 @@ pub trait BoundLower: Encode { range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: RangeBounds; /// Returns assumptions/units for enforcing a lower bound (`sum of lits >= @@ -154,11 +160,15 @@ pub trait BoundLower: Encode { /// is returned. fn enforce_lb(&self, lb: usize) -> Result, Error>; /// Encodes a lower bound pseudo-boolean constraint to CNF + /// + /// # Errors + /// + /// Either an [`Error`] of [`crate::OutOfMemory`] fn encode_lb_constr( constr: PBLBConstr, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) -> Result<(), Error> + ) -> anyhow::Result<()> where Col: CollectClauses, Self: FromIterator<(Lit, usize)> + Sized, @@ -170,13 +180,13 @@ pub trait BoundLower: Encode { lb as usize }; let mut enc = Self::from_iter(lits); - enc.encode_lb(lb..lb + 1, collector, var_manager); - collector.extend( + enc.encode_lb(lb..lb + 1, collector, var_manager)?; + collector.extend_clauses( enc.enforce_lb(lb) .unwrap() .into_iter() .map(|unit| clause![unit]), - ); + )?; Ok(()) } /// Gets the next greater lower bound value that can be _easily_ encoded. This @@ -197,12 +207,14 @@ pub trait BoundBoth: BoundUpper + BoundLower { range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: RangeBounds + Clone, { - self.encode_ub(range.clone(), collector, var_manager); - self.encode_lb(range, collector, var_manager); + self.encode_ub(range.clone(), collector, var_manager)?; + self.encode_lb(range, collector, var_manager)?; + Ok(()) } /// Returns assumptions for enforcing an equality (`sum of lits = b`) or an /// error if the encoding does not support one of the two required bound @@ -217,37 +229,45 @@ pub trait BoundBoth: BoundUpper + BoundLower { Ok(assumps) } /// Encodes an equality pseudo-boolean constraint to CNF + /// + /// # Errors + /// + /// Either an [`Error`] of [`crate::OutOfMemory`] fn encode_eq_constr( constr: PBEQConstr, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) -> Result<(), Error> + ) -> anyhow::Result<()> where Col: CollectClauses, Self: FromIterator<(Lit, usize)> + Sized, { let (lits, b) = constr.decompose(); let b = if b < 0 { - return Err(Error::Unsat); + anyhow::bail!(Error::Unsat); } else { b as usize }; let mut enc = Self::from_iter(lits); - enc.encode_both(b..b + 1, collector, var_manager); - collector.extend( + enc.encode_both(b..b + 1, collector, var_manager)?; + collector.extend_clauses( enc.enforce_eq(b) .unwrap() .into_iter() .map(|unit| clause![unit]), - ); + )?; Ok(()) } /// Encodes any pseudo-boolean constraint to CNF + /// + /// # Errors + /// + /// Either an [`Error`] of [`crate::OutOfMemory`] fn encode_constr( constr: PBConstraint, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) -> Result<(), Error> + ) -> anyhow::Result<()> where Col: CollectClauses, Self: FromIterator<(Lit, usize)> + Sized, @@ -283,7 +303,8 @@ pub trait BoundUpperIncremental: BoundUpper + EncodeIncremental { range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: RangeBounds; } @@ -300,7 +321,8 @@ pub trait BoundLowerIncremental: BoundLower + EncodeIncremental { range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: RangeBounds; } @@ -316,12 +338,14 @@ pub trait BoundBothIncremental: BoundUpperIncremental + BoundLowerIncremental { range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: RangeBounds + Clone, { - self.encode_ub_change(range.clone(), collector, var_manager); - self.encode_lb_change(range, collector, var_manager); + self.encode_ub_change(range.clone(), collector, var_manager)?; + self.encode_lb_change(range, collector, var_manager)?; + Ok(()) } } @@ -379,7 +403,7 @@ pub fn default_encode_pb_constraint( constr: PBConstraint, collector: &mut Col, var_manager: &mut dyn ManageVars, -) { +) -> Result<(), crate::OutOfMemory> { encode_pb_constraint::(constr, collector, var_manager) } @@ -388,31 +412,34 @@ pub fn encode_pb_constraint, Col: Co constr: PBConstraint, collector: &mut Col, var_manager: &mut dyn ManageVars, -) { +) -> Result<(), crate::OutOfMemory> { if constr.is_tautology() { - return; + return Ok(()); } if constr.is_unsat() { - collector.extend([Clause::new()]); - return; + return collector.add_clause(Clause::new()); } if constr.is_positive_assignment() { - collector.extend(constr.into_lits().into_iter().map(|(lit, _)| clause![lit])); - return; + return collector + .extend_clauses(constr.into_lits().into_iter().map(|(lit, _)| clause![lit])); } if constr.is_negative_assignment() { - collector.extend(constr.into_lits().into_iter().map(|(lit, _)| clause![!lit])); - return; + return collector + .extend_clauses(constr.into_lits().into_iter().map(|(lit, _)| clause![!lit])); } if constr.is_clause() { - collector.extend([constr.into_clause().unwrap()]); - return; + return collector.add_clause(constr.into_clause().unwrap()); } if constr.is_card() { let card = constr.into_card_constr().unwrap(); return card::default_encode_cardinality_constraint(card, collector, var_manager); } - PBE::encode_constr(constr, collector, var_manager).unwrap() + match PBE::encode_constr(constr, collector, var_manager) { + Ok(_) => Ok(()), + Err(err) => Err(err + .downcast::() + .expect("unexpected error when encoding constraint")), + } } fn prepare_ub_range>(enc: &Enc, range: R) -> Range { diff --git a/rustsat/src/encodings/pb/dbgte.rs b/rustsat/src/encodings/pb/dbgte.rs index 3b370b9c..bf772418 100644 --- a/rustsat/src/encodings/pb/dbgte.rs +++ b/rustsat/src/encodings/pb/dbgte.rs @@ -156,13 +156,18 @@ impl EncodeIncremental for DbGte { } impl BoundUpper for DbGte { - fn encode_ub(&mut self, range: R, collector: &mut Col, var_manager: &mut dyn ManageVars) + fn encode_ub( + &mut self, + range: R, + collector: &mut Col, + var_manager: &mut dyn ManageVars, + ) -> Result<(), crate::OutOfMemory> where Col: CollectClauses, R: RangeBounds, { self.db.reset_encoded(); - self.encode_ub_change(range, collector, var_manager); + self.encode_ub_change(range, collector, var_manager) } fn enforce_ub(&self, ub: usize) -> Result, Error> { @@ -219,13 +224,14 @@ impl BoundUpperIncremental for DbGte { range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: RangeBounds, { let range = super::prepare_ub_range(self, range); if range.is_empty() { - return; + return Ok(()); } let n_vars_before = var_manager.n_used(); let n_clauses_before = collector.n_clauses(); @@ -236,14 +242,16 @@ impl BoundUpperIncremental for DbGte { con.rev_map_round_up(range.start + 1) ..=con.rev_map(range.end + self.max_leaf_weight), ) - .for_each(|val| { + .try_for_each(|val| { self.db - .define_pos(con.id, val, collector, var_manager) + .define_pos(con.id, val, collector, var_manager)? .unwrap(); - }) + Ok::<(), crate::OutOfMemory>(()) + })? } self.n_clauses += collector.n_clauses() - n_clauses_before; self.n_vars += var_manager.n_used() - n_vars_before; + Ok(()) } } @@ -432,7 +440,8 @@ pub mod referenced { range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: RangeBounds, { @@ -487,7 +496,8 @@ pub mod referenced { range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: RangeBounds, { @@ -542,24 +552,27 @@ pub mod referenced { range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: RangeBounds, { let range = super::super::prepare_ub_range(self, range); if range.is_empty() { - return; + return Ok(()); } self.db[self.root.id] .vals( self.root.rev_map_round_up(range.start + 1) ..=self.root.rev_map(range.end + self.max_leaf_weight), ) - .for_each(|val| { + .try_for_each(|val| { self.db - .define_pos(self.root.id, val, collector, var_manager) + .define_pos(self.root.id, val, collector, var_manager)? .unwrap(); - }); + Ok::<(), crate::OutOfMemory>(()) + })?; + Ok(()) } } @@ -569,24 +582,27 @@ pub mod referenced { range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: RangeBounds, { let range = super::super::prepare_ub_range(self, range); if range.is_empty() { - return; + return Ok(()); } - let vals = self.db.borrow()[self.root.id].vals( + let mut vals = self.db.borrow()[self.root.id].vals( self.root.rev_map_round_up(range.start + 1) ..=self.root.rev_map(range.end + self.max_leaf_weight), ); - vals.for_each(|val| { + vals.try_for_each(|val| { self.db .borrow_mut() - .define_pos(self.root.id, val, collector, var_manager) + .define_pos(self.root.id, val, collector, var_manager)? .unwrap(); - }); + Ok::<(), crate::OutOfMemory>(()) + })?; + Ok(()) } } } @@ -617,7 +633,8 @@ mod tests { gte.extend(lits); assert_eq!(gte.enforce_ub(4), Err(Error::NotEncoded)); let mut var_manager = BasicVarManager::default(); - gte.encode_ub(0..7, &mut Cnf::new(), &mut var_manager); + gte.encode_ub(0..7, &mut Cnf::new(), &mut var_manager) + .unwrap(); assert_eq!(gte.depth(), 3); assert_eq!(gte.n_vars(), 10); } @@ -633,13 +650,14 @@ mod tests { gte1.extend(lits.clone()); let mut var_manager = BasicVarManager::default(); let mut cnf1 = Cnf::new(); - gte1.encode_ub(0..5, &mut cnf1, &mut var_manager); + gte1.encode_ub(0..5, &mut cnf1, &mut var_manager).unwrap(); let mut gte2 = DbGte::default(); gte2.extend(lits); let mut var_manager = BasicVarManager::default(); let mut cnf2 = Cnf::new(); - gte2.encode_ub(0..3, &mut cnf2, &mut var_manager); - gte2.encode_ub_change(0..5, &mut cnf2, &mut var_manager); + gte2.encode_ub(0..3, &mut cnf2, &mut var_manager).unwrap(); + gte2.encode_ub_change(0..5, &mut cnf2, &mut var_manager) + .unwrap(); assert_eq!(cnf1.len(), cnf2.len()); assert_eq!(cnf1.len(), gte1.n_clauses()); assert_eq!(cnf2.len(), gte2.n_clauses()); @@ -656,7 +674,7 @@ mod tests { gte1.extend(lits); let mut var_manager = BasicVarManager::default(); let mut cnf1 = Cnf::new(); - gte1.encode_ub(0..5, &mut cnf1, &mut var_manager); + gte1.encode_ub(0..5, &mut cnf1, &mut var_manager).unwrap(); let mut gte2 = DbGte::default(); let mut lits = RsHashMap::default(); lits.insert(lit![0], 10); @@ -666,7 +684,7 @@ mod tests { gte2.extend(lits); let mut var_manager = BasicVarManager::default(); let mut cnf2 = Cnf::new(); - gte2.encode_ub(0..9, &mut cnf2, &mut var_manager); + gte2.encode_ub(0..9, &mut cnf2, &mut var_manager).unwrap(); assert_eq!(cnf1.len(), cnf2.len()); assert_eq!(cnf1.len(), gte1.n_clauses()); assert_eq!(cnf2.len(), gte2.n_clauses()); @@ -689,7 +707,8 @@ mod tests { lits.insert(lit![6], 1); gte.extend(lits); let mut gte_cnf = Cnf::new(); - gte.encode_ub(3..8, &mut gte_cnf, &mut var_manager_gte); + gte.encode_ub(3..8, &mut gte_cnf, &mut var_manager_gte) + .unwrap(); // Set up Tot let mut tot = card::Totalizer::default(); tot.extend(vec![ @@ -702,7 +721,7 @@ mod tests { lit![6], ]); let mut tot_cnf = Cnf::new(); - card::BoundUpper::encode_ub(&mut tot, 3..8, &mut tot_cnf, &mut var_manager_tot); + card::BoundUpper::encode_ub(&mut tot, 3..8, &mut tot_cnf, &mut var_manager_tot).unwrap(); println!("{:?}", gte_cnf); println!("{:?}", tot_cnf); assert_eq!(var_manager_gte.new_var(), var_manager_tot.new_var()); diff --git a/rustsat/src/encodings/pb/dpw.rs b/rustsat/src/encodings/pb/dpw.rs index a8cdb164..c9eaf776 100644 --- a/rustsat/src/encodings/pb/dpw.rs +++ b/rustsat/src/encodings/pb/dpw.rs @@ -129,7 +129,12 @@ impl EncodeIncremental for DynamicPolyWatchdog { } impl BoundUpper for DynamicPolyWatchdog { - fn encode_ub(&mut self, range: R, collector: &mut Col, var_manager: &mut dyn ManageVars) + fn encode_ub( + &mut self, + range: R, + collector: &mut Col, + var_manager: &mut dyn ManageVars, + ) -> Result<(), crate::OutOfMemory> where Col: CollectClauses, R: RangeBounds, @@ -179,13 +184,14 @@ impl BoundUpperIncremental for DynamicPolyWatchdog { range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: RangeBounds, { let range = super::prepare_ub_range(self, range); if range.is_empty() || self.in_lits.len() <= 1 { - return; + return Ok(()); } let n_vars_before = var_manager.n_used(); if self.structure.is_none() && !self.in_lits.is_empty() { @@ -201,13 +207,14 @@ impl BoundUpperIncremental for DynamicPolyWatchdog { let output_weight = 1 << (structure.output_power()); let output_range = range.start / output_weight..(range.end - 1) / output_weight + 1; for oidx in output_range { - encode_output(structure, oidx, &mut self.db, collector, var_manager); + encode_output(structure, oidx, &mut self.db, collector, var_manager)?; } self.n_clauses += collector.n_clauses() - n_clauses_before; self.n_vars += var_manager.n_used() - n_vars_before; } None => (), - } + }; + Ok(()) } } @@ -343,7 +350,8 @@ pub mod referenced { range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: RangeBounds, { @@ -367,7 +375,8 @@ pub mod referenced { range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: RangeBounds, { @@ -391,19 +400,21 @@ pub mod referenced { range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: RangeBounds, { let range = super::super::prepare_ub_range(self, range); if range.is_empty() { - return; + return Ok(()); } let output_weight = 1 << self.structure.output_power(); let output_range = range.start / output_weight..(range.end - 1) / output_weight + 1; for oidx in output_range { - encode_output(self.structure, oidx, self.db, collector, var_manager); + encode_output(self.structure, oidx, self.db, collector, var_manager)?; } + Ok(()) } } @@ -413,13 +424,14 @@ pub mod referenced { range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: RangeBounds, { let range = super::super::prepare_ub_range(self, range); if range.is_empty() { - return; + return Ok(()); } let output_weight = 1 << self.structure.output_power(); let output_range = range.start / output_weight..(range.end - 1) / output_weight + 1; @@ -430,8 +442,9 @@ pub mod referenced { &mut self.db.borrow_mut(), collector, var_manager, - ); + )?; } + Ok(()) } } } @@ -581,13 +594,15 @@ fn encode_output( tot_db: &mut TotDb, collector: &mut Col, var_manager: &mut dyn ManageVars, -) where +) -> Result<(), crate::OutOfMemory> +where Col: CollectClauses, { if oidx >= tot_db[dpw.root].max_val() { - return; + return Ok(()); } - tot_db.define_pos_tot(dpw.root, oidx, collector, var_manager); + tot_db.define_pos_tot(dpw.root, oidx, collector, var_manager)?; + Ok(()) } #[cfg_attr(feature = "internals", visibility::make(pub))] @@ -643,7 +658,7 @@ mod tests { let mut dpw = DynamicPolyWatchdog::from(lits); let mut var_manager = BasicVarManager::from_next_free(Var::new(4)); let mut cnf = Cnf::new(); - dpw.encode_ub(0..=6, &mut cnf, &mut var_manager); + dpw.encode_ub(0..=6, &mut cnf, &mut var_manager).unwrap(); assert_eq!(dpw.n_vars(), 9); assert_eq!(cnf.len(), 13); } @@ -655,7 +670,7 @@ mod tests { let mut dpw = DynamicPolyWatchdog::from(lits); let mut var_manager = BasicVarManager::from_next_free(Var::new(1)); let mut cnf = Cnf::new(); - dpw.encode_ub(0..=6, &mut cnf, &mut var_manager); + dpw.encode_ub(0..=6, &mut cnf, &mut var_manager).unwrap(); assert_eq!(dpw.n_vars(), 0); assert_eq!(cnf.len(), 0); debug_assert!(dpw.enforce_ub(4).unwrap().is_empty()); @@ -670,7 +685,7 @@ mod tests { let mut dpw = DynamicPolyWatchdog::from(lits); let mut var_manager = BasicVarManager::default(); let mut cnf = Cnf::new(); - dpw.encode_ub(0..=6, &mut cnf, &mut var_manager); + dpw.encode_ub(0..=6, &mut cnf, &mut var_manager).unwrap(); assert_eq!(dpw.n_vars(), 0); assert_eq!(cnf.len(), 0); debug_assert!(dpw.enforce_ub(4).unwrap().is_empty()); @@ -687,7 +702,7 @@ mod tests { let mut dpw = DynamicPolyWatchdog::from(lits); let mut var_manager = BasicVarManager::default(); let mut cnf = Cnf::new(); - dpw.encode_ub(0..23, &mut cnf, &mut var_manager); + dpw.encode_ub(0..23, &mut cnf, &mut var_manager).unwrap(); for ub in 7..23 { let coarse_ub = dpw.coarse_ub(ub); debug_assert!(coarse_ub <= ub); @@ -709,7 +724,7 @@ mod tests { let mut dpw = DynamicPolyWatchdog::from(lits); let mut var_manager = BasicVarManager::default(); let mut cnf = Cnf::new(); - dpw.encode_ub(0..=4, &mut cnf, &mut var_manager); + dpw.encode_ub(0..=4, &mut cnf, &mut var_manager).unwrap(); for ub in 0..4 { let coarse_ub = dpw.coarse_ub(ub); debug_assert_eq!(coarse_ub, ub); diff --git a/rustsat/src/encodings/pb/gte.rs b/rustsat/src/encodings/pb/gte.rs index b42d7fe8..91856994 100644 --- a/rustsat/src/encodings/pb/gte.rs +++ b/rustsat/src/encodings/pb/gte.rs @@ -154,14 +154,19 @@ impl EncodeIncremental for GeneralizedTotalizer { } impl BoundUpper for GeneralizedTotalizer { - fn encode_ub(&mut self, range: R, collector: &mut Col, var_manager: &mut dyn ManageVars) + fn encode_ub( + &mut self, + range: R, + collector: &mut Col, + var_manager: &mut dyn ManageVars, + ) -> Result<(), crate::OutOfMemory> where Col: CollectClauses, R: RangeBounds, { let range = super::prepare_ub_range(self, range); if range.is_empty() { - return; + return Ok(()); }; let n_vars_before = var_manager.n_used(); let n_clauses_before = collector.n_clauses(); @@ -172,10 +177,11 @@ impl BoundUpper for GeneralizedTotalizer { range.start + 1..range.end + self.max_leaf_weight + 1, collector, var_manager, - ), + )?, }; self.n_clauses += collector.n_clauses() - n_clauses_before; self.n_vars += var_manager.n_used() - n_vars_before; + Ok(()) } fn enforce_ub(&self, ub: usize) -> Result, Error> { @@ -236,13 +242,14 @@ impl BoundUpperIncremental for GeneralizedTotalizer { range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: RangeBounds, { let range = super::prepare_ub_range(self, range); if range.is_empty() { - return; + return Ok(()); }; let n_vars_before = var_manager.n_used(); let n_clauses_before = collector.n_clauses(); @@ -252,10 +259,11 @@ impl BoundUpperIncremental for GeneralizedTotalizer { range.start + 1..range.end + self.max_leaf_weight, collector, var_manager, - ); + )?; } self.n_clauses += collector.n_clauses() - n_clauses_before; self.n_vars += var_manager.n_used() - n_vars_before; + Ok(()) } } @@ -410,12 +418,13 @@ impl Node { range: Range, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, { let range = self.limit_range(range); if range.is_empty() { - return; + return Ok(()); } // Reserve vars if needed @@ -434,19 +443,17 @@ impl Node { let right_lits = right.lit_map(&mut right_tmp_map); // Encode adder for current node // Propagate left value - for (&left_val, &left_lit) in left_lits.range(range.clone()) { - collector.extend([atomics::lit_impl_lit( - left_lit, - *out_lits.get(&left_val).unwrap(), - )]); - } + collector.extend_clauses(left_lits.range(range.clone()).map( + |(left_val, &left_lit)| { + atomics::lit_impl_lit(left_lit, *out_lits.get(left_val).unwrap()) + }, + ))?; // Propagate right value - for (&right_val, &right_lit) in right_lits.range(range.clone()) { - collector.extend([atomics::lit_impl_lit( - right_lit, - *out_lits.get(&right_val).unwrap(), - )]); - } + collector.extend_clauses(right_lits.range(range.clone()).map( + |(right_val, &right_lit)| { + atomics::lit_impl_lit(right_lit, *out_lits.get(right_val).unwrap()) + }, + ))?; // Propagate sum if range.end > 1 { let clause_from_data = @@ -477,10 +484,11 @@ impl Node { clause_from_data(left_val, right_val, left_lit, right_lit) }) }); - collector.extend(clause_iter); + collector.extend_clauses(clause_iter)?; } } - } + }; + Ok(()) } /// Encodes the output literals from the children to this node in a given @@ -491,12 +499,13 @@ impl Node { range: Range, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, { let range = self.limit_range(range); if range.is_empty() { - return; + return Ok(()); } // Ignore all previous encoding and encode from scratch @@ -506,16 +515,18 @@ impl Node { let left_range = Node::compute_required_min_enc(range.clone(), right.max_val()); let right_range = Node::compute_required_min_enc(range.clone(), left.max_val()); // Recurse - left.rec_encode(left_range, collector, var_manager); - right.rec_encode(right_range, collector, var_manager); + left.rec_encode(left_range, collector, var_manager)?; + right.rec_encode(right_range, collector, var_manager)?; // Encode current node let n_clauses_before = collector.n_clauses(); - self.encode_range(range.clone(), collector, var_manager); + self.encode_range(range.clone(), collector, var_manager)?; self.update_stats(range, collector.n_clauses() - n_clauses_before); } - } + }; + + Ok(()) } /// Encodes the output literals from the children to this node in a given @@ -525,12 +536,13 @@ impl Node { range: Range, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, { let range = self.limit_range(range); if range.is_empty() { - return; + return Ok(()); } match self { @@ -547,27 +559,28 @@ impl Node { let left_range = Node::compute_required_min_enc(range.clone(), right.max_val()); let right_range = Node::compute_required_min_enc(range.clone(), left.max_val()); // Recurse - left.rec_encode_change(left_range, collector, var_manager); - right.rec_encode_change(right_range, collector, var_manager); + left.rec_encode_change(left_range, collector, var_manager)?; + right.rec_encode_change(right_range, collector, var_manager)?; // Encode changes for current node let n_clauses_before = collector.n_clauses(); if enc_range.is_empty() { // First time encoding this node - self.encode_range(range.clone(), collector, var_manager); + self.encode_range(range.clone(), collector, var_manager)?; } else { // Partially encoded if range.start < enc_range.start { - self.encode_range(range.start..enc_range.start, collector, var_manager); + self.encode_range(range.start..enc_range.start, collector, var_manager)?; }; if range.end > enc_range.end { - self.encode_range(enc_range.end..range.end, collector, var_manager); + self.encode_range(enc_range.end..range.end, collector, var_manager)?; }; }; self.update_stats(range, collector.n_clauses() - n_clauses_before); } - } + }; + Ok(()) } /// Reserves variables this node might need in a given range @@ -713,7 +726,7 @@ mod tests { let mut node = Node::new_internal(child1, child2); let mut var_manager = BasicVarManager::default(); let mut cnf = Cnf::new(); - node.encode_range(0..9, &mut cnf, &mut var_manager); + node.encode_range(0..9, &mut cnf, &mut var_manager).unwrap(); match &node { Node::Leaf { .. } => panic!(), Node::Internal { out_lits, .. } => assert_eq!(out_lits.len(), 3), @@ -755,7 +768,7 @@ mod tests { let mut node = Node::new_internal(child1, child2); let mut var_manager = BasicVarManager::default(); let mut cnf = Cnf::new(); - node.encode_range(0..7, &mut cnf, &mut var_manager); + node.encode_range(0..7, &mut cnf, &mut var_manager).unwrap(); match &node { Node::Leaf { .. } => panic!(), Node::Internal { out_lits, .. } => assert_eq!(out_lits.len(), 3), @@ -797,7 +810,7 @@ mod tests { let mut node = Node::new_internal(child1, child2); let mut var_manager = BasicVarManager::default(); let mut cnf = Cnf::new(); - node.encode_range(4..7, &mut cnf, &mut var_manager); + node.encode_range(4..7, &mut cnf, &mut var_manager).unwrap(); match &node { Node::Leaf { .. } => panic!(), Node::Internal { out_lits, .. } => assert_eq!(out_lits.len(), 2), @@ -839,7 +852,7 @@ mod tests { let mut node = Node::new_internal(child1, child2); let mut var_manager = BasicVarManager::default(); let mut cnf = Cnf::new(); - node.encode_range(6..5, &mut cnf, &mut var_manager); + node.encode_range(6..5, &mut cnf, &mut var_manager).unwrap(); assert_eq!(cnf.len(), 0); } @@ -854,7 +867,8 @@ mod tests { gte.extend(lits); assert_eq!(gte.enforce_ub(4), Err(Error::NotEncoded)); let mut var_manager = BasicVarManager::default(); - gte.encode_ub(0..7, &mut Cnf::new(), &mut var_manager); + gte.encode_ub(0..7, &mut Cnf::new(), &mut var_manager) + .unwrap(); assert_eq!(gte.depth(), 3); assert_eq!(gte.n_vars(), 10); } @@ -870,13 +884,14 @@ mod tests { gte1.extend(lits.clone()); let mut var_manager = BasicVarManager::default(); let mut cnf1 = Cnf::new(); - gte1.encode_ub(0..5, &mut cnf1, &mut var_manager); + gte1.encode_ub(0..5, &mut cnf1, &mut var_manager).unwrap(); let mut gte2 = GeneralizedTotalizer::default(); gte2.extend(lits); let mut var_manager = BasicVarManager::default(); let mut cnf2 = Cnf::new(); - gte2.encode_ub(0..3, &mut cnf2, &mut var_manager); - gte2.encode_ub_change(0..5, &mut cnf2, &mut var_manager); + gte2.encode_ub(0..3, &mut cnf2, &mut var_manager).unwrap(); + gte2.encode_ub_change(0..5, &mut cnf2, &mut var_manager) + .unwrap(); assert_eq!(cnf1.len(), cnf2.len()); assert_eq!(cnf1.len(), gte1.n_clauses()); assert_eq!(cnf2.len(), gte2.n_clauses()); @@ -893,7 +908,7 @@ mod tests { gte1.extend(lits); let mut var_manager = BasicVarManager::default(); let mut cnf1 = Cnf::new(); - gte1.encode_ub(0..5, &mut cnf1, &mut var_manager); + gte1.encode_ub(0..5, &mut cnf1, &mut var_manager).unwrap(); let mut gte2 = GeneralizedTotalizer::default(); let mut lits = RsHashMap::default(); lits.insert(lit![0], 10); @@ -903,7 +918,7 @@ mod tests { gte2.extend(lits); let mut var_manager = BasicVarManager::default(); let mut cnf2 = Cnf::new(); - gte2.encode_ub(0..9, &mut cnf2, &mut var_manager); + gte2.encode_ub(0..9, &mut cnf2, &mut var_manager).unwrap(); assert_eq!(cnf1.len(), cnf2.len()); assert_eq!(cnf1.len(), gte1.n_clauses()); assert_eq!(cnf2.len(), gte2.n_clauses()); @@ -926,7 +941,8 @@ mod tests { lits.insert(lit![6], 1); gte.extend(lits); let mut gte_cnf = Cnf::new(); - gte.encode_ub(3..8, &mut gte_cnf, &mut var_manager_gte); + gte.encode_ub(3..8, &mut gte_cnf, &mut var_manager_gte) + .unwrap(); // Set up Tot let mut tot = card::Totalizer::default(); tot.extend(vec![ @@ -939,7 +955,7 @@ mod tests { lit![6], ]); let mut tot_cnf = Cnf::new(); - card::BoundUpper::encode_ub(&mut tot, 3..8, &mut tot_cnf, &mut var_manager_tot); + card::BoundUpper::encode_ub(&mut tot, 3..8, &mut tot_cnf, &mut var_manager_tot).unwrap(); println!("{:?}", gte_cnf); println!("{:?}", tot_cnf); assert_eq!(var_manager_gte.new_var(), var_manager_tot.new_var()); diff --git a/rustsat/src/encodings/pb/simulators.rs b/rustsat/src/encodings/pb/simulators.rs index 022d9578..20528062 100644 --- a/rustsat/src/encodings/pb/simulators.rs +++ b/rustsat/src/encodings/pb/simulators.rs @@ -132,7 +132,12 @@ impl BoundUpper for Inverted where PBE: BoundLower, { - fn encode_ub(&mut self, range: R, collector: &mut Col, var_manager: &mut dyn ManageVars) + fn encode_ub( + &mut self, + range: R, + collector: &mut Col, + var_manager: &mut dyn ManageVars, + ) -> Result<(), crate::OutOfMemory> where Col: CollectClauses, R: RangeBounds, @@ -158,7 +163,12 @@ impl BoundLower for Inverted where PBE: BoundUpper, { - fn encode_lb(&mut self, range: R, collector: &mut Col, var_manager: &mut dyn ManageVars) + fn encode_lb( + &mut self, + range: R, + collector: &mut Col, + var_manager: &mut dyn ManageVars, + ) -> Result<(), crate::OutOfMemory> where Col: CollectClauses, R: RangeBounds, @@ -189,7 +199,8 @@ where range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: RangeBounds, { @@ -210,7 +221,8 @@ where range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: RangeBounds, { @@ -347,7 +359,12 @@ where UBE: BoundUpper, LBE: BoundLower, { - fn encode_ub(&mut self, range: R, collector: &mut Col, var_manager: &mut dyn ManageVars) + fn encode_ub( + &mut self, + range: R, + collector: &mut Col, + var_manager: &mut dyn ManageVars, + ) -> Result<(), crate::OutOfMemory> where Col: CollectClauses, R: RangeBounds, @@ -365,7 +382,12 @@ where UBE: BoundUpper, LBE: BoundLower, { - fn encode_lb(&mut self, range: R, collector: &mut Col, var_manager: &mut dyn ManageVars) + fn encode_lb( + &mut self, + range: R, + collector: &mut Col, + var_manager: &mut dyn ManageVars, + ) -> Result<(), crate::OutOfMemory> where Col: CollectClauses, R: RangeBounds, @@ -388,7 +410,8 @@ where range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: RangeBounds, { @@ -406,7 +429,8 @@ where range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: RangeBounds, { @@ -529,7 +553,12 @@ impl BoundUpper for Card where CE: card::BoundUpper, { - fn encode_ub(&mut self, range: R, collector: &mut Col, var_manager: &mut dyn ManageVars) + fn encode_ub( + &mut self, + range: R, + collector: &mut Col, + var_manager: &mut dyn ManageVars, + ) -> Result<(), crate::OutOfMemory> where Col: CollectClauses, R: RangeBounds, @@ -546,7 +575,12 @@ impl BoundLower for Card where CE: card::BoundLower, { - fn encode_lb(&mut self, range: R, collector: &mut Col, var_manager: &mut dyn ManageVars) + fn encode_lb( + &mut self, + range: R, + collector: &mut Col, + var_manager: &mut dyn ManageVars, + ) -> Result<(), crate::OutOfMemory> where Col: CollectClauses, R: RangeBounds, @@ -568,7 +602,8 @@ where range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: RangeBounds, { @@ -586,7 +621,8 @@ where range: R, collector: &mut Col, var_manager: &mut dyn ManageVars, - ) where + ) -> Result<(), crate::OutOfMemory> + where Col: CollectClauses, R: RangeBounds, { diff --git a/rustsat/src/instances/multiopt.rs b/rustsat/src/instances/multiopt.rs index 4ebeb142..48813562 100644 --- a/rustsat/src/instances/multiopt.rs +++ b/rustsat/src/instances/multiopt.rs @@ -245,8 +245,14 @@ impl MultiOptInstance { pub fn to_dimacs(self, writer: &mut W) -> Result<(), io::Error> { #[allow(deprecated)] self.to_dimacs_with_encoders( - card::default_encode_cardinality_constraint, - pb::default_encode_pb_constraint, + |constr, cnf, vm| { + card::default_encode_cardinality_constraint(constr, cnf, vm) + .expect("cardinality encoding ran out of memory") + }, + |constr, cnf, vm| { + pb::default_encode_pb_constraint(constr, cnf, vm) + .expect("pb encoding ran out of memory") + }, writer, ) } diff --git a/rustsat/src/instances/opt.rs b/rustsat/src/instances/opt.rs index 9c6e6193..e42ecbd8 100644 --- a/rustsat/src/instances/opt.rs +++ b/rustsat/src/instances/opt.rs @@ -1102,6 +1102,10 @@ impl OptInstance { /// Converts the instance to a set of hard and soft clauses, an objective /// offset and a variable manager + /// + /// # Panic + /// + /// This might panic if the conversion to [`Cnf`] runs out of memory. pub fn into_hard_cls_soft_cls(self) -> (Cnf, (impl WClsIter, isize), VM) { let (cnf, mut vm) = self.constrs.into_cnf(); if let Some(mv) = self.obj.max_var() { @@ -1112,6 +1116,10 @@ impl OptInstance { /// Converts the instance to a set of hard clauses and soft literals, an /// objective offset and a variable manager + /// + /// # Panic + /// + /// This might panic if the conversion to [`Cnf`] runs out of memory. pub fn into_hard_cls_soft_lits(self) -> (Cnf, (impl WLitIter, isize), VM) { let (mut cnf, mut vm) = self.constrs.into_cnf(); if let Some(mv) = self.obj.max_var() { @@ -1170,8 +1178,14 @@ impl OptInstance { pub fn to_dimacs(self, writer: &mut W) -> Result<(), io::Error> { #[allow(deprecated)] self.to_dimacs_with_encoders( - card::default_encode_cardinality_constraint, - pb::default_encode_pb_constraint, + |constr, cnf, vm| { + card::default_encode_cardinality_constraint(constr, cnf, vm) + .expect("cardinality encoding ran out of memory") + }, + |constr, cnf, vm| { + pb::default_encode_pb_constraint(constr, cnf, vm) + .expect("pb encoding ran out of memory") + }, writer, ) } diff --git a/rustsat/src/instances/sat.rs b/rustsat/src/instances/sat.rs index 287cc81d..06e473f2 100644 --- a/rustsat/src/instances/sat.rs +++ b/rustsat/src/instances/sat.rs @@ -1,6 +1,6 @@ //! # Satsifiability Instance Representations -use std::{collections::TryReserveError, io, ops::Index, path::Path}; +use std::{cmp, collections::TryReserveError, io, ops::Index, path::Path}; use crate::{ clause, @@ -10,6 +10,7 @@ use crate::{ constraints::{CardConstraint, PBConstraint}, Assignment, Clause, Lit, Var, }, + utils::LimitedIter, RequiresClausal, }; @@ -217,6 +218,28 @@ impl CollectClauses for Cnf { fn n_clauses(&self) -> usize { self.clauses.len() } + + fn extend_clauses(&mut self, cl_iter: T) -> Result<(), crate::OutOfMemory> + where + T: IntoIterator, + { + let cl_iter = cl_iter.into_iter(); + if let Some(ub) = cl_iter.size_hint().1 { + self.try_reserve(ub)?; + self.extend(cl_iter); + } else { + // Extend by reserving in exponential chunks + let mut cl_iter = cl_iter.peekable(); + while cl_iter.peek().is_some() { + let additional = (self.len() + cmp::max(cl_iter.size_hint().0, 1)) + .next_power_of_two() + - self.len(); + self.try_reserve(additional)?; + self.extend(LimitedIter::new(&mut cl_iter, additional)); + } + } + Ok(()) + } } impl IntoIterator for Cnf { @@ -505,10 +528,20 @@ impl SatInstance { /// Uses the default encoders from the `encodings` module. /// /// See [`Self::convert_to_cnf`] for converting in place + /// + /// # Panic + /// + /// This might panic if the conversion to [`Cnf`] runs out of memory. pub fn into_cnf(self) -> (Cnf, VM) { self.into_cnf_with_encoders( - card::default_encode_cardinality_constraint, - pb::default_encode_pb_constraint, + |constr, cnf, vm| { + card::default_encode_cardinality_constraint(constr, cnf, vm) + .expect("cardinality encoding ran out of memory") + }, + |constr, cnf, vm| { + pb::default_encode_pb_constraint(constr, cnf, vm) + .expect("pb encoding ran out of memory") + }, ) } @@ -516,10 +549,20 @@ impl SatInstance { /// Uses the default encoders from the `encodings` module. /// /// See [`Self::into_cnf`] if you don't need to convert in place + /// + /// # Panic + /// + /// This might panic if the conversion to [`Cnf`] runs out of memory. pub fn convert_to_cnf(&mut self) { self.convert_to_cnf_with_encoders( - card::default_encode_cardinality_constraint, - pb::default_encode_pb_constraint, + |constr, cnf, vm| { + card::default_encode_cardinality_constraint(constr, cnf, vm) + .expect("cardinality encoding ran out of memory") + }, + |constr, cnf, vm| { + pb::default_encode_pb_constraint(constr, cnf, vm) + .expect("pb encoding ran out of memory") + }, ) } @@ -545,6 +588,10 @@ impl SatInstance { /// converters for non-clausal constraints. /// /// See [`Self::into_cnf_with_encoders`] to convert in place + /// + /// # Panic + /// + /// The encoder functions might panic if the conversion runs out of memory. pub fn into_cnf_with_encoders( mut self, card_encoder: CardEnc, @@ -562,6 +609,10 @@ impl SatInstance { /// converters for non-clausal constraints. /// /// See [`Self::into_cnf_with_encoders`] if you don't need to convert in place + /// + /// # Panic + /// + /// The encoder functions might panic if the conversion runs out of memory. pub fn convert_to_cnf_with_encoders( &mut self, mut card_encoder: CardEnc, @@ -632,8 +683,14 @@ impl SatInstance { pub fn to_dimacs(self, writer: &mut W) -> Result<(), io::Error> { #[allow(deprecated)] self.to_dimacs_with_encoders( - card::default_encode_cardinality_constraint, - pb::default_encode_pb_constraint, + |constr, cnf, vm| { + card::default_encode_cardinality_constraint(constr, cnf, vm) + .expect("cardinality encoding ran out of memory") + }, + |constr, cnf, vm| { + pb::default_encode_pb_constraint(constr, cnf, vm) + .expect("pb encoding ran out of memory") + }, writer, ) } @@ -943,3 +1000,16 @@ impl From for SatInstance { inst } } + +impl CollectClauses for SatInstance { + fn n_clauses(&self) -> usize { + self.n_clauses() + } + + fn extend_clauses(&mut self, cl_iter: T) -> Result<(), crate::OutOfMemory> + where + T: IntoIterator, + { + self.cnf.extend_clauses(cl_iter) + } +} diff --git a/rustsat/src/lib.rs b/rustsat/src/lib.rs index d63bca7d..5984039c 100644 --- a/rustsat/src/lib.rs +++ b/rustsat/src/lib.rs @@ -66,6 +66,7 @@ #![cfg_attr(feature = "bench", feature(test))] use core::fmt; +use std::collections::TryReserveError; use thiserror::Error; @@ -89,9 +90,19 @@ impl fmt::Display for NotAllowed { #[cfg(test)] mod bench; -#[derive(Error, Debug)] -#[error("operation ran out of memory")] -pub struct OutOfMemory; +#[derive(Error, Debug, PartialEq, Eq)] +pub enum OutOfMemory { + #[error("try reserve error: {0}")] + TryReserve(TryReserveError), + #[error("external API operation ran out of memory")] + ExternalApi, +} + +impl From for OutOfMemory { + fn from(value: TryReserveError) -> Self { + OutOfMemory::TryReserve(value) + } +} #[derive(Error, Debug)] #[error("operation requires a clausal constraint(s) but it is not")] diff --git a/rustsat/src/solvers.rs b/rustsat/src/solvers.rs index 08917610..bf4ddebc 100644 --- a/rustsat/src/solvers.rs +++ b/rustsat/src/solvers.rs @@ -466,8 +466,34 @@ impl fmt::Display for StateError { } } +macro_rules! pass_oom_or_panic { + ($result:expr) => {{ + match $result { + Ok(res) => res, + Err(err) => match err.downcast::() { + Ok(oom) => return Err(oom), + Err(err) => panic!("unexpected error in clause collector: {err}"), + }, + } + }}; +} + impl CollectClauses for S { fn n_clauses(&self) -> usize { self.n_clauses() } + + fn extend_clauses(&mut self, cl_iter: T) -> Result<(), crate::OutOfMemory> + where + T: IntoIterator, + { + for cl in cl_iter { + pass_oom_or_panic!(self.add_clause(cl)); + } + Ok(()) + } + + fn add_clause(&mut self, cl: Clause) -> Result<(), crate::OutOfMemory> { + Ok(pass_oom_or_panic!(self.add_clause(cl))) + } } diff --git a/rustsat/src/utils.rs b/rustsat/src/utils.rs index f2732e54..dab597f6 100644 --- a/rustsat/src/utils.rs +++ b/rustsat/src/utils.rs @@ -32,6 +32,32 @@ pub(crate) fn digits(mut number: usize, mut basis: u8) -> u32 { digits } +pub struct LimitedIter<'iter, I> { + iter: &'iter mut I, + remaining: usize, +} + +impl<'iter, I> LimitedIter<'iter, I> { + pub fn new(iter: &'iter mut I, remaining: usize) -> Self { + Self { iter, remaining } + } +} + +impl Iterator for LimitedIter<'_, I> +where + I: Iterator, +{ + type Item = ::Item; + + fn next(&mut self) -> Option { + if self.remaining == 0 { + return None; + } + self.remaining -= 1; + self.iter.next() + } +} + #[cfg(test)] mod tests { #[test] diff --git a/rustsat/tests/card_encodings.rs b/rustsat/tests/card_encodings.rs index 1b5e3af3..ba31a267 100644 --- a/rustsat/tests/card_encodings.rs +++ b/rustsat/tests/card_encodings.rs @@ -39,7 +39,8 @@ fn test_inc_both_card + Default>() { let mut enc = CE::default(); enc.extend(vec![lit![0], lit![1], lit![2], lit![3], lit![4]]); - enc.encode_both(2..3, &mut solver, &mut var_manager); + enc.encode_both(2..3, &mut solver, &mut var_manager) + .unwrap(); let assumps = enc.enforce_lb(2).unwrap(); let res = solver.solve_assumps(&assumps).unwrap(); assert_eq!(res, SolverResult::Sat); @@ -48,31 +49,36 @@ fn test_inc_both_card + Default>() { let res = solver.solve_assumps(&assumps).unwrap(); assert_eq!(res, SolverResult::Unsat); - enc.encode_both_change(0..4, &mut solver, &mut var_manager); + enc.encode_both_change(0..4, &mut solver, &mut var_manager) + .unwrap(); let assumps = enc.enforce_ub(3).unwrap(); let res = solver.solve_assumps(&assumps).unwrap(); assert_eq!(res, SolverResult::Sat); enc.extend(vec![lit![5]]); - enc.encode_both_change(0..4, &mut solver, &mut var_manager); + enc.encode_both_change(0..4, &mut solver, &mut var_manager) + .unwrap(); let assumps = enc.enforce_ub(3).unwrap(); let res = solver.solve_assumps(&assumps).unwrap(); assert_eq!(res, SolverResult::Unsat); - enc.encode_both_change(0..5, &mut solver, &mut var_manager); + enc.encode_both_change(0..5, &mut solver, &mut var_manager) + .unwrap(); let assumps = enc.enforce_ub(4).unwrap(); let res = solver.solve_assumps(&assumps).unwrap(); assert_eq!(res, SolverResult::Sat); enc.extend(vec![lit![6], lit![7], lit![8], lit![9], lit![10]]); - enc.encode_both_change(0..5, &mut solver, &mut var_manager); + enc.encode_both_change(0..5, &mut solver, &mut var_manager) + .unwrap(); let assumps = enc.enforce_ub(4).unwrap(); let res = solver.solve_assumps(&assumps).unwrap(); assert_eq!(res, SolverResult::Unsat); - enc.encode_both_change(0..8, &mut solver, &mut var_manager); + enc.encode_both_change(0..8, &mut solver, &mut var_manager) + .unwrap(); let assumps = enc.enforce_ub(7).unwrap(); let res = solver.solve_assumps(&assumps).unwrap(); assert_eq!(res, SolverResult::Sat); @@ -103,36 +109,41 @@ fn test_inc_ub_card + Default>() { let mut enc = CE::default(); enc.extend(vec![lit![0], lit![1], lit![2], lit![3], lit![4]]); - enc.encode_ub(2..3, &mut solver, &mut var_manager); + enc.encode_ub(2..3, &mut solver, &mut var_manager).unwrap(); let assumps = enc.enforce_ub(2).unwrap(); let res = solver.solve_assumps(&assumps).unwrap(); assert_eq!(res, SolverResult::Unsat); - enc.encode_ub_change(0..4, &mut solver, &mut var_manager); + enc.encode_ub_change(0..4, &mut solver, &mut var_manager) + .unwrap(); let assumps = enc.enforce_ub(3).unwrap(); let res = solver.solve_assumps(&assumps).unwrap(); assert_eq!(res, SolverResult::Sat); enc.extend(vec![lit![5]]); - enc.encode_ub_change(0..4, &mut solver, &mut var_manager); + enc.encode_ub_change(0..4, &mut solver, &mut var_manager) + .unwrap(); let assumps = enc.enforce_ub(3).unwrap(); let res = solver.solve_assumps(&assumps).unwrap(); assert_eq!(res, SolverResult::Unsat); - enc.encode_ub_change(0..5, &mut solver, &mut var_manager); + enc.encode_ub_change(0..5, &mut solver, &mut var_manager) + .unwrap(); let assumps = enc.enforce_ub(4).unwrap(); let res = solver.solve_assumps(&assumps).unwrap(); assert_eq!(res, SolverResult::Sat); enc.extend(vec![lit![6], lit![7], lit![8], lit![9], lit![10]]); - enc.encode_ub_change(0..5, &mut solver, &mut var_manager); + enc.encode_ub_change(0..5, &mut solver, &mut var_manager) + .unwrap(); let assumps = enc.enforce_ub(4).unwrap(); let res = solver.solve_assumps(&assumps).unwrap(); assert_eq!(res, SolverResult::Unsat); - enc.encode_ub_change(0..8, &mut solver, &mut var_manager); + enc.encode_ub_change(0..8, &mut solver, &mut var_manager) + .unwrap(); let assumps = enc.enforce_ub(7).unwrap(); let res = solver.solve_assumps(&assumps).unwrap(); assert_eq!(res, SolverResult::Sat); @@ -156,7 +167,8 @@ fn test_both_card>>() { // Set up totalizer let mut enc = CE::from(vec![!lit![0], !lit![1], !lit![2], !lit![3], !lit![4]]); - enc.encode_both(2..4, &mut solver, &mut var_manager); + enc.encode_both(2..4, &mut solver, &mut var_manager) + .unwrap(); let assumps = enc.enforce_ub(2).unwrap(); let res = solver.solve_assumps(&assumps).unwrap(); assert_eq!(res, SolverResult::Sat); @@ -179,7 +191,8 @@ fn test_both_card_min_enc>>() { let mut enc = CE::from(vec![lit![0], lit![1], lit![2], lit![3]]); - enc.encode_both(3..4, &mut solver, &mut var_manager); + enc.encode_both(3..4, &mut solver, &mut var_manager) + .unwrap(); let mut assumps = enc.enforce_eq(3).unwrap(); assumps.extend(vec![lit![0], lit![1], lit![2], !lit![3]]); let res = solver.solve_assumps(&assumps).unwrap(); @@ -274,7 +287,7 @@ fn test_ub_exhaustive>>() { let mut var_manager = BasicVarManager::default(); var_manager.increase_next_free(var![4]); - enc.encode_ub(0..1, &mut solver, &mut var_manager); + enc.encode_ub(0..1, &mut solver, &mut var_manager).unwrap(); let assumps = enc.enforce_ub(0).unwrap(); test_all!( @@ -297,7 +310,8 @@ fn test_ub_exhaustive>>() { Sat // 0000 ); - enc.encode_ub_change(1..2, &mut solver, &mut var_manager); + enc.encode_ub_change(1..2, &mut solver, &mut var_manager) + .unwrap(); let assumps = enc.enforce_ub(1).unwrap(); test_all!( @@ -320,7 +334,8 @@ fn test_ub_exhaustive>>() { Sat // 0000 ); - enc.encode_ub_change(2..3, &mut solver, &mut var_manager); + enc.encode_ub_change(2..3, &mut solver, &mut var_manager) + .unwrap(); let assumps = enc.enforce_ub(2).unwrap(); test_all!( @@ -343,7 +358,8 @@ fn test_ub_exhaustive>>() { Sat // 0000 ); - enc.encode_ub_change(3..4, &mut solver, &mut var_manager); + enc.encode_ub_change(3..4, &mut solver, &mut var_manager) + .unwrap(); let assumps = enc.enforce_ub(3).unwrap(); test_all!( @@ -366,7 +382,8 @@ fn test_ub_exhaustive>>() { Sat // 0000 ); - enc.encode_ub_change(4..5, &mut solver, &mut var_manager); + enc.encode_ub_change(4..5, &mut solver, &mut var_manager) + .unwrap(); let assumps = enc.enforce_ub(4).unwrap(); test_all!( @@ -396,7 +413,8 @@ fn test_both_exhaustive>>() { let mut var_manager = BasicVarManager::default(); var_manager.increase_next_free(var![4]); - enc.encode_both(0..1, &mut solver, &mut var_manager); + enc.encode_both(0..1, &mut solver, &mut var_manager) + .unwrap(); let assumps = enc.enforce_eq(0).unwrap(); test_all!( @@ -419,7 +437,8 @@ fn test_both_exhaustive>>() { Sat // 0000 ); - enc.encode_both_change(1..2, &mut solver, &mut var_manager); + enc.encode_both_change(1..2, &mut solver, &mut var_manager) + .unwrap(); let assumps = enc.enforce_eq(1).unwrap(); test_all!( @@ -442,7 +461,8 @@ fn test_both_exhaustive>>() { Unsat // 0000 ); - enc.encode_both_change(2..3, &mut solver, &mut var_manager); + enc.encode_both_change(2..3, &mut solver, &mut var_manager) + .unwrap(); let assumps = enc.enforce_eq(2).unwrap(); test_all!( @@ -465,7 +485,8 @@ fn test_both_exhaustive>>() { Unsat // 0000 ); - enc.encode_both_change(3..4, &mut solver, &mut var_manager); + enc.encode_both_change(3..4, &mut solver, &mut var_manager) + .unwrap(); let assumps = enc.enforce_eq(3).unwrap(); test_all!( @@ -488,7 +509,8 @@ fn test_both_exhaustive>>() { Unsat // 0000 ); - enc.encode_both_change(4..5, &mut solver, &mut var_manager); + enc.encode_both_change(4..5, &mut solver, &mut var_manager) + .unwrap(); let assumps = enc.enforce_eq(4).unwrap(); test_all!( diff --git a/rustsat/tests/pb_encodings.rs b/rustsat/tests/pb_encodings.rs index 45dc8b80..176611bf 100644 --- a/rustsat/tests/pb_encodings.rs +++ b/rustsat/tests/pb_encodings.rs @@ -49,17 +49,19 @@ fn test_inc_pb_ub + Default>() let mut enc = PBE::default(); enc.extend(lits); - enc.encode_ub(0..3, &mut solver, &mut var_manager); + enc.encode_ub(0..3, &mut solver, &mut var_manager).unwrap(); let assumps = enc.enforce_ub(2).unwrap(); let res = solver.solve_assumps(&assumps).unwrap(); assert_eq!(res, SolverResult::Unsat); - enc.encode_ub_change(0..5, &mut solver, &mut var_manager); + enc.encode_ub_change(0..5, &mut solver, &mut var_manager) + .unwrap(); let assumps = enc.enforce_ub(4).unwrap(); let res = solver.solve_assumps(&assumps).unwrap(); assert_eq!(res, SolverResult::Unsat); - enc.encode_ub_change(0..6, &mut solver, &mut var_manager); + enc.encode_ub_change(0..6, &mut solver, &mut var_manager) + .unwrap(); let assumps = enc.enforce_ub(5).unwrap(); let res = solver.solve_assumps(&assumps).unwrap(); assert_eq!(res, SolverResult::Sat); @@ -68,12 +70,14 @@ fn test_inc_pb_ub + Default>() lits.insert(lit![5], 4); enc.extend(lits); - enc.encode_ub_change(0..6, &mut solver, &mut var_manager); + enc.encode_ub_change(0..6, &mut solver, &mut var_manager) + .unwrap(); let assumps = enc.enforce_ub(5).unwrap(); let res = solver.solve_assumps(&assumps).unwrap(); assert_eq!(res, SolverResult::Unsat); - enc.encode_ub_change(0..10, &mut solver, &mut var_manager); + enc.encode_ub_change(0..10, &mut solver, &mut var_manager) + .unwrap(); let assumps = enc.enforce_ub(9).unwrap(); let res = solver.solve_assumps(&assumps).unwrap(); assert_eq!(res, SolverResult::Sat); @@ -86,12 +90,14 @@ fn test_inc_pb_ub + Default>() lits.insert(lit![10], 2); enc.extend(lits); - enc.encode_ub_change(0..10, &mut solver, &mut var_manager); + enc.encode_ub_change(0..10, &mut solver, &mut var_manager) + .unwrap(); let assumps = enc.enforce_ub(9).unwrap(); let res = solver.solve_assumps(&assumps).unwrap(); assert_eq!(res, SolverResult::Unsat); - enc.encode_ub_change(0..15, &mut solver, &mut var_manager); + enc.encode_ub_change(0..15, &mut solver, &mut var_manager) + .unwrap(); let assumps = enc.enforce_ub(14).unwrap(); let res = solver.solve_assumps(&assumps).unwrap(); assert_eq!(res, SolverResult::Sat); @@ -109,7 +115,8 @@ fn test_pb_eq>>() { lits.insert(lit![2], 2); let mut enc = PBE::from(lits); - enc.encode_both(4..5, &mut solver, &mut var_manager); + enc.encode_both(4..5, &mut solver, &mut var_manager) + .unwrap(); let mut assumps = enc.enforce_eq(4).unwrap(); assumps.extend(vec![lit![0], lit![1], lit![2]]); @@ -170,7 +177,7 @@ fn test_pb_lb>>() { lits.insert(lit![2], 3); let mut enc = PBE::from(lits); - enc.encode_lb(0..11, &mut solver, &mut var_manager); + enc.encode_lb(0..11, &mut solver, &mut var_manager).unwrap(); let assumps = enc.enforce_lb(10).unwrap(); let res = solver.solve_assumps(&assumps).unwrap(); assert_eq!(res, SolverResult::Unsat); @@ -193,7 +200,7 @@ fn test_pb_ub_min_enc>>() { lits.insert(lit![2], 1); let mut enc = PBE::from(lits); - enc.encode_ub(2..3, &mut solver, &mut var_manager); + enc.encode_ub(2..3, &mut solver, &mut var_manager).unwrap(); let mut assumps = enc.enforce_ub(2).unwrap(); assumps.extend(vec![lit![0], lit![1], lit![2]]); let res = solver.solve_assumps(&assumps).unwrap(); @@ -291,7 +298,7 @@ fn test_ub_exhaustive>>( let mut var_manager = BasicVarManager::default(); var_manager.increase_next_free(var![4]); - let max_val = weights.iter().fold(0, |sum, &w| sum + w); + let max_val = weights.iter().sum::(); let expected = |assign: usize, bound: usize| { let sum = (0..4).fold(0, |sum, idx| sum + ((assign >> idx) & 1) * weights[3 - idx]); if sum <= bound { @@ -306,7 +313,8 @@ fn test_ub_exhaustive>>( bound = max_val - bound; } - enc.encode_ub_change(bound..bound + 1, &mut solver, &mut var_manager); + enc.encode_ub_change(bound..bound + 1, &mut solver, &mut var_manager) + .unwrap(); let assumps = enc.enforce_ub(bound).unwrap(); test_all!(