forked from sb-ai-lab/SLAMA
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add save/load functionality for SparkDataset (including handling of square brackets in column names) * add test for save/load SparkDataset * add feature-processing-only and simple optuna examples --------- Co-authored-by: fonhorst <fonhorst@alipoov.nb@gmail.com>
- Loading branch information
Showing
5 changed files
with
287 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from examples.spark.examples_utils import get_spark_session, get_dataset_attrs | ||
from sparklightautoml.pipelines.features.lgb_pipeline import SparkLGBAdvancedPipeline, SparkLGBSimpleFeatures | ||
from sparklightautoml.pipelines.features.linear_pipeline import SparkLinearFeatures | ||
from sparklightautoml.reader.base import SparkToSparkReader | ||
from sparklightautoml.tasks.base import SparkTask | ||
|
||
|
||
feature_pipelines = { | ||
"linear": SparkLinearFeatures(), | ||
"lgb_simple": SparkLGBSimpleFeatures(), | ||
"lgb_adv": SparkLGBAdvancedPipeline() | ||
} | ||
|
||
|
||
if __name__ == "__main__": | ||
spark = get_spark_session() | ||
|
||
# settings and data | ||
cv = 5 | ||
feat_pipe = "lgb_adv" # linear, lgb_simple or lgb_adv | ||
dataset_name = "lama_test_dataset" | ||
path, task_type, roles, dtype = get_dataset_attrs(dataset_name) | ||
df = spark.read.csv(path, header=True) | ||
|
||
task = SparkTask(name=task_type) | ||
reader = SparkToSparkReader(task=task, cv=cv, advanced_roles=False) | ||
feature_pipe = feature_pipelines.get(feat_pipe, None) | ||
|
||
assert feature_pipe, f"Unsupported feat pipe {feat_pipe}" | ||
|
||
ds = reader.fit_read(train_data=df, roles=roles) | ||
ds = feature_pipe.fit_transform(ds) | ||
|
||
# save processed data | ||
ds.save(f"/tmp/{dataset_name}__{feat_pipe}__features.dataset") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import logging | ||
import pickle | ||
from logging import config | ||
from typing import Tuple, Union, Callable | ||
|
||
import optuna | ||
from lightautoml.ml_algo.tuning.optuna import OptunaTuner, TunableAlgo | ||
from lightautoml.ml_algo.utils import tune_and_fit_predict | ||
from lightautoml.validation.base import TrainValidIterator | ||
from pyspark.sql import functions as sf | ||
|
||
from examples.spark.examples_utils import get_spark_session | ||
from sparklightautoml.dataset.base import SparkDataset | ||
from sparklightautoml.dataset.persistence import PlainCachePersistenceManager | ||
from sparklightautoml.ml_algo.boost_lgbm import SparkBoostLGBM | ||
from sparklightautoml.utils import logging_config, VERBOSE_LOGGING_FORMAT | ||
from sparklightautoml.validation.iterators import SparkHoldoutIterator, SparkFoldsIterator | ||
|
||
logging.config.dictConfig(logging_config(level=logging.DEBUG, log_filename='/tmp/slama.log')) | ||
logging.basicConfig(level=logging.DEBUG, format=VERBOSE_LOGGING_FORMAT) | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class ProgressReportingOptunaTuner(OptunaTuner): | ||
def _get_objective(self, ml_algo: TunableAlgo, estimated_n_trials: int, train_valid_iterator: TrainValidIterator) \ | ||
-> Callable[[optuna.trial.Trial], Union[float, int]]: | ||
obj_func = super()._get_objective(ml_algo, estimated_n_trials, train_valid_iterator) | ||
|
||
def func(*args, **kwargs): | ||
obj_score = obj_func(*args, **kwargs) | ||
logger.info(f"Objective score: {obj_score}") | ||
return obj_score | ||
|
||
return func | ||
|
||
|
||
def train_test_split(dataset: SparkDataset, test_slice_or_fold_num: Union[float, int] = 0.2) \ | ||
-> Tuple[SparkDataset, SparkDataset]: | ||
|
||
if isinstance(test_slice_or_fold_num, float): | ||
assert 0 <= test_slice_or_fold_num <= 1 | ||
train, test = dataset.data.randomSplit([1 - test_slice_or_fold_num, test_slice_or_fold_num]) | ||
else: | ||
train = dataset.data.where(sf.col(dataset.folds_column) != test_slice_or_fold_num) | ||
test = dataset.data.where(sf.col(dataset.folds_column) == test_slice_or_fold_num) | ||
|
||
train_dataset, test_dataset = dataset.empty(), dataset.empty() | ||
train_dataset.set_data(train, dataset.features, roles=dataset.roles) | ||
test_dataset.set_data(test, dataset.features, roles=dataset.roles) | ||
|
||
return train_dataset, test_dataset | ||
|
||
|
||
if __name__ == "__main__": | ||
spark = get_spark_session() | ||
|
||
feat_pipe = "lgb_adv" # linear, lgb_simple or lgb_adv | ||
dataset_name = "lama_test_dataset" | ||
|
||
# load and prepare data | ||
ds = SparkDataset.load( | ||
path=f"/tmp/{dataset_name}__{feat_pipe}__features.dataset", | ||
persistence_manager=PlainCachePersistenceManager() | ||
) | ||
train_ds, test_ds = train_test_split(ds, test_slice_or_fold_num=4) | ||
|
||
# create main entities | ||
iterator = SparkFoldsIterator(train_ds).convert_to_holdout_iterator() | ||
tuner = ProgressReportingOptunaTuner(n_trials=101, timeout=3000) | ||
ml_algo = SparkBoostLGBM() | ||
score = ds.task.get_dataset_metric() | ||
|
||
# fit and predict | ||
model, oof_preds = tune_and_fit_predict(ml_algo, tuner, iterator) | ||
test_preds = ml_algo.predict(test_ds) | ||
|
||
# reporting trials | ||
# TODO: reporting to mlflow | ||
# TODO: quality curves on different datasets | ||
with open("/tmp/trials.pickle", "wb") as f: | ||
pickle.dump(tuner.study.trials, f) | ||
|
||
# estimate oof and test metrics | ||
oof_metric_value = score(oof_preds.data.select( | ||
SparkDataset.ID_COLUMN, | ||
sf.col(ds.target_column).alias('target'), | ||
sf.col(ml_algo.prediction_feature).alias('prediction') | ||
)) | ||
|
||
test_metric_value = score(test_preds.data.select( | ||
SparkDataset.ID_COLUMN, | ||
sf.col(ds.target_column).alias('target'), | ||
sf.col(ml_algo.prediction_feature).alias('prediction') | ||
)) | ||
|
||
print(f"OOF metric: {oof_metric_value}") | ||
print(f"Test metric: {oof_metric_value}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import os | ||
import shutil | ||
from typing import Optional | ||
|
||
import numpy as np | ||
from lightautoml.dataset.roles import NumericRole | ||
from lightautoml.tasks import Task | ||
from pandas.testing import assert_frame_equal | ||
from pyspark.sql import SparkSession | ||
|
||
from sparklightautoml.dataset.base import SparkDataset | ||
from sparklightautoml.tasks.base import SparkTask | ||
from . import spark as spark_sess | ||
|
||
spark = spark_sess | ||
|
||
|
||
def compare_tasks(task_a: Optional[Task], task_b: Optional[Task]): | ||
assert (task_a and task_b) or (not task_a and not task_b) | ||
assert task_a.name == task_b.name | ||
assert task_a.metric_name == task_b.metric_name | ||
assert task_a.greater_is_better == task_b.greater_is_better | ||
|
||
|
||
def compare_dfs(dataset_a: SparkDataset, dataset_b: SparkDataset): | ||
assert dataset_a.data.schema == dataset_b.data.schema | ||
|
||
# checking data | ||
df_a = dataset_a.data.orderBy(SparkDataset.ID_COLUMN).toPandas() | ||
df_b = dataset_b.data.orderBy(SparkDataset.ID_COLUMN).toPandas() | ||
assert_frame_equal(df_a, df_b) | ||
|
||
|
||
def test_spark_dataset_save_load(spark: SparkSession): | ||
path = "/tmp/test_slama_ds.dataset" | ||
|
||
# cleanup | ||
if os.path.exists(path): | ||
shutil.rmtree(path) | ||
|
||
# creating test data | ||
df = spark.createDataFrame([{ | ||
SparkDataset.ID_COLUMN: i, | ||
"a": i + 1, | ||
"b": i * 10 + 1, | ||
"this_is_target": 0, | ||
"this_is_fold": 0, | ||
"scaler__fillnamed__fillinf__logodds__oof__inter__(CODE_GENDER__EMERGENCYSTATE_MODE)": 12.0 | ||
} for i in range(10)]) | ||
|
||
ds = SparkDataset( | ||
data=df, | ||
task=SparkTask("reg"), | ||
target="this_is_target", | ||
folds="this_is_fold", | ||
roles={ | ||
"a": NumericRole(dtype=np.int32), | ||
"b": NumericRole(dtype=np.int32), | ||
"scaler__fillnamed__fillinf__logodds__oof__inter__(CODE_GENDER__EMERGENCYSTATE_MODE)": NumericRole() | ||
} | ||
) | ||
|
||
ds.save(path=path) | ||
loaded_ds = SparkDataset.load(path=path) | ||
|
||
# checking metadata | ||
assert loaded_ds.uid | ||
assert loaded_ds.uid != ds.uid | ||
assert loaded_ds.name == ds.name | ||
assert loaded_ds.target_column == ds.target_column | ||
assert loaded_ds.folds_column == ds.folds_column | ||
assert loaded_ds.service_columns == ds.service_columns | ||
assert loaded_ds.features == ds.features | ||
assert loaded_ds.roles == ds.roles | ||
compare_tasks(loaded_ds.task, ds.task) | ||
compare_dfs(loaded_ds, ds) | ||
|
||
# cleanup | ||
if os.path.exists(path): | ||
shutil.rmtree(path) |