Skip to content

Commit

Permalink
Added parallel read functionality
Browse files Browse the repository at this point in the history
Summary:
Added parallel read functionality
Design doc - https://docs.google.com/document/d/1X11i1dV0V5Mf7G0UwNIuBApjEx0CitdtaQtK9vF-ws8/edit?usp=sharing
Disabled sort/limit pushdown as they may work incorrectly with parallel read.
Task to investigate sort/limit pushdown - https://memsql.atlassian.net/browse/PLAT-5893
**Design doc/spec**:
**Docs impact**: none

Test Plan:
Now all tests are trying to use ReadFromAggregators and if they can - then ReadFromLeaves.
Added special workflow that tests ReadFromLeaves.
https://webapp.io/memsql/commits?query=repo%3Asinglestore-spark-connector+id%3A26

Reviewers: carl, cchen, pmishchenko-ua

Reviewed By: carl

Subscribers: rob, jprice, engineering-list

JIRA Issues: PLAT-5844

Differential Revision: https://grizzly.internal.memcompute.com/D52565
  • Loading branch information
AdalbertMemSQL committed Nov 23, 2021
1 parent c78a5dc commit 9c5ccf8
Show file tree
Hide file tree
Showing 22 changed files with 988 additions and 338 deletions.
4 changes: 2 additions & 2 deletions Layerfile
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ SECRET ENV LICENSE_KEY
MEMORY 4G
MEMORY 8G

# split to 8 states
# split to 9 states
# each of them will run different version of the singlestore and spark
SPLIT 8
SPLIT 9

# copy the entire git repository
COPY . .
Expand Down
6 changes: 6 additions & 0 deletions scripts/define-layerci-matrix.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ else
echo 'export SINGLESTORE_PASSWORD="password"'
fi

if [ "$TEST_NUM" == '9' ]
then
echo 'export FORCE_READ_FROM_LEAVES=TRUE'
else
echo 'export FORCE_READ_FROM_LEAVES=FALSE'
fi

if [ "$TEST_NUM" == '0' ] || [ "$TEST_NUM" == '2' ] || [ "$TEST_NUM" == '4' ] || [ "$TEST_NUM" == '6' ]
then
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package org.apache.spark.scheduler

import org.apache.spark.rdd.RDD

object MaxNumConcurrentTasks {
def get(rdd: RDD[_]): Int = {
rdd.sparkContext.maxNumConcurrentTasks()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package org.apache.spark.scheduler

import org.apache.spark.rdd.RDD

object MaxNumConcurrentTasks {
def get(rdd: RDD[_]): Int = {
val (_, resourceProfiles) =
rdd.sparkContext.dagScheduler.getShuffleDependenciesAndResourceProfiles(rdd)
val resourceProfile =
rdd.sparkContext.dagScheduler.mergeResourceProfilesForStage(resourceProfiles)
rdd.sparkContext.maxNumConcurrentTasks(resourceProfile)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
package com.singlestore.spark

import java.sql.Connection

import com.singlestore.spark.SQLGen.VariableList
import org.apache.spark.SparkContext
import org.apache.spark.scheduler.{
SparkListener,
SparkListenerStageCompleted,
SparkListenerStageSubmitted
}
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
import org.apache.spark.sql.types.StructType

import scala.collection.mutable

class AggregatorParallelReadListener(applicationId: String) extends SparkListener with LazyLogging {
// connectionsMap is a map from the result table name to the connection with which this table was created
private val connectionsMap: mutable.Map[String, Connection] =
new mutable.HashMap[String, Connection]()

// rddInfos is a map from RDD id to the info needed to create result table for this RDD
private val rddInfos: mutable.Map[Int, SingleStoreRDDInfo] =
new mutable.HashMap[Int, SingleStoreRDDInfo]()

// SingleStoreRDDInfo is information needed to create a result table
private case class SingleStoreRDDInfo(query: String,
variables: VariableList,
schema: StructType,
connectionOptions: JDBCOptions,
materialized: Boolean,
needsRepartition: Boolean)

def addRDDInfo(rdd: SinglestoreRDD): Unit = {
rddInfos.synchronized({
rddInfos += (rdd.id -> SingleStoreRDDInfo(
rdd.query,
rdd.variables,
rdd.schema,
JdbcHelpers.getDDLJDBCOptions(rdd.options),
rdd.parallelReadType.contains(ReadFromAggregatorsMaterialized),
rdd.options.parallelReadRepartition
))
})
}

def deleteRDDInfo(rdd: SinglestoreRDD): Unit = {
rddInfos.synchronized({
rddInfos -= rdd.id
})
}

def isEmpty: Boolean = {
rddInfos.synchronized({
rddInfos.isEmpty
})
}

override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = {
stageSubmitted.stageInfo.rddInfos.foreach(rddInfo => {
if (rddInfo.name == "SinglestoreRDD") {
rddInfos
.synchronized(
rddInfos.get(rddInfo.id)
)
.foreach(singleStoreRDDInfo => {
val stageId = stageSubmitted.stageInfo.stageId
val tableName = JdbcHelpers.getResultTableName(applicationId, stageId, rddInfo.id)

// Create connection and save it in the map
val conn = JdbcUtils.createConnectionFactory(singleStoreRDDInfo.connectionOptions)()
connectionsMap.synchronized(
connectionsMap += (tableName -> conn)
)

// Create result table
JdbcHelpers.createResultTable(
conn,
tableName,
singleStoreRDDInfo.query,
singleStoreRDDInfo.schema,
singleStoreRDDInfo.variables,
singleStoreRDDInfo.materialized,
singleStoreRDDInfo.needsRepartition
)
})
}
})
}

override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
stageCompleted.stageInfo.rddInfos.foreach(rddInfo => {
if (rddInfo.name == "SinglestoreRDD") {
val stageId = stageCompleted.stageInfo.stageId
val tableName = JdbcHelpers.getResultTableName(applicationId, stageId, rddInfo.id)

connectionsMap.synchronized(
connectionsMap
.get(tableName)
.foreach(conn => {
// Drop result table
JdbcHelpers.dropResultTable(conn, tableName)
// Close connection
conn.close()
// Delete connection from map
connectionsMap -= tableName
})
)
}
})
}
}

case object AggregatorParallelReadListenerAdder {
// listeners is a map from SparkContext hash code to the listener associated with this SparkContext
private val listeners = new mutable.HashMap[SparkContext, AggregatorParallelReadListener]()

def addRDD(rdd: SinglestoreRDD): Unit = {
this.synchronized({
val listener = listeners.getOrElse(
rdd.sparkContext, {
val newListener = new AggregatorParallelReadListener(rdd.sparkContext.applicationId)
rdd.sparkContext.addSparkListener(newListener)
listeners += (rdd.sparkContext -> newListener)
newListener
}
)
listener.addRDDInfo(rdd)
})
}

def deleteRDD(rdd: SinglestoreRDD): Unit = {
this.synchronized({
listeners
.get(rdd.sparkContext)
.foreach(listener => {
listener.deleteRDDInfo(rdd)
if (listener.isEmpty) {
listeners -= rdd.sparkContext
rdd.sparkContext.removeSparkListener(listener)
}
})
})
}
}
84 changes: 81 additions & 3 deletions src/main/scala/com/singlestore/spark/JdbcHelpers.scala
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
package com.singlestore.spark

import java.sql.{Connection, PreparedStatement, Statement, SQLException}
import java.sql.{Connection, PreparedStatement, SQLException, Statement}
import java.util.UUID.randomUUID

import com.singlestore.spark.SinglestoreOptions.{TableKey, TableKeyType}
import com.singlestore.spark.SQLGen.{StringVar, VariableList}
import org.apache.spark.sql.{Row, SaveMode}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier}
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.types.{StringType, StructType}

import scala.util.Try
import scala.util.{Failure, Success, Try}

case class SinglestorePartitionInfo(ordinal: Int, name: String, hostport: String)

Expand Down Expand Up @@ -343,6 +344,83 @@ object JdbcHelpers extends LazyLogging {
conn.withStatement(stmt => stmt.executeUpdate(sql))
}

def getPartitionsCount(conn: Connection, database: String): Int = {
val sql =
s"SELECT num_partitions FROM information_schema.DISTRIBUTED_DATABASES WHERE database_name = '$database'"
log.trace(s"Executing SQL:\n$sql")
val resultSet = conn.withStatement(stmt => stmt.executeQuery(sql))

if (resultSet.next()) {
resultSet.getInt("num_partitions")
} else {
throw new IllegalArgumentException(
s"Failed to get number of partitions for '$database' database")
}
}

def getResultTableName(applicationId: String, stageId: Int, rddId: Int): String = {
s"rt_${applicationId.replace("-", "")}_${stageId}_${rddId}"
}

def getCreateResultTableQuery(tableName: String,
query: String,
schema: StructType,
materialized: Boolean,
needsRepartition: Boolean): String = {
val materializedStr = { if (materialized) { "MATERIALIZED" } else "" }
if (needsRepartition) {
val randColName = s"randColumn${randomUUID().toString.replace("-", "")}"
s"CREATE $materializedStr RESULT TABLE $tableName PARTITION BY ($randColName) AS SELECT *, RAND() AS $randColName FROM ($query)"
} else {
s"CREATE $materializedStr RESULT TABLE $tableName AS $query"
}
}

def getSelectFromResultTableQuery(tableName: String, partition: Int): String = {
s"SELECT * FROM ::$tableName WHERE partition_id() = $partition"
}

def createResultTable(conn: Connection,
tableName: String,
query: String,
schema: StructType,
variables: VariableList,
materialized: Boolean,
needsRepartition: Boolean): Unit = {
val sql =
getCreateResultTableQuery(tableName, query, schema, materialized, needsRepartition)
log.trace(s"Executing SQL:\n$sql")

conn.withPreparedStatement(sql, stmt => {
JdbcHelpers.fillStatement(stmt, variables)
stmt.executeUpdate()
})
}

def dropResultTable(conn: Connection, tableName: String): Unit = {
val sql = s"DROP RESULT TABLE $tableName"
log.trace(s"Executing SQL:\n$sql")

conn.withStatement(stmt => {
stmt.executeUpdate(sql)
})
}

def isValidQuery(conn: Connection, query: String, variables: VariableList): Boolean = {
val sql = s"EXPLAIN $query"
log.trace(s"Executing SQL:\n$sql")

Try {
conn.withPreparedStatement(sql, stmt => {
JdbcHelpers.fillStatement(stmt, variables)
stmt.execute()
})
} match {
case Success(_) => true
case Failure(_) => false
}
}

def truncateTable(conn: Connection, table: TableIdentifier): Unit = {
val sql = s"TRUNCATE TABLE ${table.quotedString}"
log.trace(s"Executing SQL:\n$sql")
Expand Down
26 changes: 26 additions & 0 deletions src/main/scala/com/singlestore/spark/ParallelReadEnablement.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package com.singlestore.spark

sealed trait ParallelReadEnablement

case object Disabled extends ParallelReadEnablement
case object Automatic extends ParallelReadEnablement
case object Forced extends ParallelReadEnablement

object ParallelReadEnablement {
def apply(value: String): ParallelReadEnablement = value.toLowerCase match {
case "disabled" => Disabled
case "automatic" => Automatic
case "forced" => Forced

// These two options are added for compatibility purposes
case "false" => Disabled
case "true" => Automatic

case _ =>
throw new IllegalArgumentException(
s"""Illegal argument for `${SinglestoreOptions.ENABLE_PARALLEL_READ}` option. Valid arguments are:
| - "Disabled"
| - "Automatic"
| - "Forced"""".stripMargin)
}
}
21 changes: 21 additions & 0 deletions src/main/scala/com/singlestore/spark/ParallelReadType.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package com.singlestore.spark

sealed trait ParallelReadType

case object ReadFromLeaves extends ParallelReadType
case object ReadFromAggregators extends ParallelReadType
case object ReadFromAggregatorsMaterialized extends ParallelReadType

object ParallelReadType {
def apply(value: String): ParallelReadType = value.toLowerCase match {
case "readfromleaves" => ReadFromLeaves
case "readfromaggregators" => ReadFromAggregators
case "readfromaggregatorsmaterialized" => ReadFromAggregatorsMaterialized
case _ =>
throw new IllegalArgumentException(
s"""Illegal argument for `${SinglestoreOptions.PARALLEL_READ_FEATURES}` option. Valid arguments are:
| - "ReadFromLeaves"
| - "ReadFromAggregators"
| - "ReadFromAggregatorsMaterialized"""".stripMargin)
}
}
Loading

0 comments on commit 9c5ccf8

Please sign in to comment.