Skip to content

Commit

Permalink
Separate invalid from unstable.
Browse files Browse the repository at this point in the history
  • Loading branch information
lazycal committed Jul 1, 2022
1 parent 4ac51da commit e05b1f5
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 8 deletions.
13 changes: 8 additions & 5 deletions nnsmith/abstract/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,14 +543,17 @@ def __repr__(self) -> str:

def numeric_valid(self, outputs, inputs) -> bool:
with torch.no_grad():
cond1 = not any([torch.isnan(out).any() or torch.isinf(
return not any([torch.isnan(out).any() or torch.isinf(
out).any() for out in outputs])
cond2 = True
if hasattr(self, 'torch_loss'):

def numeric_stable(self, outputs, inputs) -> bool:
with torch.no_grad():
cond = True
if hasattr(self, 'torch_loss') and isinstance(self, (Ceil, Floor, Round, Cast, Sin, Cos)):
loss = self.torch_loss(*inputs)
loss = loss[1] if isinstance(loss, tuple) else loss
cond2 = torch.all(loss <= 0)
return cond1 and cond2
cond = torch.all(loss <= 0)
return cond


def concretize(op: AbsOpBase, model: Optional[z3.ModelRef]) -> AbsOpBase:
Expand Down
17 changes: 14 additions & 3 deletions nnsmith/graph_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def __init__(self, graph: nx.MultiDiGraph, model: z3.ModelRef, verbose=False, al
self.enable_training()
self.check_intermediate_numeric = False
self.invalid_found_last = None
self.unstable_as_invalid = False

def concretize_svars(self, node_id, shape_vars, model, op):
# concretize shapevars
Expand Down Expand Up @@ -290,6 +291,7 @@ def backward(self):
def training_reset(self):
self.loss = None
self.stop_updating_loss = False
self.unstable_as_invalid = False

def stop_training(self):
self.use_gradient = False
Expand Down Expand Up @@ -363,9 +365,15 @@ def grad_input_gen(self, init_tensors, use_cuda=False,
_ = self(*inputs)
if self.invalid_found_last: # need_to_train
self.backward()
else:
sat_inputs = [v.data for v in inputs]
break
continue
self.training_reset()
self.unstable_as_invalid = True
_ = self(*inputs)
if self.invalid_found_last: # need_to_train
self.backward()
continue
sat_inputs = [v.data for v in inputs]
break
except ConstraintError as e:
if __INPUT_FOUND_INF_MSG__ in str(e) or __INPUT_FOUND_NAN_MSG__ in str(e):
# flush NaN/Inf in inputs
Expand Down Expand Up @@ -473,6 +481,9 @@ def forward(self, *args, **kwargs):
vul_op_loss = None
self.invalid_found_last |= not op.numeric_valid(
outputs, input_tensors)
if self.unstable_as_invalid:
self.invalid_found_last |= not op.numeric_stable(
outputs, input_tensors)

if self.invalid_found_last and (self.use_gradient and not self.stop_updating_loss):
if self.print_grad >= 1:
Expand Down
1 change: 1 addition & 0 deletions nnsmith/input_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class SamplingSearch(InputSearchBase):
def search_one(self, start_inp, timeout_ms: int = None) -> List[torch.Tensor]:
with torch.no_grad():
self.net.check_intermediate_numeric = True
self.net.unstable_as_invalid = True
_ = self.net(*start_inp)
if not self.net.invalid_found_last:
return start_inp
Expand Down

0 comments on commit e05b1f5

Please sign in to comment.