Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generic Input/Output of DataFrames #475

Merged
merged 12 commits into from
Aug 13, 2022
4 changes: 3 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ val localAndCloudCommonDependencies = Seq(
"net.snowflake" % "spark-snowflake_2.12" % "2.10.0-spark_3.2",
"org.apache.commons" % "commons-lang3" % "3.12.0",
"org.xerial" % "sqlite-jdbc" % "3.36.0.3",
"com.github.changvvb" %% "jackson-module-caseclass" % "1.1.1"
"com.github.changvvb" %% "jackson-module-caseclass" % "1.1.1",
"com.azure.cosmos.spark" % "azure-cosmos-spark_3-1_2-12" % "4.11.1",
windoze marked this conversation as resolved.
Show resolved Hide resolved
"org.eclipse.jetty" % "jetty-util" % "9.3.24.v20180605"
windoze marked this conversation as resolved.
Show resolved Hide resolved
) // Common deps

val jdbcDrivers = Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import com.linkedin.feathr.offline.ErasedEntityTaggedFeature
import com.linkedin.feathr.offline.anchored.anchorExtractor.{SQLConfigurableAnchorExtractor, SimpleConfigurableAnchorExtractor, TimeWindowConfigurableAnchorExtractor}
import com.linkedin.feathr.offline.anchored.feature.{FeatureAnchor, FeatureAnchorWithSource}
import com.linkedin.feathr.offline.anchored.keyExtractor.{MVELSourceKeyExtractor, SQLSourceKeyExtractor}
import com.linkedin.feathr.offline.config.location.{InputLocation, Jdbc, KafkaEndpoint, LocationUtils, SimplePath}
import com.linkedin.feathr.offline.config.location.{DataLocation, KafkaEndpoint, LocationUtils, SimplePath}
import com.linkedin.feathr.offline.derived._
import com.linkedin.feathr.offline.derived.functions.{MvelFeatureDerivationFunction, SQLFeatureDerivationFunction, SeqJoinDerivationFunction, SimpleMvelDerivationFunction}
import com.linkedin.feathr.offline.source.{DataSource, SourceFormatType, TimeWindowParams}
Expand Down Expand Up @@ -712,7 +712,7 @@ private[offline] class DataSourceLoader extends JsonDeserializer[DataSource] {
* 2. a placeholder with reserved string "PASSTHROUGH" for anchor defined pass-through features,
* since anchor defined pass-through features do not have path
*/
val path: InputLocation = dataSourceType match {
val path: DataLocation = dataSourceType match {
case "KAFKA" =>
Option(node.get("config")) match {
case Some(field: ObjectNode) =>
Expand All @@ -725,7 +725,7 @@ private[offline] class DataSourceLoader extends JsonDeserializer[DataSource] {
case "PASSTHROUGH" => SimplePath("PASSTHROUGH")
case _ => Option(node.get("location")) match {
case Some(field: ObjectNode) =>
LocationUtils.getMapper().treeToValue(field, classOf[InputLocation])
LocationUtils.getMapper().treeToValue(field, classOf[DataLocation])
case None => throw new FeathrConfigException(ErrorLabel.FEATHR_USER_ERROR,
s"Data location is not defined for data source ${node.toPrettyString()}")
case _ => throw new FeathrConfigException(ErrorLabel.FEATHR_USER_ERROR,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
package com.linkedin.feathr.offline.config.location

import com.fasterxml.jackson.annotation.{JsonSubTypes, JsonTypeInfo}
import com.fasterxml.jackson.core.JacksonException
import com.fasterxml.jackson.databind.module.SimpleModule
import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper}
import com.fasterxml.jackson.module.caseclass.mapper.CaseClassObjectMapper
import com.jasonclawson.jackson.dataformat.hocon.HoconFactory
import com.linkedin.feathr.common.FeathrJacksonScalaModule
import com.linkedin.feathr.common.{FeathrJacksonScalaModule, Header}
import com.linkedin.feathr.offline.config.DataSourceLoader
import com.linkedin.feathr.offline.source.DataSource
import com.typesafe.config.{Config, ConfigException}
import org.apache.spark.sql.{DataFrame, SparkSession}

import scala.collection.JavaConverters._

/**
* An InputLocation is a data source definition, it can either be HDFS files or a JDBC database connection
*/
Expand All @@ -20,38 +24,50 @@ import org.apache.spark.sql.{DataFrame, SparkSession}
new JsonSubTypes.Type(value = classOf[SimplePath], name = "path"),
new JsonSubTypes.Type(value = classOf[PathList], name = "pathlist"),
new JsonSubTypes.Type(value = classOf[Jdbc], name = "jdbc"),
new JsonSubTypes.Type(value = classOf[GenericLocation], name = "generic"),
))
trait InputLocation {
trait DataLocation {
/**
* Backward Compatibility
* Many existing codes expect a simple path
*
* @return the `path` or `url` of the data source
*
* WARN: This method is deprecated, you must use match/case on InputLocation,
* and get `path` from `SimplePath` only
* WARN: This method is deprecated, you must use match/case on InputLocation,
* and get `path` from `SimplePath` only
*/
@deprecated("Do not use this method in any new code, it will be removed soon")
def getPath: String

/**
* Backward Compatibility
*
* @return the `path` or `url` of the data source, wrapped in an List
*
* WARN: This method is deprecated, you must use match/case on InputLocation,
* and get `paths` from `PathList` only
* WARN: This method is deprecated, you must use match/case on InputLocation,
* and get `paths` from `PathList` only
*/
@deprecated("Do not use this method in any new code, it will be removed soon")
def getPathList: List[String]

/**
* Load DataFrame from Spark session
*
* @param ss SparkSession
* @return
*/
def loadDf(ss: SparkSession, dataIOParameters: Map[String, String] = Map()): DataFrame

/**
* Write DataFrame to the location
* @param ss SparkSession
* @param df DataFrame to write
*/
def writeDf(ss: SparkSession, df: DataFrame, header: Option[Header])

/**
* Tell if this location is file based
*
* @return boolean
*/
def isFileBasedLocation(): Boolean
Expand All @@ -67,6 +83,7 @@ object LocationUtils {
/**
* String template substitution, replace "...${VAR}.." with corresponding System property or environment variable
* Non-existent pattern is replaced by empty string.
*
* @param s String template to be processed
* @return Processed result
*/
Expand All @@ -76,6 +93,7 @@ object LocationUtils {

/**
* Get an ObjectMapper to deserialize DataSource
*
* @return the ObjectMapper
*/
def getMapper(): ObjectMapper = {
Expand All @@ -86,3 +104,50 @@ object LocationUtils {
.registerModule(new SimpleModule().addDeserializer(classOf[DataSource], new DataSourceLoader))
}
}

object DataLocation {
/**
* Create DataLocation from string, try parsing the string as JSON and fallback to SimplePath
* @param cfg the input string
* @return DataLocation
*/
def apply(cfg: String): DataLocation = {
val jackson = (new ObjectMapper(new HoconFactory) with CaseClassObjectMapper)
.registerModule(FeathrJacksonScalaModule) // DefaultScalaModule causes a fail on holdem
.configure(DeserializationFeature.ACCEPT_SINGLE_VALUE_AS_ARRAY, true)
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
.registerModule(new SimpleModule().addDeserializer(classOf[DataSource], new DataSourceLoader))
try {
// Cfg is either a plain path or a JSON object
if (cfg.trim.startsWith("{")) {
val location = jackson.readValue(cfg, classOf[DataLocation])
location
} else {
SimplePath(cfg)
}
} catch {
case _ @ (_: ConfigException | _: JacksonException) => SimplePath(cfg)
}
}

def apply(cfg: Config): DataLocation = {
apply(cfg.root().keySet().asScala.map(key ⇒ key → cfg.getString(key)).toMap)
}

def apply(cfg: Any): DataLocation = {
val jackson = (new ObjectMapper(new HoconFactory) with CaseClassObjectMapper)
windoze marked this conversation as resolved.
Show resolved Hide resolved
.registerModule(FeathrJacksonScalaModule) // DefaultScalaModule causes a fail on holdem
.configure(DeserializationFeature.ACCEPT_SINGLE_VALUE_AS_ARRAY, true)
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
.registerModule(new SimpleModule().addDeserializer(classOf[DataSource], new DataSourceLoader))
try {
val location = jackson.convertValue(cfg, classOf[DataLocation])
location
} catch {
case e: JacksonException => {
print(e)
SimplePath(cfg.toString)
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
package com.linkedin.feathr.offline.config.location

import com.fasterxml.jackson.annotation.JsonAnySetter
import com.fasterxml.jackson.module.caseclass.annotation.CaseClassDeserialize
import com.linkedin.feathr.common.Header
import com.linkedin.feathr.common.exception.FeathrException
import com.linkedin.feathr.offline.generation.FeatureGenUtils
import com.linkedin.feathr.offline.join.DataFrameKeyCombiner
import net.minidev.json.annotate.JsonIgnore
import org.apache.log4j.Logger
import org.apache.spark.sql.functions.monotonically_increasing_id
import org.apache.spark.sql.{DataFrame, DataFrameWriter, Row, SparkSession}

@CaseClassDeserialize()
case class GenericLocation(format: String, mode: Option[String] = None) extends DataLocation {
val log: Logger = Logger.getLogger(getClass)
val options: collection.mutable.Map[String, String] = collection.mutable.Map[String, String]()
val conf: collection.mutable.Map[String, String] = collection.mutable.Map[String, String]()

/**
* Backward Compatibility
* Many existing codes expect a simple path
*
* @return the `path` or `url` of the data source
*
* WARN: This method is deprecated, you must use match/case on DataLocation,
* and get `path` from `SimplePath` only
*/
override def getPath: String = s"GenericLocation(${format})"

/**
* Backward Compatibility
*
* @return the `path` or `url` of the data source, wrapped in an List
*
* WARN: This method is deprecated, you must use match/case on DataLocation,
* and get `paths` from `PathList` only
*/
override def getPathList: List[String] = List(getPath)

/**
* Load DataFrame from Spark session
*
* @param ss SparkSession
* @return
*/
override def loadDf(ss: SparkSession, dataIOParameters: Map[String, String]): DataFrame = {
GenericLocationFixes.readDf(ss, this)
}

/**
* Write DataFrame to the location
*
* @param ss SparkSession
* @param df DataFrame to write
*/
override def writeDf(ss: SparkSession, df: DataFrame, header: Option[Header]): Unit = {
GenericLocationFixes.writeDf(ss, df, header, this)
}

/**
* Tell if this location is file based
*
* @return boolean
*/
override def isFileBasedLocation(): Boolean = false

@JsonAnySetter
def setOption(key: String, value: Any): Unit = {
println(s"GenericLocation.setOption(key: $key, value: $value)")
if (key == null) {
log.warn("Got null key, skipping")
return
}
if (value == null) {
log.warn(s"Got null value for key '$key', skipping")
return
}
val v = value.toString
if (v == null) {
log.warn(s"Got invalid value for key '$key', skipping")
return
}
if (key.startsWith("__conf__")) {
conf += (key.stripPrefix("__conf__").replace("__", ".") -> LocationUtils.envSubstitute(v))
} else {
options += (key.replace("__", ".") -> LocationUtils.envSubstitute(v))
}
}
}

/**
* Some Spark connectors need extra actions before read or write, namely CosmosDb and ElasticSearch
* Need to run specific fixes base on `format`
*/
object GenericLocationFixes {
windoze marked this conversation as resolved.
Show resolved Hide resolved
def readDf(ss: SparkSession, location: GenericLocation): DataFrame = {
location.conf.foreach(e => {
ss.conf.set(e._1, e._2)
})
ss.read.format(location.format)
.options(location.options)
.load()
}

def writeDf(ss: SparkSession, df: DataFrame, header: Option[Header], location: GenericLocation) = {
location.conf.foreach(e => {
ss.conf.set(e._1, e._2)
})

location.format.toLowerCase() match {
case "cosmos.oltp" =>
windoze marked this conversation as resolved.
Show resolved Hide resolved
// Ensure the database and the table exist before writing
val endpoint = location.options.getOrElse("spark.cosmos.accountEndpoint", throw new FeathrException("Missing spark__cosmos__accountEndpoint"))
val key = location.options.getOrElse("spark.cosmos.accountKey", throw new FeathrException("Missing spark__cosmos__accountKey"))
val databaseName = location.options.getOrElse("spark.cosmos.database", throw new FeathrException("Missing spark__cosmos__database"))
val tableName = location.options.getOrElse("spark.cosmos.container", throw new FeathrException("Missing spark__cosmos__container"))
ss.conf.set("spark.sql.catalog.cosmosCatalog", "com.azure.cosmos.spark.CosmosCatalog")
ss.conf.set("spark.sql.catalog.cosmosCatalog.spark.cosmos.accountEndpoint", endpoint)
ss.conf.set("spark.sql.catalog.cosmosCatalog.spark.cosmos.accountKey", key)
ss.sql(s"CREATE DATABASE IF NOT EXISTS cosmosCatalog.${databaseName};")
ss.sql(s"CREATE TABLE IF NOT EXISTS cosmosCatalog.${databaseName}.${tableName} using cosmos.oltp TBLPROPERTIES(partitionKeyPath = '/id')")

// CosmosDb requires the column `id` to exist and be the primary key, and `id` must be in `string` type
val keyDf = if (!df.columns.contains("id")) {
header match {
case Some(h) => {
// Generate key column from header info, which is required by CosmosDb
val (keyCol, keyedDf) = DataFrameKeyCombiner().combine(df, FeatureGenUtils.getKeyColumnsFromHeader(h))
// Rename key column to `id`
keyedDf.withColumnRenamed(keyCol, "id")
}
case None => {
// If there is no key column, we use a auto-generated monotonic id.
// but in this case the result could be duplicated if you run job for multiple times
// This function is for offline-storage usage, ideally user should create a new container for every run
df.withColumn("id", (monotonically_increasing_id().cast("string")))
windoze marked this conversation as resolved.
Show resolved Hide resolved
}
}
} else {
// We already have an `id` column
// TODO: Should we do anything here?
windoze marked this conversation as resolved.
Show resolved Hide resolved
// A corner case is that the `id` column exists but not unique, then the output will be incomplete as
// CosmosDb will overwrite the old entry with the new one with same `id`.
// We can either rename the existing `id` column and use header/autogen key column, or we can tell user
// to avoid using `id` column for non-unique data, but both workarounds have pros and cons.
df
}
keyDf.write.format(location.format)
.options(location.options)
.mode(location.mode.getOrElse("append")) // CosmosDb doesn't support ErrorIfExist mode in batch mode
.save()
case _ =>
// Normal writing procedure, just set format and options then write
df.write.format(location.format)
.options(location.options)
.mode(location.mode.getOrElse("default"))
.save()
}
}
}
Loading