Skip to content

Commit

Permalink
Rename transformed_column_name param in transform_spark()
Browse files Browse the repository at this point in the history
  • Loading branch information
deliahu committed Apr 5, 2019
1 parent 4a67d14 commit 39e9fab
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 22 deletions.
10 changes: 5 additions & 5 deletions cli/cmd/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def create_estimator(run_config, model_config):
# arg2: FLOAT
`,

"implementations/transformers/transformer.py": `def transform_spark(data, columns, args, transformed_column):
"implementations/transformers/transformer.py": `def transform_spark(data, columns, args, transformed_column_name):
"""Transform a column in a PySpark context.
This function is optional (recommended for large-scale data processing).
Expand All @@ -314,18 +314,18 @@ def create_estimator(run_config, model_config):
args: A dict with the same structure as the transformer's input args
containing the runtime values of the args.
transformed_column: The name of the column containing the transformed
transformed_column_name: The name of the column containing the transformed
data that is to be appended to the dataframe.
Returns:
The original 'data' dataframe with an added column with the name of the
transformed_column arg containing the transformed data.
The original 'data' dataframe with an added column named <transformed_column_name>
which contains the transformed data.
"""
## Sample transform_spark implementation:
#
# return data.withColumn(
# transformed_column, ((data[columns["num"]] - args["mean"]) / args["stddev"])
# transformed_column_name, ((data[columns["num"]] - args["mean"]) / args["stddev"])
# )
pass
Expand Down
12 changes: 6 additions & 6 deletions docs/applications/implementations/transformers.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Transformers run both when transforming data before model training and when resp
## Implementation

```python
def transform_spark(data, columns, args, transformed_column):
def transform_spark(data, columns, args, transformed_column_name):
"""Transform a column in a PySpark context.
This function is optional (recommended for large-scale data processing).
Expand All @@ -20,12 +20,12 @@ def transform_spark(data, columns, args, transformed_column):
args: A dict with the same structure as the transformer's input args
containing the runtime values of the args.
transformed_column: The name of the column containing the transformed
transformed_column_name: The name of the column containing the transformed
data that is to be appended to the dataframe.
Returns:
The original 'data' dataframe with an added column with the name of the
transformed_column arg containing the transformed data.
The original 'data' dataframe with an added column named <transformed_column_name>
which contains the transformed data.
"""
pass

Expand Down Expand Up @@ -69,9 +69,9 @@ def reverse_transform_python(transformed_value, args):
## Example

```python
def transform_spark(data, columns, args, transformed_column):
def transform_spark(data, columns, args, transformed_column_name):
return data.withColumn(
transformed_column, ((data[columns["num"]] - args["mean"]) / args["stddev"])
transformed_column_name, ((data[columns["num"]] - args["mean"]) / args["stddev"])
)

def transform_python(sample, args):
Expand Down
4 changes: 2 additions & 2 deletions examples/fraud/implementations/transformers/weight.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
def transform_spark(data, columns, args, transformed_column):
def transform_spark(data, columns, args, transformed_column_name):
import pyspark.sql.functions as F

distribution = args["class_distribution"]

return data.withColumn(
transformed_column,
transformed_column_name,
F.when(data[columns["col"]] == 0, distribution[1]).otherwise(distribution[0]),
)
6 changes: 3 additions & 3 deletions pkg/transformers/bucketize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
# limitations under the License.


def transform_spark(data, columns, args, transformed_column):
def transform_spark(data, columns, args, transformed_column_name):
from pyspark.ml.feature import Bucketizer
import pyspark.sql.functions as F

new_b = Bucketizer(
splits=args["bucket_boundaries"], inputCol=columns["num"], outputCol=transformed_column
splits=args["bucket_boundaries"], inputCol=columns["num"], outputCol=transformed_column_name
)
return new_b.transform(data).withColumn(
transformed_column, F.col(transformed_column).cast("int")
transformed_column_name, F.col(transformed_column_name).cast("int")
)


Expand Down
6 changes: 3 additions & 3 deletions pkg/transformers/index_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,16 @@
# limitations under the License.


def transform_spark(data, columns, args, transformed_column):
def transform_spark(data, columns, args, transformed_column_name):
from pyspark.ml.feature import StringIndexerModel
import pyspark.sql.functions as F

indexer = StringIndexerModel.from_labels(
args["index"], inputCol=columns["text"], outputCol=transformed_column
args["index"], inputCol=columns["text"], outputCol=transformed_column_name
)

return indexer.transform(data).withColumn(
transformed_column, F.col(transformed_column).cast("int")
transformed_column_name, F.col(transformed_column_name).cast("int")
)


Expand Down
4 changes: 2 additions & 2 deletions pkg/transformers/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
# limitations under the License.


def transform_spark(data, columns, args, transformed_column):
def transform_spark(data, columns, args, transformed_column_name):
return data.withColumn(
transformed_column, ((data[columns["num"]] - args["mean"]) / args["stddev"])
transformed_column_name, ((data[columns["num"]] - args["mean"]) / args["stddev"])
)


Expand Down
2 changes: 1 addition & 1 deletion pkg/workloads/lib/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ def resource_status_key(self, resource):

TRANSFORMER_IMPL_VALIDATION = {
"optional": [
{"name": "transform_spark", "args": ["data", "columns", "args", "transformed_column"]},
{"name": "transform_spark", "args": ["data", "columns", "args", "transformed_column_name"]},
{"name": "reverse_transform_python", "args": ["transformed_value", "args"]},
{"name": "transform_python", "args": ["sample", "args"]},
]
Expand Down

0 comments on commit 39e9fab

Please sign in to comment.