Skip to content

Commit

Permalink
feat: catch memory outs in clause collector
Browse files Browse the repository at this point in the history
extend `OutOfMemory` error and make `CollectClauses` and encodings
return it if not enough memory is available
  • Loading branch information
chrjabs committed Apr 25, 2024
1 parent 0506820 commit 4e46b0c
Show file tree
Hide file tree
Showing 27 changed files with 881 additions and 394 deletions.
5 changes: 4 additions & 1 deletion cadical/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ macro_rules! handle_oom {
($val:expr) => {{
let val = $val;
if val == crate::OUT_OF_MEM {
return anyhow::Context::context(Err(rustsat::OutOfMemory), "cadical out of memory");
return anyhow::Context::context(
Err(rustsat::OutOfMemory::ExternalApi),
"cadical out of memory",
);
}
val
}};
Expand Down
27 changes: 25 additions & 2 deletions capi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,25 @@ pub mod encodings {
fn n_clauses(&self) -> usize {
self.n_clauses
}

fn extend_clauses<T>(&mut self, cl_iter: T) -> Result<(), rustsat::OutOfMemory>
where
T: IntoIterator<Item = Clause>,
{
cl_iter.into_iter().for_each(|cl| {
cl.into_iter()
.for_each(|l| (self.ccol)(l.to_ipasir(), self.cdata));
(self.ccol)(0, self.cdata);
});
Ok(())
}

fn add_clause(&mut self, cl: Clause) -> Result<(), rustsat::OutOfMemory> {
cl.into_iter()
.for_each(|l| (self.ccol)(l.to_ipasir(), self.cdata));
(self.ccol)(0, self.cdata);
Ok(())
}
}

impl Extend<Clause> for ClauseCollector {
Expand Down Expand Up @@ -186,7 +205,9 @@ pub mod encodings {
let mut collector = ClauseCollector::new(collector, collector_data);
let mut var_manager = VarManager::new(n_vars_used);
let mut boxed = unsafe { Box::from_raw(tot) };
boxed.encode_ub_change(min_bound..=max_bound, &mut collector, &mut var_manager);
boxed
.encode_ub_change(min_bound..=max_bound, &mut collector, &mut var_manager)
.expect("clause collect returned out of memory");
Box::into_raw(boxed);
}

Expand Down Expand Up @@ -358,7 +379,9 @@ pub mod encodings {
let mut collector = ClauseCollector::new(collector, collector_data);
let mut var_manager = VarManager::new(n_vars_used);
let mut boxed = unsafe { Box::from_raw(dpw) };
boxed.encode_ub_change(min_bound..=max_bound, &mut collector, &mut var_manager);
boxed
.encode_ub_change(min_bound..=max_bound, &mut collector, &mut var_manager)
.expect("clause collector returned out of memory");
Box::into_raw(boxed);
}

Expand Down
5 changes: 4 additions & 1 deletion glucose/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@ macro_rules! handle_oom {
($val:expr) => {{
let val = $val;
if val == crate::OUT_OF_MEM {
return anyhow::Context::context(Err(rustsat::OutOfMemory), "glucose out of memory");
return anyhow::Context::context(
Err(rustsat::OutOfMemory::ExternalApi),
"glucose out of memory",
);
}
val
}};
Expand Down
5 changes: 4 additions & 1 deletion minisat/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@ macro_rules! handle_oom {
($val:expr) => {{
let val = $val;
if val == crate::OUT_OF_MEM {
return anyhow::Context::context(Err(rustsat::OutOfMemory), "minisat out of memory");
return anyhow::Context::context(
Err(rustsat::OutOfMemory::ExternalApi),
"minisat out of memory",
);
}
val
}};
Expand Down
43 changes: 31 additions & 12 deletions pyapi/src/encodings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use rustsat::{
};

use crate::{
handle_oom,
instances::{Cnf, VarManager},
types::Lit,
};
Expand Down Expand Up @@ -89,12 +90,18 @@ impl Totalizer {
/// Incrementally builds the totalizer encoding to that upper bounds
/// in the range `max_ub..=min_ub` can be enforced. New variables will
/// be taken from `var_manager`.
fn encode_ub(&mut self, max_ub: usize, min_ub: usize, var_manager: &mut VarManager) -> Cnf {
fn encode_ub(
&mut self,
max_ub: usize,
min_ub: usize,
var_manager: &mut VarManager,
) -> PyResult<Cnf> {
let mut cnf = RsCnf::new();
let var_manager: &mut BasicVarManager = var_manager.into();
self.0
.encode_ub_change(max_ub..=min_ub, &mut cnf, var_manager);
cnf.into()
handle_oom!(self
.0
.encode_ub_change(max_ub..=min_ub, &mut cnf, var_manager));
Ok(cnf.into())
}

/// Gets assumptions to enforce the given upper bound. Make sure that
Expand Down Expand Up @@ -165,12 +172,18 @@ impl GeneralizedTotalizer {
/// Incrementally builds the GTE encoding to that upper bounds
/// in the range `max_ub..=min_ub` can be enforced. New variables will
/// be taken from `var_manager`.
fn encode_ub(&mut self, max_ub: usize, min_ub: usize, var_manager: &mut VarManager) -> Cnf {
fn encode_ub(
&mut self,
max_ub: usize,
min_ub: usize,
var_manager: &mut VarManager,
) -> PyResult<Cnf> {
let mut cnf = RsCnf::new();
let var_manager: &mut BasicVarManager = var_manager.into();
self.0
.encode_ub_change(max_ub..=min_ub, &mut cnf, var_manager);
cnf.into()
handle_oom!(self
.0
.encode_ub_change(max_ub..=min_ub, &mut cnf, var_manager));
Ok(cnf.into())
}

/// Gets assumptions to enforce the given upper bound. Make sure that
Expand Down Expand Up @@ -235,12 +248,18 @@ impl DynamicPolyWatchdog {
/// Incrementally builds the DPW encoding to that upper bounds
/// in the range `max_ub..=min_ub` can be enforced. New variables will
/// be taken from `var_manager`.
fn encode_ub(&mut self, max_ub: usize, min_ub: usize, var_manager: &mut VarManager) -> Cnf {
fn encode_ub(
&mut self,
max_ub: usize,
min_ub: usize,
var_manager: &mut VarManager,
) -> PyResult<Cnf> {
let mut cnf = RsCnf::new();
let var_manager: &mut BasicVarManager = var_manager.into();
self.0
.encode_ub_change(max_ub..=min_ub, &mut cnf, var_manager);
cnf.into()
handle_oom!(self
.0
.encode_ub_change(max_ub..=min_ub, &mut cnf, var_manager));
Ok(cnf.into())
}

/// Gets assumptions to enforce the given upper bound. Make sure that
Expand Down
10 changes: 10 additions & 0 deletions pyapi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,13 @@ fn rustsat(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {

Ok(())
}

macro_rules! handle_oom {
($result:expr) => {{
match $result {
Ok(val) => val,
Err(err) => return Err(pyo3::exceptions::PyMemoryError::new_err(format!("{}", err))),
}
}};
}
pub(crate) use handle_oom;
2 changes: 1 addition & 1 deletion rustsat/examples/profiling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ fn build_full_ub<PBE: BoundUpper + FromIterator<(Lit, usize)>>(lits: &[(Lit, usi
let mut var_manager = BasicVarManager::default();
var_manager.increase_next_free(max_var + 1);
let mut collector = Cnf::new();
enc.encode_ub(.., &mut collector, &mut var_manager);
enc.encode_ub(.., &mut collector, &mut var_manager).unwrap();
}

fn main() {
Expand Down
16 changes: 14 additions & 2 deletions rustsat/src/encodings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

use thiserror::Error;

use crate::types::Lit;
use crate::types::{Clause, Lit};

pub mod am1;
pub mod atomics;
Expand All @@ -13,9 +13,21 @@ pub mod pb;

/// Trait for collecting clauses. Mainly used when generating encodings and implemented by
/// [`crate::instances::Cnf`], and solvers.
pub trait CollectClauses: Extend<crate::types::Clause> {
pub trait CollectClauses {
/// Gets the number of clauses in the collection
fn n_clauses(&self) -> usize;
/// Extends the clause collector with an iterator of clauses
///
/// # Error
///
/// If the collector runs out of memory, return an [`crate::OutOfMemory`] error.
fn extend_clauses<T>(&mut self, cl_iter: T) -> Result<(), crate::OutOfMemory>
where
T: IntoIterator<Item = Clause>;
/// Adds one clause to the collector
fn add_clause(&mut self, cl: Clause) -> Result<(), crate::OutOfMemory> {
self.extend_clauses([cl])
}
}

/// Errors from encodings
Expand Down
4 changes: 2 additions & 2 deletions rustsat/src/encodings/am1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
//! enc.encode(&mut encoding, &mut var_manager).unwrap();
//! ```

use super::{CollectClauses, Error};
use super::CollectClauses;
use crate::instances::ManageVars;

mod pairwise;
Expand All @@ -35,7 +35,7 @@ pub trait Encode {
&mut self,
collector: &mut Col,
var_manager: &mut dyn ManageVars,
) -> Result<(), Error>
) -> Result<(), crate::OutOfMemory>
where
Col: CollectClauses;
}
Expand Down
6 changes: 3 additions & 3 deletions rustsat/src/encodings/am1/pairwise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
use super::Encode;
use crate::{
clause,
encodings::{CollectClauses, EncodeStats, Error, IterInputs},
encodings::{CollectClauses, EncodeStats, IterInputs},
instances::ManageVars,
types::Lit,
};
Expand All @@ -34,7 +34,7 @@ impl Encode for Pairwise {
&mut self,
collector: &mut Col,
_var_manager: &mut dyn ManageVars,
) -> Result<(), Error>
) -> Result<(), crate::OutOfMemory>
where
Col: CollectClauses,
{
Expand All @@ -43,7 +43,7 @@ impl Encode for Pairwise {
let clause_iter = (0..self.in_lits.len()).flat_map(|first| {
(first + 1..self.in_lits.len()).map(move |second| clause![!lits[first], !lits[second]])
});
collector.extend(clause_iter);
collector.extend_clauses(clause_iter)?;
self.n_clauses = collector.n_clauses() - prev_clauses;
Ok(())
}
Expand Down

0 comments on commit 4e46b0c

Please sign in to comment.