Skip to content

Commit

Permalink
Merge pull request #95 from dlshriver/fix-trivialprops
Browse files Browse the repository at this point in the history
fix: special case for trivial (network-free) properties
  • Loading branch information
dlshriver committed Jul 16, 2022
2 parents 7d6d643 + 347e53c commit 4d4b124
Show file tree
Hide file tree
Showing 7 changed files with 314 additions and 63 deletions.
30 changes: 24 additions & 6 deletions dnnv/verifiers/common/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,18 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from functools import partial
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union
from typing import (
Any,
Callable,
Dict,
Generator,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
)

from dnnv.properties import Expression, LogicalExpression

Expand All @@ -23,7 +34,7 @@
class Parameter:
def __init__(
self,
dtype: Type,
dtype: Callable[[Any], Any],
default: Optional[Any] = None,
choices: Optional[List[Any]] = None,
help: Optional[str] = None,
Expand All @@ -43,7 +54,7 @@ def __init__(
self.choices = choices
self.help = help

def as_type(self, value):
def as_type(self, value: Any):
return self.type(value) if value is not None else None


Expand Down Expand Up @@ -116,7 +127,14 @@ def run(self) -> Tuple[PropertyCheckResult, Optional[Any]]:
tempfile.tempdir = tempdir
result = UNSAT
for subproperty in self.reduce_property():
subproperty_result, cex = self.check(subproperty)
is_trivial, *trivial_result = subproperty.is_trivial()
if is_trivial:
subproperty_result, cex = trivial_result[0]
self.logger.warning(
"Property is trivially %s.", subproperty_result
)
else:
subproperty_result, cex = self.check(subproperty)
result |= subproperty_result
if result == SAT:
if cex is not None:
Expand All @@ -136,13 +154,13 @@ def validate_counter_example(self, prop: Property, cex: Any) -> bool:
return is_valid

@abstractmethod
def build_inputs(self, prop: Property) -> Tuple[Any, ...]:
def build_inputs(self, prop: Property) -> Sequence: # pragma: no cover
raise NotImplementedError()

@abstractmethod
def parse_results(
self, prop: Property, results: Any
) -> Tuple[PropertyCheckResult, Optional[Any]]:
) -> Tuple[PropertyCheckResult, Optional[Any]]: # pragma: no cover
raise NotImplementedError()


Expand Down
12 changes: 10 additions & 2 deletions dnnv/verifiers/common/reductions/base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
from abc import ABC, abstractmethod
from typing import Any, Iterator, Optional, Tuple, Type
from typing import Any, Iterator, Optional, Tuple, Type, Union

from dnnv.errors import DNNVError
from dnnv.properties import Expression

from ..results import PropertyCheckResult


class ReductionError(DNNVError):
pass


class Property:
class Property(ABC):
@abstractmethod
def is_trivial(
self,
) -> Union[Tuple[bool], Tuple[bool, Tuple[PropertyCheckResult, Any]]]:
raise NotImplementedError()

@abstractmethod
def validate_counter_example(self, cex: Any) -> Tuple[bool, Optional[str]]:
raise NotImplementedError()
Expand Down
16 changes: 10 additions & 6 deletions dnnv/verifiers/common/reductions/iopolytope/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
class Variable:
_count = 0

def __init__(self, shape: Tuple[int, ...], name: Optional[str] = None):
def __init__(self, shape: Sequence[int], name: Optional[str] = None):
self.shape = shape
self.name = name
if self.name is None:
self.name = f"x_{Variable._count}"
Variable._count += 1

def size(self) -> int:
return np.product(self.shape)
return int(np.product(self.shape))

def __str__(self):
return self.name
Expand Down Expand Up @@ -58,7 +58,7 @@ def add_variable(self, variable: Variable) -> Constraint:
self.variables[variable] = self.size()
return self

def unravel_index(self, index: int) -> Tuple[Variable, Tuple[np.intp, ...]]:
def unravel_index(self, index: int) -> Tuple[Variable, Sequence[np.intp]]:
c_size = self.size()
for variable, size in sorted(self.variables.items(), key=lambda kv: -kv[1]):
if size <= index < c_size:
Expand All @@ -75,7 +75,7 @@ def as_bounds(self) -> Tuple[np.ndarray, np.ndarray]:
def update_constraint(
self,
variables: Sequence[Variable],
indices: Sequence[Tuple[int, ...]],
indices: Sequence[Sequence[int]],
coefficients: Sequence[float],
b: float,
is_open=False,
Expand Down Expand Up @@ -148,6 +148,8 @@ def is_consistent(self):
A, b = self.as_matrix_inequality()
obj = np.zeros(A.shape[1])
lb, ub = self.as_bounds()
if A.size == 0:
return np.all(lb <= ub)
# linprog breaks if bounds are too big or too small
bounds = list(
zip(
Expand Down Expand Up @@ -264,7 +266,7 @@ def _update_bounds(
def update_constraint(
self,
variables: Sequence[Variable],
indices: Sequence[Tuple[int, ...]],
indices: Sequence[Sequence[int]],
coefficients: Sequence[float],
b: float,
is_open=False,
Expand Down Expand Up @@ -293,11 +295,13 @@ def update_constraint(
self._update_bounds(flat_indices, coefficients, b, is_open=is_open)

def validate(self, *x: np.ndarray, threshold: float = 1e-6) -> bool:
if self.size() == 0:
return True
if len(x) != len(self.variables):
return False
x_flat_ = []
for x_, v in zip(x, self.variables):
if x_.shape != v.shape:
if x_.size != v.size():
return False
x_flat_.append(x_.flatten())
x_flat = np.concatenate(x_flat_)
Expand Down
34 changes: 33 additions & 1 deletion dnnv/verifiers/common/reductions/iopolytope/property.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

from typing import List, Optional, Sequence, Tuple, Union
from typing import Any, List, Optional, Sequence, Tuple, Union

import numpy as np
from scipy.optimize import linprog

from .....nn import OperationGraph, OperationTransformer, operations
from .....properties import Network
from ...results import SAT, PropertyCheckResult
from ..base import Property
from .base import HalfspacePolytope, HyperRectangle

Expand Down Expand Up @@ -59,6 +61,36 @@ def __str__(self):
]
return "\n".join(strs)

def is_trivial(
self,
) -> Union[Tuple[bool], Tuple[bool, Tuple[PropertyCheckResult, Any]]]:
is_trivial = (
self.output_constraint.size() == 0 and self.input_constraint.is_consistent
)
if not is_trivial:
return (False,)
A, b = self.input_constraint.as_matrix_inequality()
obj = np.zeros(A.shape[1])
lb, ub = self.input_constraint.as_bounds()
if A.size == 0:
cex = (lb + ub) / 2
else:
bounds = list(
zip(
(b if b > -1e6 else None for b in lb),
(b if b < 1e6 else None for b in ub),
)
)
result = linprog(
obj,
A_ub=A,
b_ub=b,
bounds=bounds,
method="highs",
)
cex = result.x
return (is_trivial, (SAT, cex))

def validate_counter_example(
self, cex: np.ndarray, threshold=1e-6
) -> Tuple[bool, Optional[str]]:
Expand Down
66 changes: 18 additions & 48 deletions dnnv/verifiers/common/reductions/iopolytope/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,13 @@ def __init__(
self.output_constraint_type = output_constraint_type
self.logger = logging.getLogger(__name__)
self._stack: List[Expression] = []
self._network_input_shapes: Dict[Expression, Tuple[int, ...]] = {}
self._network_output_shapes: Dict[Network, Tuple[int, ...]] = {}
self.initialize()

def initialize(self):
self.input = None
self.networks = []
self.input_constraint = None
self.output_constraint = None
self.input_constraint = self.input_constraint_type()
self.output_constraint = self.output_constraint_type()
self.variables: Dict[Expression, Variable] = {}
self.indices: Dict[Expression, np.ndarray] = {}
self.coefs: Dict[Expression, np.ndarray] = {}
Expand All @@ -52,7 +50,10 @@ def build_property(self):
def _reduce(self, expression: And) -> Iterator[Property]:
self.initialize()
if len(expression.variables) != 1:
raise self.reduction_error("Exactly one network input is required")
raise self.reduction_error(
"At most one symbolic variable is allowed."
f" Received: {tuple(str(v) for v in expression.variables)}"
)
self.visit(expression)
prop = self.build_property()
if not prop.input_constraint.is_consistent:
Expand Down Expand Up @@ -128,26 +129,15 @@ def visit_Call(self, expression: Call):
"Unsupported property: Executing networks with keyword arguments"
" is not currently supported"
)
for arg, d in zip(expression.args, input_details):
if arg in self._network_input_shapes:
if any(
i1 != i2 and i2 > 0
for i1, i2 in zip(
self._network_input_shapes[arg], tuple(d.shape)
)
):
raise self.reduction_error(
f"Invalid property: variable with multiple shapes: '{arg}'"
)
self._network_input_shapes[arg] = tuple(
i if i > 0 else 1 for i in d.shape
)
for arg in expression.args:
self.visit(arg)
shape = self._network_output_shapes[expression.function]
self.variables[expression] = self.variables[expression.function]
shape = expression.ctx.shapes[expression]
variable = Variable(shape, str(expression.function))
self.output_constraint.add_variable(variable)
self.variables[expression] = variable
self.indices[expression] = np.array(
[i for i in np.ndindex(*shape)]
).reshape(shape + (len(shape),))
).reshape(*shape, len(shape))
self.coefs[expression] = np.ones(shape)
else:
raise self.reduction_error(
Expand Down Expand Up @@ -299,26 +289,6 @@ def visit_Network(self, expression: Network):
"Networks with multiple output operations"
" are not currently supported"
)
if expression not in self._network_output_shapes:
self._network_output_shapes[expression] = expression.value.output_shape[
0
]
elif (
self._network_output_shapes[expression]
!= expression.value.output_shape[0]
):
raise self.reduction_error(
f"Invalid property: network with multiple shapes: '{expression}'"
)
variable = Variable(
self._network_output_shapes[expression], str(expression)
)
if self.output_constraint is None:
self.output_constraint = self.output_constraint_type(variable)
else:
self.output_constraint = self.output_constraint.add_variable(variable)
variable = Variable(self._network_output_shapes[expression], str(expression))
self.variables[expression] = variable
return expression

def visit_Subscript(self, expression: Subscript):
Expand All @@ -333,18 +303,18 @@ def visit_Subscript(self, expression: Subscript):
def visit_Symbol(self, expression: Symbol):
if self.input is None:
self.input = expression
if expression not in self._network_input_shapes:
if expression not in expression.ctx.shapes:
raise self.reduction_error(f"Unknown shape for variable {expression}")
variable = Variable(self._network_input_shapes[expression], str(expression))
self.input_constraint = self.input_constraint_type(variable)
variable = Variable(expression.ctx.shapes[expression], str(expression))
self.input_constraint.add_variable(variable)
elif self.input is not expression:
raise self.reduction_error("Multiple inputs detected in property")
shape = self._network_input_shapes[expression]
shape = expression.ctx.shapes[expression]
self.variables[expression] = Variable(
self._network_input_shapes[expression], str(expression)
expression.ctx.shapes[expression], str(expression)
)
self.indices[expression] = np.array(list(np.ndindex(*shape))).reshape(
shape + (len(shape),)
*shape, len(shape)
)
self.coefs[expression] = np.ones(shape)

Expand Down

0 comments on commit 4d4b124

Please sign in to comment.