Skip to content

Commit

Permalink
Fix/pr comments 3 (#7)
Browse files Browse the repository at this point in the history
* add save/load fixes

* refactoring of the scala library

---------

Co-authored-by: fonhorst <fonhorst@alipoov.nb@gmail.com>
  • Loading branch information
fonhorst and fonhorst committed Jun 28, 2023
1 parent 8d0677b commit 29b8cf8
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ object functions {
val vectorAveragingUdf: UserDefinedFunction = udf { (vecs: Seq[DenseVector], vec_size: Int) =>
val not_null_vecs = vecs.count(_ != null)
if (not_null_vecs == 0){
// null
new DenseVector(linalg.DenseVector.fill(vec_size)(0.0).toArray)
}
else {
Expand Down Expand Up @@ -40,20 +39,6 @@ object functions {
case (_, 0) => 0
case (sm, cnt) => sm / cnt
}


// val not_null_vecs = vecs.count(_.nonEmpty)
// if (not_null_vecs == 0){
//// null
// 0.0
// }
// else {
// val result = vecs.map {
// case None => 0.0
// case Some(v) => v
// }.sum
// result / not_null_vecs
// }
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ class IsHolidayTransformer(override val uid: String, private var holidays_dates:
with HasOutputCols
with MLWritable {

val dt_format = "yyyy-MM-dd"

def this(uid: String, holidays_dates: java.util.Map[String, java.util.Set[String]]) =
this(
uid,
Expand Down Expand Up @@ -87,10 +85,10 @@ class IsHolidayTransformer(override val uid: String, private var holidays_dates:
val outColumns = getInputCols.zip(getOutputCols).map {
case (in_col, out_col) =>
val dt_col = dataset.schema(in_col).dataType match {
case _: DateType => date_format(to_timestamp(col(in_col)), dt_format)
case _: TimestampType => date_format(col(in_col), dt_format)
case _: LongType => date_format(col(in_col).cast(TimestampType), dt_format)
case _: IntegerType => date_format(col(in_col).cast(TimestampType), dt_format)
case _: DateType => date_format(to_timestamp(col(in_col)), IsHolidayTransformer.dt_format)
case _: TimestampType => date_format(col(in_col), IsHolidayTransformer.dt_format)
case _: LongType => date_format(col(in_col).cast(TimestampType), IsHolidayTransformer.dt_format)
case _: IntegerType => date_format(col(in_col).cast(TimestampType), IsHolidayTransformer.dt_format)
case _: StringType => col(in_col)
}
func(lit(in_col).cast(StringType), dt_col).alias(out_col)
Expand All @@ -115,6 +113,8 @@ class IsHolidayTransformer(override val uid: String, private var holidays_dates:


object IsHolidayTransformer extends MLReadable[IsHolidayTransformer] {
val dt_format = "yyyy-MM-dd"

override def read: MLReader[IsHolidayTransformer] = new IsHolidayTransformerReader

override def load(path: String): IsHolidayTransformer = super.load(path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ class TestPrefferedLocsPartitionCoalescer extends AnyFunSuite with BeforeAndAfte
val spark: SparkSession = SparkSession
.builder()
.master(s"local-cluster[$num_workers, $num_cores, 1024]")
// .config("spark.jars", "target/scala-2.12/spark-lightautoml_2.12-0.1.1.jar,target/scala-2.12/spark-lightautoml_2.12-0.1.1-tests.jar")
.config("spark.jars", "target/scala-2.12/spark-lightautoml_2.12-0.1.1.jar")
.config("spark.default.parallelism", "6")
.config("spark.sql.shuffle.partitions", "6")
Expand Down Expand Up @@ -50,8 +49,6 @@ class TestPrefferedLocsPartitionCoalescer extends AnyFunSuite with BeforeAndAfte

coalesced_df.count()

// should not produce any exception
// coalesced_df.rdd.barrier().mapPartitions(SomeFunctions.func).count()
SomeFunctions.test_func(coalesced_df)
}
}
100 changes: 93 additions & 7 deletions sparklightautoml/dataset/base.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import functools
import importlib
import json
import logging
import os
import pickle
import uuid
import warnings
from abc import ABC, abstractmethod
from collections import Counter
from copy import copy
from copy import copy, deepcopy
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from json import JSONEncoder, JSONDecoder
from typing import Sequence, Any, Tuple, Union, Optional, List, cast, Dict, Set, Callable

import numpy as np
import pandas as pd
from lightautoml.dataset.base import (
LAMLDataset,
Expand All @@ -23,7 +27,8 @@
array_attr_roles,
)
from lightautoml.dataset.np_pd_dataset import PandasDataset, NumpyDataset, NpRoles
from lightautoml.dataset.roles import ColumnRole, NumericRole, DropRole
from lightautoml.dataset.roles import ColumnRole, NumericRole, DropRole, DatetimeRole, GroupRole, WeightsRole, \
FoldsRole, PathRole, TreatmentRole, DateRole
from lightautoml.tasks import Task
from pyspark.ml.functions import vector_to_array
from pyspark.sql import functions as sf, Column
Expand Down Expand Up @@ -57,6 +62,74 @@ def unpersist(self):
...


class SparkDatasetMetadataJsonEncoder(JSONEncoder):
def default(self, o: Any) -> Any:
if isinstance(o, ColumnRole):
dct = deepcopy(o.__dict__)
dct["__type__"] = ".".join([type(o).__module__, type(o).__name__])
dct['dtype'] = o.dtype.__name__

if isinstance(o, DatetimeRole):
dct['date_format'] = dct['format']
del dct['format']
dct['origin'] = o.origin.timestamp() if isinstance(o.origin, datetime) else o.origin

return dct

from sparklightautoml.tasks.base import SparkTask

if isinstance(o, SparkTask):
dct = {
"__type__": type(o).__name__,
"name": o.name,
"loss": o.loss_name,
"metric": o.metric_name,
# convert to python bool from numpy.bool_
"greater_is_better": True if o.greater_is_better else False
}

return dct

raise ValueError(f"Cannot handle object which is not a ColumRole. Object type: {type(o)}")


class SparkDatasetMetadataJsonDecoder(JSONDecoder):
@staticmethod
def _column_roles_object_hook(json_object):
from sparklightautoml.tasks.base import SparkTask
if json_object.get("__type__", None) == "SparkTask":
del json_object["__type__"]
return SparkTask(**json_object)

if "__type__" in json_object:
components = json_object["__type__"].split(".")
module_name = ".".join(components[:-1])
class_name = components[-1]
module = importlib.import_module(module_name)
clazz = getattr(module, class_name)

del json_object["__type__"]
if clazz in [DatetimeRole, DateRole]:
json_object["seasonality"] = tuple(json_object["seasonality"])
if isinstance(json_object["origin"], float):
json_object["origin"] = datetime.fromtimestamp(json_object["origin"])

json_object["dtype"] = getattr(np, json_object["dtype"])

if clazz in [GroupRole, DropRole, WeightsRole, FoldsRole, PathRole, TreatmentRole]:
del json_object["dtype"]

instance = clazz(**json_object)

return instance

return json_object

def __init__(self, **kwargs):
kwargs.setdefault("object_hook", self._column_roles_object_hook)
super().__init__(**kwargs)


class SparkDataset(LAMLDataset, Unpersistable):
"""
Implements a dataset that uses a ``pyspark.sql.DataFrame`` internally,
Expand Down Expand Up @@ -92,7 +165,7 @@ def load(cls,

# reading metadata
metadata_df = spark.read.format(file_format).options(**file_format_options).load(metadata_file_path)
metadata = pickle.loads(metadata_df.select('metadata').first().asDict()['metadata'])
metadata = json.loads(metadata_df.select('metadata').first().asDict()['metadata'], cls=SparkDatasetMetadataJsonDecoder)

# reading data
data_df = spark.read.format(file_format).options(**file_format_options).load(file_path)
Expand Down Expand Up @@ -505,7 +578,11 @@ def freeze(self) -> 'SparkDataset':
ds.set_data(self.data, self.features, self.roles, frozen=True)
return ds

def save(self, path: str, save_mode: str = 'error', file_format: str = 'parquet', file_format_options: Optional[Dict[str, Any]] = None):
def save(self,
path: str,
save_mode: str = 'error',
file_format: str = 'parquet',
file_format_options: Optional[Dict[str, Any]] = None):
metadata_file_path = os.path.join(path, f"metadata.{file_format}")
file_path = os.path.join(path, f"data.{file_format}")
file_format_options = file_format_options or dict()
Expand All @@ -518,7 +595,7 @@ def save(self, path: str, save_mode: str = 'error', file_format: str = 'parquet'
"target": self.target_column,
"folds": self.folds_column,
}
metadata_str = pickle.dumps(metadata)
metadata_str = json.dumps(metadata, cls=SparkDatasetMetadataJsonEncoder)
metadata_df = self.spark_session.createDataFrame([{"metadata": metadata_str}])

# create directory that will store data and metadata as separate files of dataframes
Expand All @@ -528,7 +605,16 @@ def save(self, path: str, save_mode: str = 'error', file_format: str = 'parquet'
metadata_df.write.format(file_format).mode(save_mode).options(**file_format_options).save(metadata_file_path)
# fix name of columns: parquet cannot have columns with '(' or ')' in the name
name_fixed_cols = (sf.col(c).alias(c.replace('(', '[').replace(')', ']')) for c in self.data.columns)
self.data.select(*name_fixed_cols).write.format(file_format).mode(save_mode).options(**file_format_options).save(file_path)
# write dataframe with fixed column names
(
self.data
.select(*name_fixed_cols)
.write
.format(file_format)
.mode(save_mode)
.options(**file_format_options)
.save(file_path)
)

def to_pandas(self) -> PandasDataset:
data, target_data, folds_data, roles = self._materialize_to_pandas()
Expand Down
24 changes: 6 additions & 18 deletions sparklightautoml/dataset/roles.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,9 @@ def __init__(
prob: If input number is probability.
"""
super().__init__(dtype, force_input, prob, discretization)
self._size = size
self._element_col_name_template = element_col_name_template
self._is_vector = is_vector

@property
def size(self):
return self._size

@property
def element_col_name_template(self):
return self._element_col_name_template

@property
def is_vector(self):
return self._is_vector
self.size = size
self.element_col_name_template = element_col_name_template
self.is_vector = is_vector

def feature_name_at(self, position: int) -> str:
"""
Expand All @@ -61,7 +49,7 @@ def feature_name_at(self, position: int) -> str:
"""
assert 0 <= position < self.size

if isinstance(self._element_col_name_template, str):
return self._element_col_name_template.format(position)
if isinstance(self.element_col_name_template, str):
return self.element_col_name_template.format(position)

return self._element_col_name_template[position]
return self.element_col_name_template[position]
3 changes: 2 additions & 1 deletion sparklightautoml/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import pandas as pd
from lightautoml.tasks import Task as LAMATask
from lightautoml.tasks.base import LAMLMetric, _default_losses
from lightautoml.tasks.base import LAMLMetric
from pyspark.ml.evaluation import (
BinaryClassificationEvaluator,
RegressionEvaluator,
Expand Down Expand Up @@ -144,6 +144,7 @@ def __init__(
f"The following metrics are supported: {list(self._supported_metrics[self.name].keys())}"
)

self.loss_name = loss
self.metric_name = metric

def get_dataset_metric(self) -> LAMLMetric:
Expand Down
45 changes: 43 additions & 2 deletions tests/spark/unit/test_spark_dataset.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import json
import os
import shutil
from datetime import datetime
from typing import Optional

import numpy as np
from lightautoml.dataset.roles import NumericRole
from lightautoml.dataset.roles import NumericRole, CategoryRole, DatetimeRole, DropRole, TextRole, DateRole, IdRole, \
TargetRole, GroupRole, WeightsRole, FoldsRole, PathRole, TreatmentRole
from lightautoml.tasks import Task
from pandas.testing import assert_frame_equal
from pyspark.sql import SparkSession

from sparklightautoml.dataset.base import SparkDataset
from sparklightautoml.dataset.base import SparkDataset, SparkDatasetMetadataJsonEncoder, SparkDatasetMetadataJsonDecoder
from sparklightautoml.dataset.roles import NumericVectorOrArrayRole
from sparklightautoml.tasks.base import SparkTask
from . import spark as spark_sess

Expand All @@ -31,6 +35,43 @@ def compare_dfs(dataset_a: SparkDataset, dataset_b: SparkDataset):
assert_frame_equal(df_a, df_b)


def test_column_roles_json_encoder_and_decoder():
roles = {
"num_role": NumericRole(),
"num_vect_role": NumericVectorOrArrayRole(size=20, element_col_name_template="some_template"),
"cat_role": CategoryRole(),
"id_role": IdRole(),
"dt_role": DatetimeRole(origin=datetime.now()),
"d_role": DateRole(),
"text_role": TextRole(),
"target_role": TargetRole(),
"group_role": GroupRole(),
"drop_role": DropRole(),
"weights_role": WeightsRole(),
"folds_role": FoldsRole(),
"path_role": PathRole(),
"treatment_role": TreatmentRole()
}

js_roles = json.dumps(roles, cls=SparkDatasetMetadataJsonEncoder)
deser_roles = json.loads(js_roles, cls=SparkDatasetMetadataJsonDecoder)

assert deser_roles == roles


def test_spark_task_json_encoder_decoder():
stask = SparkTask("binary")

js_stask = json.dumps(stask, cls=SparkDatasetMetadataJsonEncoder)
deser_stask = json.loads(js_stask, cls=SparkDatasetMetadataJsonDecoder)

stask_internals = [stask.name, stask.loss_name, stask.metric_name, stask.greater_is_better]
deser_stask_internals = [deser_stask.name, deser_stask.loss_name,
deser_stask.metric_name, deser_stask.greater_is_better]

assert deser_stask_internals == stask_internals


def test_spark_dataset_save_load(spark: SparkSession):
path = "/tmp/test_slama_ds.dataset"
partitions_num = 37
Expand Down

0 comments on commit 29b8cf8

Please sign in to comment.