diff --git a/nnsmith/abstract/op.py b/nnsmith/abstract/op.py index 080994dc..e3aef49c 100644 --- a/nnsmith/abstract/op.py +++ b/nnsmith/abstract/op.py @@ -543,14 +543,17 @@ def __repr__(self) -> str: def numeric_valid(self, outputs, inputs) -> bool: with torch.no_grad(): - cond1 = not any([torch.isnan(out).any() or torch.isinf( + return not any([torch.isnan(out).any() or torch.isinf( out).any() for out in outputs]) - cond2 = True - if hasattr(self, 'torch_loss'): + + def numeric_stable(self, outputs, inputs) -> bool: + with torch.no_grad(): + cond = True + if hasattr(self, 'torch_loss') and isinstance(self, (Ceil, Floor, Round, Cast, Sin, Cos)): loss = self.torch_loss(*inputs) loss = loss[1] if isinstance(loss, tuple) else loss - cond2 = torch.all(loss <= 0) - return cond1 and cond2 + cond = torch.all(loss <= 0) + return cond def concretize(op: AbsOpBase, model: Optional[z3.ModelRef]) -> AbsOpBase: diff --git a/nnsmith/graph_gen.py b/nnsmith/graph_gen.py index d848e6ff..eb9ff764 100644 --- a/nnsmith/graph_gen.py +++ b/nnsmith/graph_gen.py @@ -183,6 +183,7 @@ def __init__(self, graph: nx.MultiDiGraph, model: z3.ModelRef, verbose=False, al self.enable_training() self.check_intermediate_numeric = False self.invalid_found_last = None + self.unstable_as_invalid = False def concretize_svars(self, node_id, shape_vars, model, op): # concretize shapevars @@ -290,6 +291,7 @@ def backward(self): def training_reset(self): self.loss = None self.stop_updating_loss = False + self.unstable_as_invalid = False def stop_training(self): self.use_gradient = False @@ -363,9 +365,15 @@ def grad_input_gen(self, init_tensors, use_cuda=False, _ = self(*inputs) if self.invalid_found_last: # need_to_train self.backward() - else: - sat_inputs = [v.data for v in inputs] - break + continue + self.training_reset() + self.unstable_as_invalid = True + _ = self(*inputs) + if self.invalid_found_last: # need_to_train + self.backward() + continue + sat_inputs = [v.data for v in inputs] + break except ConstraintError as e: if __INPUT_FOUND_INF_MSG__ in str(e) or __INPUT_FOUND_NAN_MSG__ in str(e): # flush NaN/Inf in inputs @@ -473,6 +481,9 @@ def forward(self, *args, **kwargs): vul_op_loss = None self.invalid_found_last |= not op.numeric_valid( outputs, input_tensors) + if self.unstable_as_invalid: + self.invalid_found_last |= not op.numeric_stable( + outputs, input_tensors) if self.invalid_found_last and (self.use_gradient and not self.stop_updating_loss): if self.print_grad >= 1: diff --git a/nnsmith/input_gen.py b/nnsmith/input_gen.py index 55fc8a02..d6699b09 100644 --- a/nnsmith/input_gen.py +++ b/nnsmith/input_gen.py @@ -65,6 +65,7 @@ class SamplingSearch(InputSearchBase): def search_one(self, start_inp, timeout_ms: int = None) -> List[torch.Tensor]: with torch.no_grad(): self.net.check_intermediate_numeric = True + self.net.unstable_as_invalid = True _ = self.net(*start_inp) if not self.net.invalid_found_last: return start_inp