Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SW-1353] Introduce logic flatting data frames with arbitrarily nested structures #1279

Merged
merged 14 commits into from Jun 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -52,7 +52,7 @@ private[h2o] object SparkDataFrameConverter extends Logging {
def toH2OFrame(hc: H2OContext, dataFrame: DataFrame, frameKeyName: Option[String]): H2OFrame = {
import H2OSchemaUtils._
// Flatten the Spark data frame so we don't have any nested rows
val flatDataFrame = flattenDataFrame(dataFrame)
val flatDataFrame = flattenStructsInDataFrame(dataFrame)
jakubhava marked this conversation as resolved.
Show resolved Hide resolved
val dfRdd = flatDataFrame.rdd
val keyName = frameKeyName.getOrElse("frame_rdd_" + dfRdd.id + Key.rand())

mn-mikke marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
221 changes: 214 additions & 7 deletions core/src/main/scala/org/apache/spark/h2o/utils/H2OSchemaUtils.scala
Expand Up @@ -20,10 +20,12 @@ package org.apache.spark.h2o.utils
import org.apache.spark.h2o._
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.mllib.linalg.SparseVector
import org.apache.spark.sql._
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types._
import org.apache.spark.{SparkContext, ml, mllib}
import org.apache.spark.{ml, mllib}

import scala.collection.mutable.ArrayBuffer

/**
* Utilities for working with Spark SQL component.
Expand Down Expand Up @@ -70,23 +72,228 @@ object H2OSchemaUtils {
StructType(types)
}

def flattenSchema(schema: StructType, prefix: String = null, nullable: Boolean = false): StructType = {
def flattenDataFrame(df: DataFrame): DataFrame = {
val schema = flattenSchema(df)
flattenDataFrame(df, schema)
}

def flattenDataFrame(df: DataFrame, flatSchema: StructType): DataFrame = {
implicit val rowEncoder = RowEncoder(flatSchema)
val numberOfColumns = flatSchema.fields.length
val nameToIndexMap = flatSchema.fields.map(_.name).zipWithIndex.toMap
val originalSchema = df.schema
df.map[Row] { row: Row =>
val result = ArrayBuffer.fill[Any](numberOfColumns)(null)
val fillBufferPartiallyApplied = fillBuffer(nameToIndexMap, result) _
originalSchema.fields.zipWithIndex.foreach { case (field, idx) =>
fillBufferPartiallyApplied(field, row(idx), None)
}
Row.fromSeq(result)
}
}

private def fillBuffer
(flatSchemaIndexes: Map[String, Int], buffer: ArrayBuffer[Any])
(field: StructField, data: Any, prefix: Option[String] = None): Unit = {
if (data != null) {
val StructField(name, dataType, _, _) = field
val qualifiedName = getQualifiedName(prefix, name)
dataType match {
case BinaryType => fillArray(ByteType, flatSchemaIndexes, buffer, data, qualifiedName)
case MapType(_, valueType, _) => fillMap(valueType, flatSchemaIndexes, buffer, data, qualifiedName)
case ArrayType(elementType, _) => fillArray(elementType, flatSchemaIndexes, buffer, data, qualifiedName)
case StructType(fields) => fillStruct(fields, flatSchemaIndexes, buffer, data, qualifiedName)
case _ => buffer(flatSchemaIndexes(qualifiedName)) = data
}
}
}

private def fillArray(
elementType: DataType,
flatSchemaIndexes: Map[String, Int],
buffer: ArrayBuffer[Any],
data: Any,
qualifiedName: String): Unit = {
val seq = data.asInstanceOf[Seq[Any]]
val subRow = Row.fromSeq(seq)
val fillBufferPartiallyApplied = fillBuffer(flatSchemaIndexes, buffer) _
(0 until seq.size).foreach { idx =>
val arrayField = StructField(idx.toString, elementType)
fillBufferPartiallyApplied(arrayField, subRow(idx), Some(qualifiedName))
}
}

private def fillMap(
valueType: DataType,
flatSchemaIndexes: Map[String, Int],
buffer: ArrayBuffer[Any],
data: Any,
qualifiedName: String): Unit = {
val map = data.asInstanceOf[Map[Any, Any]]
val subRow = Row.fromSeq(map.values.toSeq)
val fillBufferPartiallyApplied = fillBuffer(flatSchemaIndexes, buffer) _
map.keys.zipWithIndex.foreach { case (key, idx) =>
val mapField = StructField(key.toString, valueType)
fillBufferPartiallyApplied(mapField, subRow(idx), Some(qualifiedName))
}
}

private def fillStruct(
fields: Seq[StructField],
flatSchemaIndexes: Map[String, Int],
buffer: ArrayBuffer[Any],
data: Any,
qualifiedName: String): Unit = {
val subRow = data.asInstanceOf[Row]
val fillBufferPartiallyApplied = fillBuffer(flatSchemaIndexes, buffer) _
fields.zipWithIndex.foreach { case (subField, idx) =>
fillBufferPartiallyApplied(subField, subRow(idx), Some(qualifiedName))
}
}

def flattenSchema(df: DataFrame): StructType = {
val rowSchemas = rowsToRowSchemas(df)
val mergedSchema = mergeRowSchemas(rowSchemas)
StructType(mergedSchema.map(_.field))
}

private def rowsToRowSchemas(df: DataFrame): Dataset[Seq[FieldWithOrder]] = {
implicit val encoder = org.apache.spark.sql.Encoders.kryo[Seq[FieldWithOrder]]
val originalSchema = df.schema
df.map[Seq[FieldWithOrder]] { row: Row =>
originalSchema.fields.zipWithIndex.foldLeft(Seq.empty[FieldWithOrder]) {
case (acc, (field, index)) => acc ++ flattenField(field, row(index), index :: Nil)
}
}
}

private def mergeRowSchemas(ds: Dataset[Seq[FieldWithOrder]]): Seq[FieldWithOrder] = ds.reduce {
(first, second) =>
val firstMap = convertRowSchemaToPathToFieldMap(first)
val secondMap = convertRowSchemaToPathToFieldMap(second)
val keys = (firstMap.keySet ++ secondMap.keySet).toSeq.sorted(fieldPathOrdering)
keys.map { key =>
(firstMap.get(key), secondMap.get(key)) match {
case (None, Some(StructField(name, dataType, _, _))) =>
FieldWithOrder(StructField(name, dataType, true), key)
case (Some(StructField(name, dataType, _, _)), None) =>
FieldWithOrder(StructField(name, dataType, true), key)
case (Some(StructField(name, dataType, nullable1, _)), Some(StructField(_, _, nullable2, _))) =>
FieldWithOrder(StructField(name, dataType, nullable1 || nullable2), key)
case (None, None) =>
throw new IllegalStateException(s"There must be a corresponding value for key '$key' in one map at least.")
}
}
}

@transient private lazy val fieldPathOrdering = {
val segmentOrdering = new Ordering[Any] {
override def compare(x: Any, y: Any): Int = (x, y) match {
case (a: Int, b: Int) => a.compareTo(b)
case (a: String, b: String) => a.compareTo(b)
case (a, b) => a.toString.compareTo(b.toString)
}
}
Ordering.Iterable(segmentOrdering)
}

private def convertRowSchemaToPathToFieldMap(rowSchema: Seq[FieldWithOrder]): Map[Seq[Any], StructField] = {
rowSchema.map(f => f.order -> f.field).toMap
}

private def getQualifiedName(prefix: Option[String], name: String): String = prefix match {
case None => name
case Some(p) => s"${p}_$name"
}

private def flattenField(
originalField: StructField,
data: Any,
path: List[Any],
prefix: Option[String] = None,
isParentNullable: Boolean = false): Seq[FieldWithOrder] = {
if (data != null) {
val StructField(name, dataType, nullable, _) = originalField
val qualifiedName = getQualifiedName(prefix, name)
val nullableField = isParentNullable || nullable
dataType match {
case BinaryType =>
flattenArrayType(ByteType, false, data, path, qualifiedName, nullableField)
case MapType(_, valueType, containsNull) =>
flattenMapType(valueType, containsNull, data, path, qualifiedName, nullableField)
case ArrayType(elementType, containsNull) =>
flattenArrayType(elementType, containsNull, data, path, qualifiedName, nullableField)
case StructType(fields) =>
flattenStructType(fields, data, path, qualifiedName, nullableField)
case dt =>
FieldWithOrder(StructField(qualifiedName, dt, nullableField), path.reverse) :: Nil
}
} else {
Nil
}
}

private case class FieldWithOrder(field: StructField, order: Seq[Any])

private def flattenArrayType(
elementType: DataType,
containsNull: Boolean,
data: Any,
path: List[Any],
qualifiedName: String,
nullableField: Boolean) = {
val values = data.asInstanceOf[Seq[Any]]
val subRow = Row.fromSeq(values)
(0 until values.size).flatMap { idx =>
val arrayField = StructField(idx.toString(), elementType, containsNull)
flattenField(arrayField, subRow(idx), idx :: path, Some(qualifiedName), nullableField)
}
}

private def flattenMapType(
valueType: DataType,
containsNull: Boolean,
data: Any,
path: List[Any],
qualifiedName: String,
nullableField: Boolean) = {
val map = data.asInstanceOf[Map[Any, Any]]
val subRow = Row.fromSeq(map.values.toSeq)
map.keys.zipWithIndex.flatMap { case (key, idx) =>
val mapField = StructField(key.toString, valueType, containsNull)
flattenField(mapField, subRow(idx), key :: path, Some(qualifiedName), nullableField)
}.toSeq
}

private def flattenStructType(
fields: Seq[StructField],
data: Any,
path: List[Any],
qualifiedName: String,
nullableField: Boolean) = {
val subRow = data.asInstanceOf[Row]
fields.zipWithIndex.flatMap { case (subField, idx) =>
flattenField(subField, subRow(idx), idx :: path, Some(qualifiedName), nullableField)
}
}

def flattenStructsInSchema(schema: StructType, prefix: String = null, nullable: Boolean = false): StructType = {

val flattened = schema.fields.flatMap { f =>
val escaped = if (f.name.contains(".")) "`" + f.name + "`" else f.name
val colName = if (prefix == null) escaped else prefix + "." + escaped

f.dataType match {
case st: StructType => flattenSchema(st, colName, nullable || f.nullable)
case st: StructType => flattenStructsInSchema(st, colName, nullable || f.nullable)
case _ => Array[StructField](StructField(colName, f.dataType, nullable || f.nullable))
}
}
StructType(flattened)
}

def flattenDataFrame(df: DataFrame): DataFrame = {
def flattenStructsInDataFrame(df: DataFrame): DataFrame = {
import org.apache.spark.sql.functions.col
val flatten = flattenSchema(df.schema)
val flatten = flattenStructsInSchema(df.schema)
val cols = flatten.map(f => col(f.name).as(f.name.replaceAll("`", "")))
df.select(cols: _*)
}
Expand Down