Skip to content

Commit

Permalink
Added export to MLflow pyfunc model format (#1192)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Jun 3, 2021
1 parent fbb0b58 commit 0573b3a
Show file tree
Hide file tree
Showing 8 changed files with 225 additions and 65 deletions.
109 changes: 55 additions & 54 deletions ludwig/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,64 +484,65 @@ def train(
if not skip_save_model:
self.save_config(model_dir)

train_stats = trainer.train(
self.model,
training_set,
validation_set=validation_set,
test_set=test_set,
save_path=model_dir,
)

self.model, train_trainset_stats, train_valiset_stats, train_testset_stats = train_stats
train_stats = {
TRAINING: train_trainset_stats,
VALIDATION: train_valiset_stats,
TEST: train_testset_stats
}

# save training statistics
if self.backend.is_coordinator():
if not skip_save_training_statistics:
save_json(training_stats_fn, train_stats)

# grab the results of the model with highest validation test performance
validation_field = trainer.validation_field
validation_metric = trainer.validation_metric
validation_field_result = train_valiset_stats[validation_field]

best_function = get_best_function(validation_metric)
# results of the model with highest validation test performance
if self.backend.is_coordinator() and validation_set is not None:
epoch_best_vali_metric, best_vali_metric = best_function(
enumerate(validation_field_result[validation_metric]),
key=lambda pair: pair[1]
)
logger.info(
'Best validation model epoch: {0}'.format(
epoch_best_vali_metric + 1)
try:
train_stats = trainer.train(
self.model,
training_set,
validation_set=validation_set,
test_set=test_set,
save_path=model_dir,
)
logger.info(
'Best validation model {0} on validation set {1}: {2}'.format(
validation_metric, validation_field, best_vali_metric
))
if test_set is not None:
best_vali_metric_epoch_test_metric = train_testset_stats[
validation_field][validation_metric][
epoch_best_vali_metric]

self.model, train_trainset_stats, train_valiset_stats, train_testset_stats = train_stats
train_stats = {
TRAINING: train_trainset_stats,
VALIDATION: train_valiset_stats,
TEST: train_testset_stats
}

# save training statistics
if self.backend.is_coordinator():
if not skip_save_training_statistics:
save_json(training_stats_fn, train_stats)

# grab the results of the model with highest validation test performance
validation_field = trainer.validation_field
validation_metric = trainer.validation_metric
validation_field_result = train_valiset_stats[validation_field]

best_function = get_best_function(validation_metric)
# results of the model with highest validation test performance
if self.backend.is_coordinator() and validation_set is not None:
epoch_best_vali_metric, best_vali_metric = best_function(
enumerate(validation_field_result[validation_metric]),
key=lambda pair: pair[1]
)
logger.info(
'Best validation model {0} on test set {1}: {2}'.format(
validation_metric,
validation_field,
best_vali_metric_epoch_test_metric
)
'Best validation model epoch: {0}'.format(
epoch_best_vali_metric + 1)
)
logger.info(
'\nFinished: {0}_{1}'.format(experiment_name, model_name))
logger.info('Saved to: {0}'.format(output_directory))

for callback in self.callbacks:
callback.on_train_end(output_directory)
logger.info(
'Best validation model {0} on validation set {1}: {2}'.format(
validation_metric, validation_field, best_vali_metric
))
if test_set is not None:
best_vali_metric_epoch_test_metric = train_testset_stats[
validation_field][validation_metric][
epoch_best_vali_metric]

logger.info(
'Best validation model {0} on test set {1}: {2}'.format(
validation_metric,
validation_field,
best_vali_metric_epoch_test_metric
)
)
logger.info(
'\nFinished: {0}_{1}'.format(experiment_name, model_name))
logger.info('Saved to: {0}'.format(output_directory))
finally:
for callback in self.callbacks:
callback.on_train_end(output_directory)

self.training_set_metadata = training_set_metadata

Expand Down
2 changes: 1 addition & 1 deletion ludwig/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def on_test_end(self, trainer, progress_tracker, save_path):
def on_visualize_figure(self, fig):
pass

def prepare_ray_tune(self, train_fn, tune_config):
def prepare_ray_tune(self, train_fn, tune_config, tune_callbacks):
"""Configures Ray Tune to properly use this callback in each trial."""
return train_fn, tune_config

Expand Down
5 changes: 5 additions & 0 deletions ludwig/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(self):
collect_activations Collects tensors for each datapoint using a pretrained model
export_savedmodel Exports Ludwig models to SavedModel
export_neuropod Exports Ludwig models to Neuropod
export_mlflow Exports Ludwig models to MLflow
preprocess Preprocess data and saves it into HDF5 and JSON format
synthesize_dataset Creates synthetic data for tesing purposes
''')
Expand Down Expand Up @@ -117,6 +118,10 @@ def export_neuropod(self):
from ludwig import export
export.cli_export_neuropod(sys.argv[2:])

def export_mlflow(self):
from ludwig import export
export.cli_export_mlflow(sys.argv[2:])

def preprocess(self):
from ludwig import preprocess
preprocess.cli(sys.argv[2:])
Expand Down
61 changes: 58 additions & 3 deletions ludwig/contribs/mlflow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os

from ludwig.api import LudwigModel
from ludwig.callbacks import Callback
from ludwig.utils.data_utils import chunk_dict, flatten_dict, to_json_dict
from ludwig.utils.package_utils import LazyLoader
Expand Down Expand Up @@ -57,8 +58,7 @@ def on_train_start(self, config, **kwargs):
self._log_params({'training': config['training']})

def on_train_end(self, output_directory):
for fname in os.listdir(output_directory):
mlflow.log_artifact(os.path.join(output_directory, fname))
_log_artifacts(output_directory)
if self.run is not None:
mlflow.end_run()

Expand All @@ -68,13 +68,19 @@ def on_epoch_end(self, trainer, progress_tracker, save_path):
step=progress_tracker.steps
)

mlflow.pyfunc.log_model(
artifact_path='model',
**_export_kwargs(save_path)
)

def on_visualize_figure(self, fig):
# TODO: need to also include a filename for this figure
# mlflow.log_figure(fig)
pass

def prepare_ray_tune(self, train_fn, tune_config):
def prepare_ray_tune(self, train_fn, tune_config, tune_callbacks):
from ray.tune.integration.mlflow import mlflow_mixin

return mlflow_mixin(train_fn), {
**tune_config,
'mlflow': {
Expand All @@ -92,3 +98,52 @@ def __setstate__(self, d):
self.__dict__ = d
if self.tracking_uri:
mlflow.set_tracking_uri(self.tracking_uri)


class LudwigMlflowModel(mlflow.pyfunc.PythonModel):
def __init__(self):
super().__init__()
self._model = None

def load_context(self, context):
self._model = LudwigModel.load(context.artifacts['model'])

def predict(self, context, model_input):
pred_df, _ = self._model.predict(model_input)
return pred_df


def _export_kwargs(model_path):
return dict(
python_model=LudwigMlflowModel(),
artifacts={
'model': model_path,
},
)


def _log_artifacts(output_directory):
for fname in os.listdir(output_directory):
lpath = os.path.join(output_directory, fname)
if fname == 'model':
mlflow.pyfunc.log_model(
artifact_path='model',
**_export_kwargs(lpath)
)
else:
mlflow.log_artifact(lpath)


def export_model(model_path, output_path, registered_model_name=None):
kwargs = _export_kwargs(model_path)
if registered_model_name:
mlflow.pyfunc.log_model(
artifact_path=output_path,
registered_model_name=registered_model_name,
**kwargs
)
else:
mlflow.pyfunc.save_model(
path=output_path,
**kwargs
)
95 changes: 95 additions & 0 deletions ludwig/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,36 @@ def export_neuropod(
logger.info('Saved to: {0}'.format(output_path))


def export_mlflow(
model_path,
output_path='mlflow',
registered_model_name=None,
**kwargs
):
"""Exports a model to MLflow
# Inputs
:param model_path: (str) filepath to pre-trained model.
:param output_path: (str, default: `'mlflow'`) directory to store the
mlflow model.
:param registered_model_name: (str, default: `None`) save mlflow under this
name in the model registry. Saved locally if `None`.
# Return
:returns: (`None`)
"""
logger.info('Model path: {}'.format(model_path))
logger.info('Output path: {}'.format(output_path))
logger.info('\n')

from ludwig.contribs import mlflow
mlflow.export_model(model_path, output_path, registered_model_name)

logger.info('Saved to: {0}'.format(output_path))


def cli_export_savedmodel(sys_argv):
parser = argparse.ArgumentParser(
description='This script loads a pretrained model '
Expand Down Expand Up @@ -209,6 +239,71 @@ def cli_export_neuropod(sys_argv):
export_neuropod(**vars(args))


def cli_export_mlflow(sys_argv):
parser = argparse.ArgumentParser(
description='This script loads a pretrained model '
'and saves it as an MLFlow model.',
prog='ludwig export_mlflow',
usage='%(prog)s [options]'
)

# ----------------
# Model parameters
# ----------------
parser.add_argument(
'-m',
'--model_path',
help='model to load',
required=True
)
parser.add_argument(
'-mn',
'--registered_model_name',
help='model name to upload to in MLflow model registry',
default='mlflow'
)

# -----------------
# Output parameters
# -----------------
parser.add_argument(
'-od',
'--output_path',
type=str,
help='path where to save the exported model',
required=True
)

# ------------------
# Runtime parameters
# ------------------
parser.add_argument(
'-l',
'--logging_level',
default='info',
help='the level of logging to use',
choices=['critical', 'error', 'warning', 'info', 'debug', 'notset']
)

add_contrib_callback_args(parser)
args = parser.parse_args(sys_argv)

args.callbacks = args.callbacks or []
for callback in args.callbacks:
callback.on_cmdline('export_mlflow', *sys_argv)

args.logging_level = logging_level_registry[args.logging_level]
logging.getLogger('ludwig').setLevel(
args.logging_level
)
global logger
logger = logging.getLogger('ludwig.export')

print_ludwig('Export MLFlow', LUDWIG_VERSION)

export_mlflow(**vars(args))


if __name__ == '__main__':
if len(sys.argv) > 1:
if sys.argv[1] == 'savedmodel':
Expand Down
3 changes: 3 additions & 0 deletions ludwig/hyperopt/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,10 +905,12 @@ def run_experiment_trial(config, checkpoint_dir=None):
return self._run_experiment(config, checkpoint_dir, hyperopt_dict, self.decode_ctx)

tune_config = {}
tune_callbacks = []
for callback in callbacks or []:
run_experiment_trial, tune_config = callback.prepare_ray_tune(
run_experiment_trial,
tune_config,
tune_callbacks,
)

register_trainable(
Expand All @@ -933,6 +935,7 @@ def run_experiment_trial(config, checkpoint_dir=None):
mode=mode,
trial_name_creator=lambda trial: f"trial_{trial.trial_id}",
trial_dirname_creator=lambda trial: f"trial_{trial.trial_id}",
callbacks=tune_callbacks,
)

ordered_trials = analysis.results_df.sort_values(
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/scripts/run_train_comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# This test runs in an isolated environment to ensure TensorFlow imports are not leaked
# from previous tests.

# Comet must be imported before the libraries to wraps
# Comet must be imported before the libraries it wraps
import comet_ml

import argparse
Expand Down
Loading

0 comments on commit 0573b3a

Please sign in to comment.