Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added additional backend configuration options #1195

Merged
merged 16 commits into from
Jun 8, 2021
21 changes: 11 additions & 10 deletions ludwig/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,10 @@
from typing import Dict, List, Optional, Tuple, Union

from ludwig.data.dataset.partitioned import PartitionedDataset
from ludwig.utils.fs_utils import upload_output_directory, open_file, path_exists, makedirs
from ludwig.utils.fs_utils import upload_output_directory, path_exists, makedirs

import numpy as np
import pandas as pd
import yaml

from ludwig.backend import Backend, initialize_backend
from ludwig.callbacks import Callback
Expand All @@ -59,7 +58,7 @@
DICT_FORMATS,
external_data_reader_registry,
figure_data_format, generate_kfold_splits,
load_json, save_json)
load_json, save_json, load_yaml)
from ludwig.utils.defaults import default_random_seed, merge_with_defaults
from ludwig.utils.misc_utils import (get_experiment_description,
get_file_names, get_from_registry,
Expand Down Expand Up @@ -176,8 +175,7 @@ def __init__(
"""
# check if config is a path or a dict
if isinstance(config, str): # assume path
with open_file(config, 'r') as def_file:
config_dict = yaml.safe_load(def_file)
config_dict = load_yaml(config)
self.config_fp = config
else:
config_dict = copy.deepcopy(config)
Expand All @@ -192,7 +190,7 @@ def __init__(
self.set_logging_level(logging_level)

# setup Backend
self.backend = initialize_backend(backend)
self.backend = initialize_backend(backend or config.get('backend'))
self.callbacks = callbacks if callbacks is not None else []

# setup TensorFlow
Expand Down Expand Up @@ -1336,6 +1334,7 @@ def load(
"""
# Initialize Horovod and TensorFlow before calling `broadcast()` to prevent initializing
# TensorFlow with default parameters
backend_param = backend
backend = initialize_backend(backend)
backend.initialize_tensorflow(
gpus=gpus,
Expand All @@ -1350,6 +1349,10 @@ def load(
)
))

if backend_param is None and 'backend' in config:
# Reset backend from config
backend = initialize_backend(config.get('backend'))

# initialize model
ludwig_model = LudwigModel(
config,
Expand Down Expand Up @@ -1654,12 +1657,10 @@ def kfold_cross_validate(
`kfold_split_indices`: indices to split training data into
training fold and test fold.
"""
backend = initialize_backend(backend)

# if config is a path, convert to dictionary
if isinstance(config, str): # assume path
with open_file(config, 'r') as def_file:
config = yaml.safe_load(def_file)
config = load_yaml(config)
backend = initialize_backend(backend or config.get('backend'))

# check for k_fold
if num_folds is None:
Expand Down
37 changes: 29 additions & 8 deletions ludwig/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
from ludwig.backend.base import Backend, LocalBackend
from ludwig.utils.horovod_utils import has_horovodrun

try:
import ray as _ray
except:
_ray = None


LOCAL_BACKEND = LocalBackend()

Expand All @@ -29,6 +34,17 @@
ALL_BACKENDS = [LOCAL, DASK, HOROVOD, RAY]


def _has_ray():
if _ray is None:
return False

try:
_ray.init('auto', ignore_reinit_error=True)
return True
except:
return False


def get_local_backend(**kwargs):
return LocalBackend(**kwargs)

Expand Down Expand Up @@ -57,17 +73,22 @@ def create_ray_backend(**kwargs):
}


def create_backend(name, **kwargs):
if isinstance(name, Backend):
return name
def create_backend(type, **kwargs):
if isinstance(type, Backend):
return type

if name is None and has_horovodrun():
name = HOROVOD
if type is None and _has_ray():
type = RAY
elif type is None and has_horovodrun():
type = HOROVOD

return backend_registry[name](**kwargs)
return backend_registry[type](**kwargs)


def initialize_backend(name, **kwargs):
backend = create_backend(name, **kwargs)
def initialize_backend(backend):
if isinstance(backend, dict):
backend = create_backend(**backend)
else:
backend = create_backend(backend)
backend.initialize()
return backend
4 changes: 2 additions & 2 deletions ludwig/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@


class Backend(ABC):
def __init__(self, cache_dir=None, data_format=None):
self._dataset_manager = create_dataset_manager(self, data_format)
def __init__(self, cache_dir=None, cache_format=None):
self._dataset_manager = create_dataset_manager(self, cache_format)
self._cache_manager = CacheManager(self._dataset_manager, cache_dir)

@property
Expand Down
8 changes: 4 additions & 4 deletions ludwig/backend/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,13 @@ def shutdown(self):


class DaskBackend(LocalTrainingMixin, Backend):
def __init__(self, data_format=PARQUET, engine=None, **kwargs):
super().__init__(data_format=data_format, **kwargs)
def __init__(self, cache_format=PARQUET, engine=None, **kwargs):
super().__init__(cache_format=cache_format, **kwargs)
engine = engine or {}
self._df_engine = DaskEngine(**engine)
if data_format != PARQUET:
if cache_format != PARQUET:
raise ValueError(
f'Data format {data_format} is not supported when using the Dask backend. '
f'Data format {cache_format} is not supported when using the Dask backend. '
f'Try setting to `parquet`.'
)

Expand Down
31 changes: 25 additions & 6 deletions ludwig/backend/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ludwig.backend.base import Backend, RemoteTrainingMixin
from ludwig.constants import NAME, PARQUET
from ludwig.data.dataframe.dask import DaskEngine
from ludwig.data.dataframe.pandas import PandasEngine
from ludwig.data.dataset.partitioned import PartitionedDataset
from ludwig.models.predictor import BasePredictor, Predictor, get_output_columns
from ludwig.models.trainer import BaseTrainer, RemoteTrainer
Expand Down Expand Up @@ -74,10 +75,28 @@ def get_total_resources(bucket):
)


_engine_registry = {
'dask': DaskEngine,
'pandas': PandasEngine,
}


def _get_df_engine(engine_config):
if engine_config is None:
return DaskEngine()

engine_config = engine_config.copy()

dtype = engine_config.pop('type', 'dask')
engine_cls = _engine_registry.get(dtype)
return engine_cls(**engine_config)


class RayRemoteModel:
def __init__(self, model):
buf = save_weights_to_buffer(model)
self.cls, self.args, state = list(model.__reduce__())
self.cls = type(model)
self.args = model.get_args()
self.state = ray.put(buf)

def load(self):
Expand Down Expand Up @@ -193,14 +212,14 @@ def shutdown(self):


class RayBackend(RemoteTrainingMixin, Backend):
def __init__(self, horovod_kwargs=None, data_format=PARQUET, **kwargs):
super().__init__(data_format=data_format, **kwargs)
self._df_engine = DaskEngine()
def __init__(self, horovod_kwargs=None, cache_format=PARQUET, engine=None, **kwargs):
super().__init__(cache_format=cache_format, **kwargs)
self._df_engine = _get_df_engine(engine)
self._horovod_kwargs = horovod_kwargs or {}
self._tensorflow_kwargs = {}
if data_format != PARQUET:
if cache_format != PARQUET:
raise ValueError(
f'Data format {data_format} is not supported when using the Ray backend. '
f'Data format {cache_format} is not supported when using the Ray backend. '
f'Try setting to `parquet`.'
)

Expand Down
22 changes: 17 additions & 5 deletions ludwig/contribs/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,24 @@ def _log_artifacts(output_directory):
def export_model(model_path, output_path, registered_model_name=None):
kwargs = _export_kwargs(model_path)
if registered_model_name:
mlflow.pyfunc.log_model(
artifact_path=output_path,
registered_model_name=registered_model_name,
**kwargs
)
if not model_path.startswith('runs:/') or output_path is not None:
# No run specified, so in order to register the model in mlflow, we need
# to create a new run and upload the model as an artifact first
output_path = output_path or 'model'
with mlflow.start_run():
mlflow.pyfunc.log_model(
artifact_path=output_path,
registered_model_name=registered_model_name,
**kwargs
)
else:
# Registering a model from an artifact of an existing run
mlflow.register_model(
model_path,
registered_model_name,
)
else:
# No model name means we only want to save the model locally
mlflow.pyfunc.save_model(
path=output_path,
**kwargs
Expand Down
4 changes: 4 additions & 0 deletions ludwig/data/dataframe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,7 @@ def df_lib(self):
@abstractmethod
def partitioned(self):
raise NotImplementedError()

@abstractmethod
def set_parallelism(self, parallelism):
raise NotImplementedError()
2 changes: 1 addition & 1 deletion ludwig/data/dataframe/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def set_scheduler(scheduler):


class DaskEngine(DataFrameEngine):
def __init__(self, parallelism=None, persist=False):
def __init__(self, parallelism=None, persist=False, **kwargs):
self._parallelism = parallelism or multiprocessing.cpu_count()
self._persist = persist

Expand Down
6 changes: 6 additions & 0 deletions ludwig/data/dataframe/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@


class PandasEngine(DataFrameEngine):
def __init__(self, **kwargs):
super().__init__()

def empty_df_like(self, df):
return pd.DataFrame(index=df.index)

Expand Down Expand Up @@ -61,5 +64,8 @@ def df_lib(self):
def partitioned(self):
return False

def set_parallelism(self, parallelism):
pass


PANDAS = PandasEngine()
4 changes: 2 additions & 2 deletions ludwig/data/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@
}


def create_dataset_manager(backend, data_format, **kwargs):
return dataset_registry.get(data_format)(backend)
def create_dataset_manager(backend, cache_format, **kwargs):
return dataset_registry.get(cache_format)(backend)
1 change: 1 addition & 0 deletions ludwig/data/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __len__(self):
@abstractmethod
def initialize_batcher(self, batch_size=128,
should_shuffle=True,
shuffle_buffer_size=None,
seed=0,
ignore_last=False,
horovod=None):
Expand Down
1 change: 1 addition & 0 deletions ludwig/data/dataset/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __len__(self):
@contextlib.contextmanager
def initialize_batcher(self, batch_size=128,
should_shuffle=True,
shuffle_buffer_size=None,
seed=0,
ignore_last=False,
horovod=None):
Expand Down
3 changes: 2 additions & 1 deletion ludwig/data/dataset/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __len__(self):
def initialize_batcher(self,
batch_size=128,
should_shuffle=True,
shuffle_buffer_size=None,
seed=0,
ignore_last=False,
horovod=None):
Expand All @@ -81,7 +82,7 @@ def initialize_batcher(self,
dataset = dataset.unbatch()
if should_shuffle:
rows_per_piece = max([piece.get_metadata().num_rows for piece in reader.dataset.pieces])
buffer_size = min(rows_per_piece, local_samples)
buffer_size = shuffle_buffer_size or min(rows_per_piece, local_samples)
dataset = dataset.shuffle(buffer_size)
dataset = dataset.batch(batch_size)

Expand Down
Loading