Skip to content

Commit

Permalink
Add a predict method to DecomonModel for single batch numpy arrays
Browse files Browse the repository at this point in the history
The idea is to avoid calling predict() which is known to be
not designed for small arrays, and leads to memory leaks when used in loops.

See https://keras.io/api/models/model_training_apis/#predict-method and
tensorflow/tensorflow#44711

Use it in wrapper instead of predict().
  • Loading branch information
nhuet committed Nov 2, 2023
1 parent ce0fbd9 commit c520e54
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 6 deletions.
24 changes: 24 additions & 0 deletions src/decomon/models/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Dict, List, Optional, Union

import keras
import numpy as np

from decomon.core import (
BoxDomain,
Expand Down Expand Up @@ -71,6 +72,29 @@ def reset_finetuning(self) -> None:
if hasattr(layer, "reset_finetuning"):
layer.reset_finetuning()

def predict_on_single_batch_np(
self, inputs: Union[np.ndarray, List[np.ndarray]]
) -> Union[np.ndarray, List[np.ndarray]]:
"""Make predictions on numpy arrays fitting in one batch
Avoid using `self.predict()` known to be not designed for small arrays,
and leading to memory leaks when used in loops.
See https://keras.io/api/models/model_training_apis/#predict-method and
https://github.com/tensorflow/tensorflow/issues/44711
Args:
inputs:
Returns:
"""
output_tensors = self(inputs)
if isinstance(output_tensors, list):
return [output.numpy() for output in output_tensors]
else:
return output_tensors.numpy()


def _check_domain(
perturbation_domain_prev: PerturbationDomain, perturbation_domain: PerturbationDomain
Expand Down
12 changes: 6 additions & 6 deletions src/decomon/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,9 @@ def get_adv_box(
output: npt.NDArray[np.float_]
if decomon_model.backward_bounds:
C = np.diag([1] * n_label)[None] - source_labels[:, :, None]
output = decomon_model.predict([z, C], verbose=0)
output = decomon_model.predict_on_single_batch_np([z, C])
else:
output = decomon_model.predict(z, verbose=0)
output = decomon_model.predict_on_single_batch_np(z)

def get_ibp_score(
u_c: npt.NDArray[np.float_],
Expand Down Expand Up @@ -361,7 +361,7 @@ def check_adv_box(
)

else:
output = decomon_model.predict(z, verbose=0)
output = decomon_model.predict_on_single_batch_np(z)

if not affine:
# translate into affine information
Expand Down Expand Up @@ -532,7 +532,7 @@ def get_range_box(
ibp = decomon_model.ibp
affine = decomon_model.affine

output = decomon_model.predict(z, verbose=0)
output = decomon_model.predict_on_single_batch_np(z)
shape = list(output[-1].shape[1:])
output_dim = np.prod(shape)

Expand Down Expand Up @@ -701,7 +701,7 @@ def get_range_noise(
ibp = decomon_model.ibp
affine = decomon_model.affine

output = decomon_model.predict(x_reshaped, verbose=0)
output = decomon_model.predict_on_single_batch_np(x_reshaped)
shape = list(output[-1].shape[1:])
output_dim = np.prod(shape)

Expand Down Expand Up @@ -1006,7 +1006,7 @@ def get_adv_noise(
else:
ibp = decomon_model.ibp
affine = decomon_model.affine
output = decomon_model.predict(x_reshaped, verbose=0)
output = decomon_model.predict_on_single_batch_np(x_reshaped)

def get_ibp_score(
u_c: npt.NDArray[np.float_],
Expand Down

0 comments on commit c520e54

Please sign in to comment.