Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perform preprocessing first before hyperopt when possible #1415

Merged
merged 17 commits into from
Oct 30, 2021
93 changes: 85 additions & 8 deletions ludwig/backend/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def train_fn(
executable_kwargs: Dict[str, Any] = None,
remote_model: RayRemoteModel = None,
training_set_metadata: Dict[str, Any] = None,
features: Dict[str, Dict] = None,
train_shards: List[DatasetPipeline] = None,
val_shards: List[DatasetPipeline] = None,
test_shards: List[DatasetPipeline] = None,
Expand All @@ -174,26 +175,23 @@ def train_fn(

train_shard = RayDatasetShard(
train_shards[rt.world_rank()], #rt.get_dataset_shard("train"),
model.input_features,
model.output_features,
features,
training_set_metadata,
)

val_shard = val_shards[rt.world_rank()] if val_shards else None #rt.get_dataset_shard("val")
if val_shard is not None:
val_shard = RayDatasetShard(
val_shard,
model.input_features,
model.output_features,
features,
training_set_metadata,
)

test_shard = test_shards[rt.world_rank()] if test_shards else None #rt.get_dataset_shard("test")
if test_shard is not None:
test_shard = RayDatasetShard(
test_shard,
model.input_features,
model.output_features,
features,
training_set_metadata,
)

Expand Down Expand Up @@ -239,6 +237,7 @@ def train(self, model, training_set, validation_set=None, test_set=None, **kwarg
'val_shards': val_shards,
'test_shards': test_shards,
'training_set_metadata': training_set.training_set_metadata,
'features': training_set.features,
**kwargs
}

Expand Down Expand Up @@ -275,6 +274,54 @@ def shutdown(self):
self.trainer.shutdown()


def legacy_train_fn(
trainer: RayRemoteTrainer = None,
remote_model: RayRemoteModel = None,
training_set_metadata: Dict[str, Any] = None,
features: Dict[str, Dict] = None,
train_shards: List[DatasetPipeline] = None,
val_shards: List[DatasetPipeline] = None,
test_shards: List[DatasetPipeline] = None,
**kwargs
):
# Pin GPU before loading the model to prevent memory leaking onto other devices
hvd = initialize_horovod()
initialize_tensorflow(horovod=hvd)

model = remote_model.load()

train_shard = RayDatasetShard(
train_shards[hvd.rank()],
features,
training_set_metadata,
)

val_shard = val_shards[hvd.rank()] if val_shards else None
if val_shard is not None:
val_shard = RayDatasetShard(
val_shard,
features,
training_set_metadata,
)

test_shard = test_shards[hvd.rank()] if test_shards else None
if test_shard is not None:
test_shard = RayDatasetShard(
test_shard,
features,
training_set_metadata,
)

results = trainer.train(
model,
train_shard,
val_shard,
test_shard,
**kwargs
)
return results


class RayLegacyTrainer(BaseTrainer):
def __init__(self, horovod_kwargs, executable_kwargs):
# TODO ray: make this more configurable by allowing YAML overrides of timeout_s, etc.
Expand All @@ -285,10 +332,40 @@ def __init__(self, horovod_kwargs, executable_kwargs):
self.executor.start(executable_cls=RayRemoteTrainer,
executable_kwargs=executable_kwargs)

def train(self, model, *args, **kwargs):
def train(self, model, training_set, validation_set=None, test_set=None, **kwargs):
remote_model = RayRemoteModel(model)

# TODO(travis): enable after dropping petastorm
# workers = self.executor.driver.workers
# train_shards = training_set.pipeline().split(
# n=len(workers), locality_hints=workers, equal=True
# )
# val_shards = validation_set.pipeline(shuffle=False).split(
# n=len(workers), locality_hints=workers
# ) if validation_set else None
# test_shards = test_set.pipeline(shuffle=False).split(
# n=len(workers), locality_hints=workers
# ) if test_set else None

# results = self.executor.execute(
# lambda trainer: legacy_train_fn(
# trainer,
# remote_model,
# training_set.training_set_metadata,
# training_set.features,
# train_shards,
# val_shards,
# test_shards,
# **kwargs)
# )

results = self.executor.execute(
lambda trainer: trainer.train(remote_model.load(), *args, **kwargs)
lambda trainer: trainer.train(
remote_model.load(),
training_set,
validation_set,
test_set,
**kwargs)
)

weights, *stats = results[0]
Expand Down
13 changes: 13 additions & 0 deletions ludwig/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,20 @@ def on_preprocess_end(
def on_hyperopt_init(self, experiment_name):
pass

def on_hyperopt_preprocessing_start(self, experiment_name):
pass

def on_hyperopt_preprocessing_end(self, experiment_name):
pass

def on_hyperopt_start(self, experiment_name):
pass

def on_hyperopt_end(self, experiment_name):
pass

def on_hyperopt_finish(self, experiment_name):
# TODO(travis): remove in favor of on_hyperopt_end for naming consistency
pass

def on_hyperopt_trial_start(self, parameters):
Expand Down
1 change: 1 addition & 0 deletions ludwig/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@
MINIMIZE = "minimize"
MAXIMIZE = "maximize"
SAMPLER = "sampler"
PARAMETERS = "parameters"

NAME = "name"
COLUMN = "column"
Expand Down
64 changes: 34 additions & 30 deletions ludwig/data/dataset/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@
import pandas as pd

import ray
from ray.data import from_dask
from ray.data import from_dask, read_parquet
from ray.data.dataset_pipeline import DatasetPipeline
from ray.data.extensions import TensorDtype

from ludwig.constants import NAME
from ludwig.data.batcher.base import Batcher
from ludwig.data.dataset.base import Dataset
from ludwig.features.base_feature import InputFeature, OutputFeature
from ludwig.utils.data_utils import DATA_TRAIN_HDF5_FP
from ludwig.utils.misc_utils import get_proc_features
from ludwig.utils.types import DataFrame
Expand All @@ -41,11 +41,12 @@
_ray18 = LooseVersion(ray.__version__) >= LooseVersion("1.8")


class RayDataset(object):
class RayDataset(Dataset):
""" Wrapper around ray.data.Dataset. """

def __init__(self, df: DataFrame, features: Dict[str, Dict], training_set_metadata: Dict[str, Any]):
self.ds = from_dask(df)
# TODO(travis): move read_parquet to cache layer after removing petastorm
self.ds = from_dask(df) if not isinstance(df, str) else read_parquet(df)
self.features = features
self.training_set_metadata = training_set_metadata
self.data_hdf5_fp = training_set_metadata.get(DATA_TRAIN_HDF5_FP)
Expand All @@ -66,9 +67,28 @@ def pipeline(self, shuffle=True) -> DatasetPipeline:
pipe = pipe.random_shuffle()
return pipe

@contextlib.contextmanager
def initialize_batcher(self, batch_size=128,
should_shuffle=True,
shuffle_buffer_size=None,
seed=0,
ignore_last=False,
horovod=None):
yield RayDatasetBatcher(
self.ds.repeat().iter_datasets(),
self.features,
self.training_set_metadata,
batch_size,
self.size,
)

def __len__(self):
return self.ds.count()

@property
def size(self):
return len(self)


class RayDatasetManager(object):
def __init__(self, backend):
Expand Down Expand Up @@ -99,13 +119,11 @@ def save(
training_set_metadata: Dict[str, Any],
tag: str
):
# TODO(travis): optionally save dataset to Parquet for reuse
return dataset
self.backend.df_engine.to_parquet(dataset, cache_path)
return cache_path

def can_cache(self, skip_save_processed_input):
# TODO(travis): enable caching
# return self.backend.is_coordinator()
return False
return not skip_save_processed_input

@property
def data_format(self):
Expand All @@ -116,13 +134,11 @@ class RayDatasetShard(Dataset):
def __init__(
self,
dataset_shard: DatasetPipeline,
input_features: Dict[str, InputFeature],
output_features: Dict[str, OutputFeature],
features: Dict[str, Dict],
training_set_metadata: Dict[str, Any],
):
self.dataset_shard = dataset_shard
self.input_features = input_features
self.output_features = output_features
self.features = features
self.training_set_metadata = training_set_metadata
self.dataset_iter = dataset_shard.iter_datasets()

Expand All @@ -135,8 +151,7 @@ def initialize_batcher(self, batch_size=128,
horovod=None):
yield RayDatasetBatcher(
self.dataset_iter,
self.input_features,
self.output_features,
self.features,
self.training_set_metadata,
batch_size,
self.size,
Expand All @@ -156,8 +171,7 @@ class RayDatasetBatcher(Batcher):
def __init__(
self,
dataset_epoch_iterator: Iterator[ray.data.Dataset],
input_features: Dict[str, InputFeature],
output_features: Dict[str, OutputFeature],
features: Dict[str, Dict],
training_set_metadata: Dict[str, Any],
batch_size: int,
samples_per_epoch: int,
Expand All @@ -167,20 +181,10 @@ def __init__(
self.samples_per_epoch = samples_per_epoch
self.training_set_metadata = training_set_metadata

self.columns = [
f.proc_column for f in input_features.values()
] + [
f.proc_column for f in output_features.values()
]

features = {
**input_features,
**output_features,
}

self.columns = list(features.keys())
self.reshape_map = {
f.proc_column: training_set_metadata[f.feature_name].get('reshape')
for f in features.values()
proc_column: training_set_metadata[feature[NAME]].get('reshape')
for proc_column, feature in features.items()
}

self.dataset_batch_iter = None
Expand Down
3 changes: 2 additions & 1 deletion ludwig/data/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1618,7 +1618,8 @@ def _preprocess_file_for_training(
skip_save_processed_input=skip_save_processed_input
)

if backend.is_coordinator() and not skip_save_processed_input:
# TODO(travis): implement saving split for Ray
if backend.is_coordinator() and not skip_save_processed_input and SPLIT in data.columns:
# save split values for use by visualization routines
split_fp = get_split_path(dataset)
save_array(split_fp, data[SPLIT])
Expand Down
Loading