Skip to content

Commit

Permalink
Make field mapping for batch source consistent with streaming source (#…
Browse files Browse the repository at this point in the history
…45)

Signed-off-by: Khor Shu Heng <khor.heng@go-jek.com>
  • Loading branch information
khorshuheng committed Mar 17, 2021
1 parent 49142c2 commit 8ed2d33
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
11 changes: 6 additions & 5 deletions python/feast_spark/pyspark/historical_feature_retrieval_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class FileSource(Source):
created_timestamp_column (str): Column representing the creation timestamp. Required
only if the source corresponds to a feature table.
field_mapping (Dict[str, str]): Optional. If present, the source column will be renamed
based on the mapping.
based on the mapping. The key would be the final result and the value would be the source column.
options (Optional[Dict[str, str]]): Options to be passed to spark while reading the file source.
"""

Expand Down Expand Up @@ -124,7 +124,7 @@ class BigQuerySource(Source):
dataset (str): BQ dataset.
table (str): BQ table.
field_mapping (Dict[str, str]): Optional. If present, the source column will be renamed
based on the mapping.
based on the mapping. The key would be the final result and value would be the source column.
event_timestamp_column (str): Column representing the event timestamp.
created_timestamp_column (str): Column representing the creation timestamp. Required
only if the source corresponds to a feature table.
Expand Down Expand Up @@ -279,8 +279,9 @@ class FileDestination(NamedTuple):


def _map_column(df: DataFrame, col_mapping: Dict[str, str]):
source_to_alias_map = {v: k for k, v in col_mapping.items()}
projection = [
col(col_name).alias(col_mapping.get(col_name, col_name))
col(col_name).alias(source_to_alias_map.get(col_name, col_name))
for col_name in df.columns
]
return df.select(projection)
Expand Down Expand Up @@ -667,14 +668,14 @@ def retrieve_historical_features(
"format": {"jsonClass": "ParquetFormat"},
"path": "file:///some_dir/customer_driver_pairs.csv"),
"options": {"inferSchema": "true", "header": "true"},
"field_mapping": {"id": "driver_id"}
"field_mapping": {"driver_id": "id"}
}
>>> feature_tables_sources_conf = [
{
"format": {"json_class": "ParquetFormat"},
"path": "gs://some_bucket/bookings.parquet"),
"field_mapping": {"id": "driver_id"}
"field_mapping": {"driver_id": "id"}
},
{
"format": {"json_class": "AvroFormat", schema_json: "..avro schema.."},
Expand Down
4 changes: 2 additions & 2 deletions python/tests/test_historical_feature_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,7 @@ def test_historical_feature_retrieval_with_mapping(spark: SparkSession):
"format": {"json_class": "CSVFormat"},
"path": f"file://{path.join(test_data_dir, 'column_mapping_test_entity.csv')}",
"event_timestamp_column": "event_timestamp",
"field_mapping": {"id": "customer_id"},
"field_mapping": {"customer_id": "id"},
"options": {"inferSchema": "true", "header": "true"},
}
}
Expand Down Expand Up @@ -717,7 +717,7 @@ def test_large_historical_feature_retrieval(
"format": {"json_class": "CSVFormat"},
"path": f"file://{large_entity_csv_file}",
"event_timestamp_column": "event_timestamp",
"field_mapping": {"id": "customer_id"},
"field_mapping": {"customer_id": "id"},
"options": {"inferSchema": "true", "header": "true"},
}
}
Expand Down

0 comments on commit 8ed2d33

Please sign in to comment.