Skip to content

Commit

Permalink
Cleanup Scala code (#28)
Browse files Browse the repository at this point in the history
Signed-off-by: 陈易生 <chenyisheng14418@ipalfish.com>
Signed-off-by: Yik San Chan <evan.chanyiksan@gmail.com>

Co-authored-by: 陈易生 <chenyisheng14418@ipalfish.com>
  • Loading branch information
YikSanChan and 陈易生 committed Mar 26, 2021
1 parent ecc4083 commit 6ae6176
Show file tree
Hide file tree
Showing 10 changed files with 47 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.sql.{Column, SparkSession}
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.streaming.StreamingQuery

trait BasePipeline {
object BasePipeline {
def createSparkSession(jobConfig: IngestionJobConfig): SparkSession = {
// workaround for issue with arrow & netty
// see https://github.com/apache/arrow/tree/master/java#java-properties
Expand Down Expand Up @@ -75,8 +75,6 @@ trait BasePipeline {
.getOrCreate()
}

def createPipeline(sparkSession: SparkSession, config: IngestionJobConfig): Option[StreamingQuery]

/**
* Build column projection using custom mapping with fallback to feature|entity names.
*/
Expand All @@ -100,3 +98,8 @@ trait BasePipeline {
}.toArray
}
}

trait BasePipeline {

def createPipeline(sparkSession: SparkSession, config: IngestionJobConfig): Option[StreamingQuery]
}
30 changes: 16 additions & 14 deletions spark/ingestion/src/main/scala/feast/ingestion/BatchPipeline.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import feast.ingestion.validation.{RowValidator, TypeCheck}
import org.apache.commons.lang.StringUtils
import org.apache.spark.SparkEnv
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.streaming.StreamingQuery
import org.apache.spark.sql.{Encoder, Row, SaveMode, SparkSession}

/**
Expand All @@ -34,10 +35,13 @@ import org.apache.spark.sql.{Encoder, Row, SaveMode, SparkSession}
* 5. Store invalid rows in parquet format at `deadletter` destination
*/
object BatchPipeline extends BasePipeline {
override def createPipeline(sparkSession: SparkSession, config: IngestionJobConfig) = {
override def createPipeline(
sparkSession: SparkSession,
config: IngestionJobConfig
): Option[StreamingQuery] = {
val featureTable = config.featureTable
val projection =
inputProjection(config.source, featureTable.features, featureTable.entities)
BasePipeline.inputProjection(config.source, featureTable.features, featureTable.entities)
val rowValidator = new RowValidator(featureTable, config.source.eventTimestampColumn)
val metrics = new IngestionPipelineMetrics

Expand All @@ -60,14 +64,14 @@ object BatchPipeline extends BasePipeline {

val projected = input.select(projection: _*).cache()

implicit def rowEncoder: Encoder[Row] = RowEncoder(projected.schema)

TypeCheck.allTypesMatch(projected.schema, featureTable) match {
case Some(error) =>
throw new RuntimeException(s"Dataframe columns don't match expected feature types: $error")
case _ => ()
}

implicit val rowEncoder: Encoder[Row] = RowEncoder(projected.schema)

val validRows = projected
.map(metrics.incrementRead)
.filter(rowValidator.allChecks)
Expand All @@ -81,16 +85,14 @@ object BatchPipeline extends BasePipeline {
.option("max_age", config.featureTable.maxAge.getOrElse(0L))
.save()

config.deadLetterPath match {
case Some(path) =>
projected
.filter(!rowValidator.allChecks)
.map(metrics.incrementDeadLetters)
.write
.format("parquet")
.mode(SaveMode.Append)
.save(StringUtils.stripEnd(path, "/") + "/" + SparkEnv.get.conf.getAppId)
case _ => None
config.deadLetterPath foreach { path =>
projected
.filter(!rowValidator.allChecks)
.map(metrics.incrementDeadLetters)
.write
.format("parquet")
.mode(SaveMode.Append)
.save(StringUtils.stripEnd(path, "/") + "/" + SparkEnv.get.conf.getAppId)
}

None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ object IngestionJob {
}
})
.required()
.text("JSON-encoded source object (e.g. {\"kafka\":{\"bootstrapServers\":...}}")
.text("""JSON-encoded source object (e.g. {"kafka":{"bootstrapServers":...}}""")

opt[String](name = "feature-table")
.action((x, c) => {
Expand Down Expand Up @@ -103,10 +103,10 @@ object IngestionJob {
println(s"Starting with config $config")
config.mode match {
case Modes.Offline =>
val sparkSession = BatchPipeline.createSparkSession(config)
val sparkSession = BasePipeline.createSparkSession(config)
BatchPipeline.createPipeline(sparkSession, config)
case Modes.Online =>
val sparkSession = BatchPipeline.createSparkSession(config)
val sparkSession = BasePipeline.createSparkSession(config)
StreamingPipeline.createPipeline(sparkSession, config).get.awaitTermination
}
case None =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ abstract class StoreConfig

case class RedisConfig(host: String, port: Int, ssl: Boolean) extends StoreConfig

abstract class MetricConfig
sealed trait MetricConfig

case class StatsDConfig(host: String, port: Int) extends MetricConfig

abstract class DataFormat
case class ParquetFormat() extends DataFormat
case object ParquetFormat extends DataFormat
case class ProtoFormat(classPath: String) extends DataFormat
case class AvroFormat(schemaJson: String) extends DataFormat

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@
*/
package feast.ingestion

import java.io.File
import java.util.concurrent.TimeUnit

import feast.ingestion.metrics.IngestionPipelineMetrics
import feast.ingestion.registry.proto.ProtoRegistryFactory
import feast.ingestion.utils.ProtoReflection
Expand All @@ -27,16 +24,19 @@ import feast.ingestion.validation.{RowValidator, TypeCheck}
import org.apache.commons.io.FileUtils
import org.apache.commons.lang.StringUtils
import org.apache.spark.api.python.DynamicPythonFunction
import org.apache.spark.sql.avro._
import org.apache.spark.sql._
import org.apache.spark.sql.avro.functions.from_avro
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
import org.apache.spark.sql.execution.streaming.ProcessingTimeTrigger
import org.apache.spark.sql.functions.{expr, struct, udf}
import org.apache.spark.sql.streaming.StreamingQuery
import org.apache.spark.sql.types.BooleanType
import org.apache.spark.sql._
import org.apache.spark.{SparkEnv, SparkFiles}

import java.io.File
import java.util.concurrent.TimeUnit

/**
* Streaming pipeline (currently in micro-batches mode only, since we need to have multiple sinks: redis & deadletters).
* Flow:
Expand All @@ -55,7 +55,7 @@ object StreamingPipeline extends BasePipeline with Serializable {

val featureTable = config.featureTable
val projection =
inputProjection(config.source, featureTable.features, featureTable.entities)
BasePipeline.inputProjection(config.source, featureTable.features, featureTable.entities)
val rowValidator = new RowValidator(featureTable, config.source.eventTimestampColumn)
val metrics = new IngestionPipelineMetrics
val validationUDF = createValidationUDF(sparkSession, config)
Expand Down Expand Up @@ -105,7 +105,7 @@ object StreamingPipeline extends BasePipeline with Serializable {
}
rowsAfterValidation.persist()

implicit def rowEncoder: Encoder[Row] = RowEncoder(rowsAfterValidation.schema)
implicit val rowEncoder: Encoder[Row] = RowEncoder(rowsAfterValidation.schema)

rowsAfterValidation
.map(metrics.incrementRead)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,15 @@ import org.apache.spark.sql.Row
class IngestionPipelineMetrics extends Serializable {

def incrementDeadLetters(row: Row): Row = {
if (metricSource.nonEmpty)
metricSource.get.METRIC_DEADLETTER_ROWS_INSERTED.inc()

metricSource.foreach(_.METRIC_DEADLETTER_ROWS_INSERTED.inc())
row
}

def incrementRead(row: Row): Row = {
if (metricSource.nonEmpty)
metricSource.get.METRIC_ROWS_READ_FROM_SOURCE.inc()

metricSource.foreach(_.METRIC_ROWS_READ_FROM_SOURCE.inc())
row
}

def incrementRead(inc: Long): Unit = {
if (metricSource.nonEmpty)
metricSource.get.METRIC_ROWS_READ_FROM_SOURCE.inc(inc)
}

def incrementDeadLetters(inc: Long): Unit = {
if (metricSource.nonEmpty)
metricSource.get.METRIC_DEADLETTER_ROWS_INSERTED.inc(inc)
}

private lazy val metricSource: Option[IngestionPipelineMetricSource] = {
val metricsSystem = SparkEnv.get.metricsSystem
IngestionPipelineMetricsLock.synchronized {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,10 @@ object BigQueryReader {
.format("bigquery")
.option("viewsEnabled", "true")

source.materialization match {
case Some(materializationConfig) =>
reader
.option("materializationProject", materializationConfig.project)
.option("materializationDataset", materializationConfig.dataset)

case _ => ()
source.materialization foreach { materializationConfig =>
reader
.option("materializationProject", materializationConfig.project)
.option("materializationDataset", materializationConfig.dataset)
}

reader
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ class RedisSinkRelation(override val sqlContext: SQLContext, config: SparkRedisC
}

SparkEnv.get.metricsSystem.getSourcesByName(RedisSinkMetricSource.sourceName) match {
case Seq(head) => Some(head.asInstanceOf[RedisSinkMetricSource])
case _ => None
case Seq(source: RedisSinkMetricSource) => Some(source)
case _ => None
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ object TypeConversion {
case StringType => ValueProto.Value.newBuilder().setStringVal(value.asInstanceOf[String])
case DoubleType => ValueProto.Value.newBuilder().setDoubleVal(value.asInstanceOf[Double])
case FloatType => ValueProto.Value.newBuilder().setFloatVal(value.asInstanceOf[Float])
case StringType => ValueProto.Value.newBuilder().setStringVal(value.asInstanceOf[String])
case BooleanType => ValueProto.Value.newBuilder().setBoolVal(value.asInstanceOf[Boolean])
case BinaryType =>
ValueProto.Value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,11 @@ object DynamicPythonFunction {

def pythonVersion: String = {
runCommand(
List(pythonExec, "-c", "import sys; print(\"{0.major}.{0.minor}\".format(sys.version_info))")
List(
pythonExec,
"-c",
"""import sys; print("{0.major}.{0.minor}".format(sys.version_info))"""
)
)
}

Expand Down

0 comments on commit 6ae6176

Please sign in to comment.