Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update batch_size calculation in keras autolog #11224

Merged
merged 11 commits into from Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
40 changes: 26 additions & 14 deletions mlflow/keras/autologging.py
Expand Up @@ -14,6 +14,7 @@
from mlflow.keras.save import log_model
from mlflow.keras.utils import get_model_signature
from mlflow.tracking.context import registry as context_registry
from mlflow.utils import is_iterator
from mlflow.utils.annotations import experimental
from mlflow.utils.autologging_utils import (
PatchFunction,
Expand All @@ -26,19 +27,6 @@
_logger = logging.getLogger(__name__)


def _infer_batch_size(*keras_fit_args, **keras_fit_kwargs):
if "batch_size" in keras_fit_kwargs:
return keras_fit_kwargs["batch_size"]

training_data = keras_fit_kwargs["x"] if "x" in keras_fit_kwargs else keras_fit_args[0]
batch_size = getattr(training_data, "batch_size", None) or getattr(
training_data, "_batch_size", None
)
if batch_size:
return batch_size
return None


def _check_existing_mlflow_callback(callbacks):
for callback in callbacks:
if isinstance(callback, MLflowCallback):
Expand Down Expand Up @@ -212,7 +200,31 @@ def __init__(self):
def _patch_implementation(self, original, inst, *args, **kwargs):
unlogged_params = ["self", "x", "y", "callbacks", "validation_data", "verbose"]

batch_size = _infer_batch_size(*args, **kwargs)
batch_size = None
serena-ruan marked this conversation as resolved.
Show resolved Hide resolved
if "batch_size" in kwargs:
batch_size = kwargs["batch_size"]
else:
training_data = kwargs["x"] if "x" in kwargs else args[0]
if _batch_size := getattr(training_data, "batch_size", None):
batch_size = _batch_size
elif _batch_size := getattr(training_data, "_batch_size", None):
batch_size = (
_batch_size if isinstance(_batch_size, int) else _batch_size.numpy()
)
elif is_iterator(training_data):
is_single_input_model = isinstance(inst.input_shape, tuple)
peek = next(training_data)
batch_size = len(peek[0]) if is_single_input_model else len(peek[0][0])

def __restore_generator(prev_generator):
serena-ruan marked this conversation as resolved.
Show resolved Hide resolved
yield peek
yield from prev_generator

restored_generator = __restore_generator(training_data)
if "x" in kwargs:
kwargs["x"] = restored_generator
else:
args = (restored_generator,) + args[1:]

if batch_size is not None:
mlflow.log_param("batch_size", batch_size)
Expand Down
2 changes: 1 addition & 1 deletion mlflow/ml-package-versions.yml
Expand Up @@ -109,7 +109,7 @@ tensorflow:
maximum: "2.15.0.post1"
requirements:
# Requirements to run tests for keras
">= 0.0.0": ["scikit-learn", "pyspark", "pyarrow", "transformers"]
">= 0.0.0": ["scikit-learn", "pyspark", "pyarrow", "transformers!=4.38.0,!=4.38.1"]
"< 2.7.0": ["pandas>=1.3.5,<2.0"]
">= 2.7.0": ["pandas<2.0"]
# TensorFlow == 2.6.5 are incompatible with SQLAlchemy 2.x due to
Expand Down