Skip to content

Commit

Permalink
Perform data type conversion automatically (#133)
Browse files Browse the repository at this point in the history
* Perform data type conversion automatically

Signed-off-by: Khor Shu Heng <khor.heng@gojek.com>

* Use wheel installation for local setup to avoid module not found issue

Signed-off-by: Khor Shu Heng <khor.heng@gojek.com>

Co-authored-by: Khor Shu Heng <khor.heng@gojek.com>
  • Loading branch information
khorshuheng and khorshuheng committed Apr 1, 2022
1 parent 8cec4ce commit 7d3aa9d
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 35 deletions.
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ install-python-ci-dependencies:

# Supports feast-dev repo master branch
install-python: install-python-ci-dependencies
cd ${ROOT_DIR}/python; python setup.py install
pip install --user --upgrade setuptools wheel grpcio-tools mypy-protobuf
cd ${ROOT_DIR}/python; rm -rf dist; python setup.py bdist_wheel; pip install --find-links=dist feast-spark

lint-python:
cd ${ROOT_DIR}/python ; mypy feast_spark/ tests/
Expand Down
53 changes: 37 additions & 16 deletions spark/ingestion/src/main/scala/feast/ingestion/BasePipeline.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@
*/
package feast.ingestion

import feast.ingestion.utils.TypeConversion
import feast.ingestion.validation.TypeCheck
import org.apache.log4j.{Level, Logger}
import org.apache.spark.SparkConf
import org.apache.spark.sql.{Column, SparkSession}
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.functions.{col, expr}
import org.apache.spark.sql.streaming.StreamingQuery
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{Column, SparkSession}

object BasePipeline {

def createSparkSession(jobConfig: IngestionJobConfig): SparkSession = {
// workaround for issue with arrow & netty
// see https://github.com/apache/arrow/tree/master/java#java-properties
Expand Down Expand Up @@ -110,22 +114,39 @@ object BasePipeline {
def inputProjection(
source: Source,
features: Seq[Field],
entities: Seq[Field]
entities: Seq[Field],
inputSchema: StructType
): Array[Column] = {
val featureColumns = features
.filter(f => !source.fieldMapping.contains(f.name))
.map(f => (f.name, f.name)) ++ source.fieldMapping

val timestampColumn = Seq((source.eventTimestampColumn, source.eventTimestampColumn))
val entitiesColumns =
entities
.filter(e => !source.fieldMapping.contains(e.name))
.map(e => (e.name, e.name))

(featureColumns ++ entitiesColumns ++ timestampColumn).map { case (alias, source) =>
expr(source).alias(alias)
}.toArray
val typeByField =
(entities ++ features).map(f => f.name -> f.`type`).toMap
val columnDataTypes = inputSchema.fields
.map(f => f.name -> f.dataType)
.toMap

val entitiesFeaturesColumns: Seq[(String, String)] = (entities ++ features)
.map {
case f if source.fieldMapping.contains(f.name) => (f.name, source.fieldMapping(f.name))
case f => (f.name, f.name)
}

val entitiesFeaturesProjection: Seq[Column] = entitiesFeaturesColumns
.map {
case (alias, source) if !columnDataTypes.contains(source) =>
expr(source).alias(alias)
case (alias, source)
if TypeCheck.typesMatch(
typeByField(alias),
columnDataTypes(source)
) =>
col(source).alias(alias)
case (alias, source) =>
col(source).cast(TypeConversion.feastTypeToSqlType(typeByField(alias))).alias(alias)
}

val timestampProjection = Seq(col(source.eventTimestampColumn))
(entitiesFeaturesProjection ++ timestampProjection).toArray
}

}

trait BasePipeline {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ object BatchPipeline extends BasePipeline {
config: IngestionJobConfig
): Option[StreamingQuery] = {
val featureTable = config.featureTable
val projection =
BasePipeline.inputProjection(config.source, featureTable.features, featureTable.entities)
val rowValidator = new RowValidator(featureTable, config.source.eventTimestampColumn)
val metrics = new IngestionPipelineMetrics

Expand All @@ -62,18 +60,20 @@ object BatchPipeline extends BasePipeline {
)
}

val projection =
BasePipeline.inputProjection(
config.source,
featureTable.features,
featureTable.entities,
input.schema
)

val projected = if (config.deadLetterPath.nonEmpty) {
input.select(projection: _*).cache()
} else {
input.select(projection: _*)
}

TypeCheck.allTypesMatch(projected.schema, featureTable) match {
case Some(error) =>
throw new RuntimeException(s"Dataframe columns don't match expected feature types: $error")
case _ => ()
}

implicit val rowEncoder: Encoder[Row] = RowEncoder(projected.schema)

val validRows = projected
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ object StreamingPipeline extends BasePipeline with Serializable {
): Option[StreamingQuery] = {
import sparkSession.implicits._

val featureTable = config.featureTable
val projection =
BasePipeline.inputProjection(config.source, featureTable.features, featureTable.entities)
val featureTable = config.featureTable
val rowValidator = new RowValidator(featureTable, config.source.eventTimestampColumn)
val metrics = new IngestionPipelineMetrics
val streamingMetrics = new StreamingMetrics
Expand Down Expand Up @@ -104,16 +102,17 @@ object StreamingPipeline extends BasePipeline with Serializable {
val parsed = input
.withColumn("features", featureStruct)
.select(metadata :+ col("features.*"): _*)
val projection =
BasePipeline.inputProjection(
config.source,
featureTable.features,
featureTable.entities,
parsed.schema
)

val projected = parsed
.select(projection ++ metadata: _*)

TypeCheck.allTypesMatch(projected.schema, featureTable) match {
case Some(error) =>
throw new RuntimeException(s"Dataframe columns don't match expected feature types: $error")
case _ => ()
}

val sink = projected.writeStream
.foreachBatch { (batchDF: DataFrame, batchID: Long) =>
val rowsAfterValidation = if (validationUDF.nonEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,34 @@
package feast.ingestion.utils

import java.sql

import com.google.protobuf.{ByteString, Message, Timestamp}
import feast.proto.types.ValueProto
import feast.proto.types.ValueProto.ValueType
import org.apache.spark.sql.types._

import scala.collection.JavaConverters._
import scala.collection.mutable

object TypeConversion {

def feastTypeToSqlType(feastType: ValueType.Enum): DataType = {
feastType match {
case ValueType.Enum.BOOL => BooleanType
case ValueType.Enum.INT32 => IntegerType
case ValueType.Enum.INT64 => LongType
case ValueType.Enum.FLOAT => FloatType
case ValueType.Enum.DOUBLE => DoubleType
case ValueType.Enum.BYTES => BinaryType
case ValueType.Enum.BOOL_LIST => ArrayType(BooleanType)
case ValueType.Enum.INT32_LIST => ArrayType(IntegerType)
case ValueType.Enum.INT64_LIST => ArrayType(LongType)
case ValueType.Enum.FLOAT_LIST => ArrayType(FloatType)
case ValueType.Enum.DOUBLE_LIST => ArrayType(DoubleType)
case ValueType.Enum.BYTES_LIST => ArrayType(BinaryType)
case _ => throw new IllegalArgumentException(s"unsupported type conversion (${feastType})")
}
}

def sqlTypeToProtoValue(value: Any, `type`: DataType): Message = {
(`type` match {
case IntegerType => ValueProto.Value.newBuilder().setInt32Val(value.asInstanceOf[Int])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,48 @@ class BatchPipelineIT extends SparkSpec with ForAllTestContainer {
)
}

"Parquet source file" should "be coerced to the correct column types" in new Scope {
val gen = rowGenerator(DateTime.parse("2020-08-01"), DateTime.parse("2020-08-03"))
val rows = generateDistinctRows(gen, 100, groupByEntity)
val tempPath = storeAsParquet(sparkSession, rows)
val configWithDifferentColumnTypes = config.copy(
source = FileSource(tempPath, Map.empty, "eventTimestamp"),
featureTable = config.featureTable.copy(
features = Seq(
Field("feature1", ValueType.Enum.INT64),
Field("feature2", ValueType.Enum.DOUBLE)
)
)
)

BatchPipeline.createPipeline(sparkSession, configWithDifferentColumnTypes)

val featureKeyEncoder: String => String = encodeFeatureKey(config.featureTable)

rows.foreach(r => {
val encodedEntityKey = encodeEntityKey(r, config.featureTable)
val storedValues = jedis.hgetAll(encodedEntityKey).asScala.toMap
storedValues should beStoredRow(
Map(
featureKeyEncoder("feature1") -> r.feature1.toLong,
featureKeyEncoder("feature2") -> r.feature2.toDouble,
murmurHashHexString("_ts:test-fs") -> r.eventTimestamp
)
)
val keyTTL = jedis.ttl(encodedEntityKey).toInt
keyTTL shouldEqual -1

})

SparkEnv.get.metricsSystem.report()
statsDStub.receivedMetrics should contain.allElementsOf(
Map(
"driver.ingestion_pipeline.read_from_source_count" -> rows.length,
"driver.redis_sink.feature_row_ingested_count" -> rows.length
)
)
}

"Parquet source file" should "be ingested in redis with expiry time equal to the largest of (event_timestamp + max_age) for" +
"all feature tables associated with the entity" in new Scope {
val startDate = new DateTime().minusDays(1).withTimeAtStartOfDay()
Expand Down

0 comments on commit 7d3aa9d

Please sign in to comment.