Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zero copy import when the schema is known #261

Merged
merged 11 commits into from Aug 24, 2022
4 changes: 2 additions & 2 deletions app/com/lynxanalytics/biggraph/controllers/Operation.scala
Expand Up @@ -483,10 +483,10 @@ abstract class SmartOperation(context: Operation.Context) extends SimpleOperatio
table.schema.fieldNames.toList.map(n => FEOption(n, n))
}

protected def splitParam(param: String): Seq[String] = {
protected def splitParam(param: String, delimiter: String = ","): Seq[String] = {
val p = params(param)
if (p.trim.isEmpty) Seq()
else p.split(",", -1).map(_.trim)
else p.split(delimiter, -1).map(_.trim)
}
}

Expand Down
Expand Up @@ -308,16 +308,8 @@ class ImportOperations(env: SparkFreeEnvironment) extends ProjectOperations(env)
}
})

abstract class FileWithSchema(context: Context) extends ImportOperation(context) {
abstract class FileWithSchemaBase(context: Context) extends ImportOperation(context) {
val format: String
params ++= List(
FileParam("filename", "File"),
Param("imported_columns", "Columns to import"),
Param("limit", "Limit"),
Code("sql", "SQL", language = "sql"),
ImportedDataParam(),
new DummyParam("last_settings", ""),
)

override def summary = {
val fn = simpleFileName(params("filename"))
Expand All @@ -332,7 +324,18 @@ class ImportOperations(env: SparkFreeEnvironment) extends ProjectOperations(env)
}
}

registerImport("Import Parquet")(new FileWithSchema(_) { val format = "parquet" })
// Also adds parameters.
abstract class FileWithSchema(context: Context) extends FileWithSchemaBase(context) {
params ++= List(
FileParam("filename", "File"),
Param("imported_columns", "Columns to import"),
Param("limit", "Limit"),
Code("sql", "SQL", language = "sql"),
ImportedDataParam(),
new DummyParam("last_settings", ""),
)
}

registerImport("Import ORC")(new FileWithSchema(_) { val format = "orc" })
registerImport("Import JSON")(new FileWithSchema(_) { val format = "json" })
registerImport("Import AVRO")(new FileWithSchema(_) { val format = "avro" })
Expand All @@ -355,6 +358,45 @@ class ImportOperations(env: SparkFreeEnvironment) extends ProjectOperations(env)
}
})

registerImport("Import Parquet")(new FileWithSchemaBase(_) {
val format = "parquet"
params ++= List(
FileParam("filename", "File"),
Param("imported_columns", "Columns to import"),
Param("limit", "Limit"),
Code("sql", "SQL", language = "sql"),
Choice(
"eager",
"Import now or provide schema",
List(FEOption("yes", "Import now"), FEOption("no", "Provide schema"))),
new DummyParam("last_settings", ""),
)
params ++= (params("eager") match { // Hide/show import button and schema parameter.
case "yes" => List(
ImportedDataParam(),
new DummyParam("schema", ""),
)
case "no" => List(
new DummyParam("imported_table", ""),
Param("schema", "Schema"),
)
})

override def getOutputs(): Map[BoxOutput, BoxOutputState] = {
params.validate()
params("eager") match {
case "yes" =>
assert(params("imported_table").nonEmpty, "You have to import the data first.")
assert(!areSettingsStale, areSettingsStaleReplyMessage)
makeOutput(tableFromGuid(params("imported_table")))
case "no" =>
makeOutput(graph_operations.ReadParquetWithSchema.run(
params("filename"),
splitParam("schema", delimiter = ";")))
}
}
})

registerImport("Import from BigQuery (raw table)")(new ImportOperation(_) {
params ++= List(
Param("parent_project_id", "GCP project ID for billing"),
Expand Down
33 changes: 25 additions & 8 deletions app/com/lynxanalytics/biggraph/graph_api/JsonSerialization.scala
Expand Up @@ -104,7 +104,9 @@ object SerializableType {
val long = P("Long").map(_ => SerializableType.long)
val id = P("ID").map(_ => SerializableType.id)
val int = P("Int").map(_ => SerializableType.int)
val primitive = P(string | double | long | int | id)
val timestamp = P("Timestamp").map(_ => SerializableType.timestamp)
val date = P("Date").map(_ => SerializableType.date)
val primitive = P(string | double | long | int | id | timestamp | date)
val vector: PS = P("Vector[" ~ stype ~ "]").map {
inner => SerializableType.vector(inner)
}
Expand All @@ -127,14 +129,29 @@ object SerializableType {
val id = new SerializableType[com.lynxanalytics.biggraph.graph_api.ID]("ID")
val long = new SerializableType[Long]("Long")
val int = new SerializableType[Int]("Int")
val timestamp = {
implicit val o = new TimestampOrdering
implicit val f = TypeTagToFormat.formatTimestamp
new SerializableType[java.sql.Timestamp]("Timestamp")
}
val date = {
implicit val o = new DateOrdering
implicit val f = TypeTagToFormat.formatDate
new SerializableType[java.sql.Date]("Date")
}

// Every serializable type defines an ordering here, but we never use it for vectors.
class MockVectorOrdering[T: TypeTag] extends Ordering[Vector[T]] with Serializable {
// Custom orderings.
class TimestampOrdering extends Ordering[java.sql.Timestamp] with Serializable {
def compare(x: java.sql.Timestamp, y: java.sql.Timestamp): Int = x.compareTo(y)
}
class DateOrdering extends Ordering[java.sql.Date] with Serializable {
def compare(x: java.sql.Date, y: java.sql.Date): Int = x.compareTo(y)
}
// Orderings for types that should throw an error for comparisons.
class UnsupportedVectorOrdering[T: TypeTag] extends Ordering[Vector[T]] with Serializable {
def compare(x: Vector[T], y: Vector[T]): Int = ???
}

// Every serializable type defines an ordering here, but we never use it for tuple2s.
class MockTuple2Ordering[T1: TypeTag, T2: TypeTag] extends Ordering[(T1, T2)] with Serializable {
class UnsupportedTuple2Ordering[T1: TypeTag, T2: TypeTag] extends Ordering[(T1, T2)] with Serializable {
def compare(x: (T1, T2), y: (T1, T2)): Int = ???
}

Expand Down Expand Up @@ -194,12 +211,12 @@ class VectorSerializableType[T: TypeTag] private[graph_api] (
extends SerializableType[Vector[T]](typename)(
classTag = RuntimeSafeCastable.classTagFromTypeTag(typeTag),
format = TypeTagToFormat.vectorToFormat(typeTag),
ordering = new SerializableType.MockVectorOrdering()(typeTag),
ordering = new SerializableType.UnsupportedVectorOrdering()(typeTag),
typeTag = TypeTagUtil.vectorTypeTag(typeTag)) {}
class Tuple2SerializableType[T1: TypeTag, T2: TypeTag] private[graph_api] (
typename: String)
extends SerializableType[(T1, T2)](typename)(
classTag = RuntimeSafeCastable.classTagFromTypeTag(typeTag),
format = TypeTagToFormat.pairToFormat(typeTag[T1], typeTag[T2]),
ordering = new SerializableType.MockTuple2Ordering()(typeTag[T1], typeTag[T2]),
ordering = new SerializableType.UnsupportedTuple2Ordering()(typeTag[T1], typeTag[T2]),
typeTag = TypeTagUtil.tuple2TypeTag(typeTag[T1], typeTag[T2])) {}
20 changes: 20 additions & 0 deletions app/com/lynxanalytics/biggraph/graph_api/TypeTagToFormat.scala
Expand Up @@ -22,6 +22,8 @@ object TypeTagToFormat {
implicit val formatUIAttributeFilter = json.Json.format[UIAttributeFilter]
implicit val formatUICenterRequest = json.Json.format[UICenterRequest]
implicit val formatUIStatus = json.Json.format[UIStatus]
implicit val formatDate = new DateFormat
implicit val formatTimestamp = new TimestampFormat

implicit object ToJsonFormat extends json.Format[ToJson] {
def writes(t: ToJson): JsValue = {
Expand Down Expand Up @@ -82,6 +84,22 @@ object TypeTagToFormat {
}
}

class DateFormat extends json.Format[java.sql.Date] {
def reads(j: json.JsValue): json.JsResult[java.sql.Date] = {
val s = j.as[String]
json.JsResult.fromTry(util.Try(java.sql.Date.valueOf(s)))
}
def writes(v: java.sql.Date): json.JsValue = json.JsString(v.toString)
}

class TimestampFormat extends json.Format[java.sql.Timestamp] {
def reads(j: json.JsValue): json.JsResult[java.sql.Timestamp] = {
val s = j.as[String]
json.JsResult.fromTry(util.Try(java.sql.Timestamp.valueOf(s)))
}
def writes(v: java.sql.Timestamp): json.JsValue = json.JsString(v.toString)
}

class IDBucketsFormat[T: json.Format] extends json.Format[IDBuckets[T]] {
implicit val mft1 = new MapFormat[T, Long]
implicit val mft2 = new MapFormat[Long, T]
Expand Down Expand Up @@ -155,6 +173,8 @@ object TypeTagToFormat {
else if (TypeTagUtil.isType[DynamicValue](t)) implicitly[json.Format[DynamicValue]]
else if (TypeTagUtil.isType[UIStatus](t)) implicitly[json.Format[UIStatus]]
else if (TypeTagUtil.isType[Edge](t)) implicitly[json.Format[Edge]]
else if (TypeTagUtil.isType[java.sql.Date](t)) implicitly[json.Format[java.sql.Date]]
else if (TypeTagUtil.isType[java.sql.Timestamp](t)) implicitly[json.Format[java.sql.Timestamp]]
else if (TypeTagUtil.isSubtypeOf[ToJson](t)) ToJsonFormat
else if (TypeTagUtil.isOfKind1[Option](t)) {
val innerType = TypeTagUtil.typeArgs(tag).head
Expand Down
Expand Up @@ -38,3 +38,9 @@ class ScalarOutput[T: TypeTag](implicit instance: MetaGraphOperationInstance)
extends MagicOutput(instance) {
val sc = scalar[T]
}

class TableOutput(schema: org.apache.spark.sql.types.StructType)(
implicit instance: MetaGraphOperationInstance)
extends MagicOutput(instance) {
val t = table(schema)
}
Expand Up @@ -10,14 +10,15 @@ import org.apache.spark.sql.types

object ImportDataFrame extends OpFromJson {

def fromJson(j: JsValue) = {
def schemaFromJson(j: play.api.libs.json.JsLookupResult): types.StructType = {
// This is meta level, so we may not have a Spark session at this point.
// But we've got to allow reading old schemas for compatibility.
org.apache.spark.sql.internal.SQLConf.get.setConfString("spark.sql.legacy.allowNegativeScaleOfDecimal", "true")
new ImportDataFrame(
types.DataType.fromJson((j \ "schema").as[String]).asInstanceOf[types.StructType],
None,
(j \ "timestamp").as[String])
types.DataType.fromJson(j.as[String]).asInstanceOf[types.StructType]
}

def fromJson(j: JsValue) = {
new ImportDataFrame(schemaFromJson(j \ "schema"), None, (j \ "timestamp").as[String])
}

private def apply(df: DataFrame) = {
Expand Down
@@ -0,0 +1,63 @@
// Zero-copy import. Reads a Parquet file outside of LynxKite, for which we know the schema.
package com.lynxanalytics.biggraph.graph_operations

import com.lynxanalytics.biggraph.graph_api._
import com.lynxanalytics.biggraph.graph_util.HadoopFile
import com.lynxanalytics.biggraph.spark_util.SQLHelper

import org.apache.spark.sql.types

object ReadParquetWithSchema extends OpFromJson {
def fromJson(j: JsValue) = ReadParquetWithSchema(
(j \ "filename").as[String],
ImportDataFrame.schemaFromJson(j \ "schema"),
)

def parseSchema(strings: Seq[String]): types.StructType = {
val re = raw"\s*(.*?)\s*:\s*(.*?)\s*".r
val reArray = raw"array of\s+(.*)".r
types.StructType(strings.map {
case re(name, tpe) => types.StructField(
name = name,
dataType = types.DataType.fromJson {
tpe.toLowerCase match {
case "array" => throw new AssertionError(
"For array types you need to specify the element type too. For example: 'array of string'")
case reArray(t) =>
s"""{"type": "array", "containsNull": true, "elementType": "$t"}"""
case t => s""" "$t" """
}
},
)
case x => throw new AssertionError(s"The schema must be listed as 'column: type'. Got '$x'.")
})
}

def run(filename: String, schema: Seq[String])(implicit mm: MetaGraphManager): Table = {
import Scripting._
new ReadParquetWithSchema(filename, parseSchema(schema))().result.t
}
}

case class ReadParquetWithSchema(
filename: String,
schema: types.StructType)
extends SparkOperation[NoInput, TableOutput] {
override val isHeavy = true
@transient override lazy val inputs = new NoInput()
def outputMeta(instance: MetaGraphOperationInstance) = new TableOutput(schema)(instance)
override def toJson = Json.obj(
"filename" -> filename,
"schema" -> schema.prettyJson,
)

def execute(
inputDatas: DataSet,
o: TableOutput,
output: OutputBuilder,
rc: RuntimeContext): Unit = {
val f = HadoopFile(filename)
val df = rc.sparkDomain.sparkSession.read.schema(schema).parquet(f.resolvedName)
output(o.t, df)
}
}
25 changes: 19 additions & 6 deletions app/com/lynxanalytics/biggraph/serving/Utils.scala
@@ -1,13 +1,26 @@
package com.lynxanalytics.biggraph.serving

object Utils {
private def rootCause(t: Throwable): Throwable = Option(t.getCause).map(rootCause(_)).getOrElse(t)

private def causes(t: Throwable): List[Throwable] = {
t :: (Option(t.getCause) match {
case None => Nil
case Some(t) => causes(t)
})
}
private val assertionFailed = "^assertion failed: ".r
private val afterFirstLine = "(?s)\n.*".r

def formatThrowable(t: Throwable): String = rootCause(t) match {
// Trim "assertion failed: " from AssertionErrors.
case e: AssertionError => assertionFailed.replaceFirstIn(e.getMessage, "")
case e => e.toString
def formatThrowable(t: Throwable): String = {
val cs = causes(t)
val assertion = cs.collectFirst { case c: AssertionError => c }
assertion.map { t =>
// If we have an assertion, that should explain everything on its own.
assertionFailed.replaceFirstIn(t.getMessage, "")
}.getOrElse {
// Otherwise give a condensed version of the stack trace.
cs.flatMap { t =>
Option(t.getMessage).map { msg => afterFirstLine.replaceFirstIn(msg, "") }
}.mkString("\ncaused by:\n")
}
}
}
Expand Up @@ -379,6 +379,7 @@ class BigGraphKryoRegistrator extends KryoRegistrator {
kryo.register(Class.forName("org.apache.spark.sql.catalyst.expressions.NullsFirst$"))
kryo.register(classOf[Array[Array[List[_]]]])
kryo.register(classOf[Array[Array[Tuple2[_, _]]]])
kryo.register(org.apache.spark.sql.types.DateType.getClass)

// Add new stuff just above this line! Thanks.
// Adding Foo$mcXXX$sp? It is a type specialization. Register the decoded type instead!
Expand Down
1 change: 1 addition & 0 deletions dependency-licenses/scala.md
Expand Up @@ -24,6 +24,7 @@ Apache | [Apache License v2.0](http://www.apache.org/licenses/LICENSE-2.0.txt) |
Apache | [Apache License, Version 2](http://www.apache.org/licenses/LICENSE-2.0) | org.neo4j.driver # neo4j-java-driver # 4.2.5 | <notextile></notextile>
Apache | [Apache License, Version 2.0](https://aws.amazon.com/apache2.0) | com.amazonaws # aws-java-sdk # 1.7.4 | <notextile></notextile>
Apache | [Apache License, Version 2.0](http://www.apache.org/licenses/LICENSE-2.0.txt) | com.clearspring.analytics # stream # 2.9.6 | <notextile></notextile>
Apache | [Apache License, Version 2.0](http://www.apache.org/licenses/LICENSE-2.0.txt) | com.google.cloud.spark # spark-bigquery-with-dependencies_2.12 # 0.25.0 | <notextile></notextile>
Apache | [Apache License, Version 2.0](http://www.apache.org/licenses/LICENSE-2.0.txt) | com.google.guava # guava # 30.1-android | <notextile></notextile>
Apache | [Apache License, Version 2.0](http://www.apache.org/licenses/LICENSE-2.0) | com.jamesmurty.utils # java-xmlbuilder # 1.1 | <notextile></notextile>
Apache | [Apache License, Version 2.0](https://www.apache.org/licenses/LICENSE-2.0.txt) | commons-codec # commons-codec # 1.15 | <notextile></notextile>
Expand Down
@@ -1,5 +1,6 @@
package com.lynxanalytics.biggraph.frontend_operations

import scala.reflect.runtime.universe.TypeTag
import com.lynxanalytics.biggraph.graph_api.Scripting._
import com.lynxanalytics.biggraph.graph_api.GraphTestUtils._

Expand Down Expand Up @@ -54,7 +55,7 @@ class AggregateOnNeighborsTest extends OperationsTestBase {

class WeightedAggregateOnNeighborsTest extends OperationsTestBase {
test("all aggregators") {
def agg[T](attribute: String, aggregator: String, weight: String): Map[Long, T] = {
def agg[T: TypeTag](attribute: String, aggregator: String, weight: String): Map[Long, T] = {
val p = box("Create example graph")
.box(
"Weighted aggregate on neighbors",
Expand Down