-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added Backend interface for abstracting DataFrame preprocessing steps (…
…#1014) Co-authored-by: w4nderlust <w4nderlust@gmail.com>
- Loading branch information
1 parent
e245fa2
commit 496009e
Showing
40 changed files
with
1,438 additions
and
894 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
#! /usr/bin/env python | ||
# coding=utf-8 | ||
# Copyright (c) 2020 Uber Technologies, Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
from ludwig.backend.base import Backend, LocalBackend | ||
|
||
|
||
LOCAL_BACKEND = LocalBackend() | ||
|
||
|
||
def get_local_backend(): | ||
return LOCAL_BACKEND | ||
|
||
|
||
backend_registry = { | ||
'local': get_local_backend, | ||
None: get_local_backend, | ||
} | ||
|
||
|
||
def create_backend(name): | ||
return backend_registry[name]() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
#! /usr/bin/env python | ||
# coding=utf-8 | ||
# Copyright (c) 2020 Uber Technologies, Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
import os | ||
import tempfile | ||
import uuid | ||
|
||
from abc import ABC, abstractmethod | ||
from contextlib import contextmanager | ||
|
||
from ludwig.data.dataframe.pandas import PandasEngine | ||
|
||
|
||
class Backend(ABC): | ||
def __init__(self, cache_dir=None): | ||
self._cache_dir = cache_dir | ||
|
||
@property | ||
@abstractmethod | ||
def df_engine(self): | ||
raise NotImplementedError() | ||
|
||
@property | ||
@abstractmethod | ||
def supports_multiprocessing(self): | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def check_lazy_load_supported(self, feature): | ||
raise NotImplementedError() | ||
|
||
@property | ||
def cache_enabled(self): | ||
return self._cache_dir is not None | ||
|
||
def create_cache_entry(self): | ||
return os.path.join(self.cache_dir, str(uuid.uuid1())) | ||
|
||
@property | ||
def cache_dir(self): | ||
if not self._cache_dir: | ||
raise ValueError('Cache directory not available, try calling `with backend.create_cache_dir()`.') | ||
return self._cache_dir | ||
|
||
@contextmanager | ||
def create_cache_dir(self): | ||
prev_cache_dir = self._cache_dir | ||
try: | ||
if self._cache_dir: | ||
os.makedirs(self._cache_dir, exist_ok=True) | ||
yield self._cache_dir | ||
else: | ||
with tempfile.TemporaryDirectory() as tmpdir: | ||
self._cache_dir = tmpdir | ||
yield tmpdir | ||
finally: | ||
self._cache_dir = prev_cache_dir | ||
|
||
|
||
class LocalBackend(Backend): | ||
def __init__(self): | ||
super().__init__() | ||
self._df_engine = PandasEngine() | ||
|
||
@property | ||
def df_engine(self): | ||
return self._df_engine | ||
|
||
@property | ||
def supports_multiprocessing(self): | ||
return True | ||
|
||
def check_lazy_load_supported(self, feature): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
#! /usr/bin/env python | ||
# coding=utf-8 | ||
# Copyright (c) 2020 Uber Technologies, Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
#! /usr/bin/env python | ||
# coding=utf-8 | ||
# Copyright (c) 2020 Uber Technologies, Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
from abc import ABC, abstractmethod | ||
|
||
|
||
class DataFrameEngine(ABC): | ||
@abstractmethod | ||
def empty_df_like(self, df): | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def parallelize(self, data): | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def persist(self, data): | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def compute(self, data): | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def from_pandas(self, df): | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def map_objects(self, series, map_fn): | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def reduce_objects(self, series, reduce_fn): | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def create_dataset(self, dataset, tag, config, training_set_metadata): | ||
raise NotImplementedError() | ||
|
||
@property | ||
@abstractmethod | ||
def array_lib(self): | ||
raise NotImplementedError() | ||
|
||
@property | ||
@abstractmethod | ||
def df_lib(self): | ||
raise NotImplementedError() | ||
|
||
@property | ||
@abstractmethod | ||
def use_hdf5_cache(self): | ||
raise NotImplementedError() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
#! /usr/bin/env python | ||
# coding=utf-8 | ||
# Copyright (c) 2020 Uber Technologies, Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
from ludwig.data.dataset import Dataset | ||
from ludwig.data.dataframe.base import DataFrameEngine | ||
from ludwig.utils.data_utils import DATA_TRAIN_HDF5_FP | ||
from ludwig.utils.misc_utils import get_proc_features | ||
|
||
|
||
class PandasEngine(DataFrameEngine): | ||
def empty_df_like(self, df): | ||
return pd.DataFrame(index=df.index) | ||
|
||
def parallelize(self, data): | ||
return data | ||
|
||
def persist(self, data): | ||
return data | ||
|
||
def compute(self, data): | ||
return data | ||
|
||
def from_pandas(self, df): | ||
return df | ||
|
||
def map_objects(self, series, map_fn): | ||
return series.map(map_fn) | ||
|
||
def reduce_objects(self, series, reduce_fn): | ||
return reduce_fn(series) | ||
|
||
def create_dataset(self, dataset, tag, config, training_set_metadata): | ||
return Dataset( | ||
dataset, | ||
get_proc_features(config), | ||
training_set_metadata.get(DATA_TRAIN_HDF5_FP) | ||
) | ||
|
||
@property | ||
def array_lib(self): | ||
return np | ||
|
||
@property | ||
def df_lib(self): | ||
return pd | ||
|
||
@property | ||
def use_hdf5_cache(self): | ||
return True | ||
|
||
|
||
PANDAS = PandasEngine() |
Oops, something went wrong.