Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changed LoadDataWriter to load data in batches #81

Merged
merged 1 commit into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 97 additions & 38 deletions src/main/scala/com/singlestore/spark/SinglestoreLoadDataWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,26 +51,11 @@ class LoadDataWriterFactory(table: TableIdentifier, conf: SinglestoreOptions)
def setNextLocalInfileInputStream(input: InputStream)
}

def createDataWriter(schema: StructType,
partitionId: Int,
attemptNumber: Int,
isReferenceTable: Boolean,
mode: SaveMode): DataWriter[Row] = {
val basestream = new PipedOutputStream
val inputstream = new PipedInputStream(basestream, BUFFER_SIZE)

val (ext, outputstream) = conf.loadDataCompression match {
case CompressionType.GZip =>
// With gzip default 1 we get a 50% improvement in bandwidth
// (up to 16 Mps) over gzip default 6 on customer workload.
//
("gz", new GZIPOutputStream(basestream) { { `def`.setLevel(1) } })

case CompressionType.LZ4 =>
("lz4", new LZ4FrameOutputStream(basestream))

case CompressionType.Skip =>
("tsv", basestream)
private def createLoadDataQuery(schema: StructType, mode: SaveMode, avroSchema: Schema): String = {
val ext = conf.loadDataCompression match {
case CompressionType.GZip => "gz"
case CompressionType.LZ4 => "lz4"
case CompressionType.Skip => "tsv"
}

def tempColName(colName: String) = s"@${colName}_tmp"
Expand Down Expand Up @@ -108,10 +93,8 @@ class LoadDataWriterFactory(table: TableIdentifier, conf: SinglestoreOptions)
}
}
val maxErrorsPart = s"MAX_ERRORS ${conf.maxErrors}"
var avroSchema: Schema = null
val queryPrefix = s"LOAD DATA LOCAL INFILE '###.$ext'"
val queryEnding = if (loadDataFormat == SinglestoreOptions.LoadDataFormat.Avro) {
avroSchema = SchemaConverters.toAvroType(schema)
val nullableSchemas = for ((field, index) <- schema.fields.zipWithIndex)
yield
AvroSchemaHelper.resolveNullableType(avroSchema.getFields.get(index).schema(),
Expand All @@ -136,13 +119,32 @@ class LoadDataWriterFactory(table: TableIdentifier, conf: SinglestoreOptions)
.filter(s => !s.isEmpty)
.mkString(" ")

val conn = SinglestoreConnectionPool.getConnection(if (isReferenceTable) {
getDDLConnProperties(conf, isOnExecutor = true)
} else {
getDMLConnProperties(conf, isOnExecutor = true)
})
query
}

private def createStreams(): (InputStream, OutputStream) = {
val basestream = new PipedOutputStream
val inputstream = new PipedInputStream(basestream, BUFFER_SIZE)

val (ext, outputstream) = conf.loadDataCompression match {
case CompressionType.GZip =>
// With gzip default 1 we get a 50% improvement in bandwidth
// (up to 16 Mps) over gzip default 6 on customer workload.
//
("gz", new GZIPOutputStream(basestream) { { `def`.setLevel(1) } })

val writer = Future[Long] {
case CompressionType.LZ4 =>
("lz4", new LZ4FrameOutputStream(basestream))

case CompressionType.Skip =>
("tsv", basestream)
}

(inputstream, outputstream)
}

private def startStatementExecution(conn: Connection, query: String, inputstream: InputStream): Future[Long] = {
Future[Long] {
try {
val stmt = conn.createStatement()
try {
Expand All @@ -159,22 +161,59 @@ class LoadDataWriterFactory(table: TableIdentifier, conf: SinglestoreOptions)
}
} finally {
inputstream.close()
conn.close()
}
}
}

def createDataWriter(schema: StructType,
partitionId: Int,
attemptNumber: Int,
isReferenceTable: Boolean,
mode: SaveMode): DataWriter[Row] = {
val avroSchema: Schema = if (conf.loadDataFormat == SinglestoreOptions.LoadDataFormat.Avro) {
SchemaConverters.toAvroType(schema)
} else {
null
}
val query = createLoadDataQuery(schema, mode, avroSchema)

if (loadDataFormat == SinglestoreOptions.LoadDataFormat.Avro) {
new AvroDataWriter(avroSchema, outputstream, writer, conn)
val conn = SinglestoreConnectionPool.getConnection(if (isReferenceTable) {
getDDLConnProperties(conf, isOnExecutor = true)
} else {
new LoadDataWriter(outputstream, writer, conn)
getDMLConnProperties(conf, isOnExecutor = true)
})
conn.setAutoCommit(false);

val createDatabaseWriter: () => (OutputStream, Future[Long]) = () => {
val (inputstream, outputstream) = createStreams()
val writer = startStatementExecution(conn, query, inputstream)
(outputstream, writer)
}

if (conf.loadDataFormat == SinglestoreOptions.LoadDataFormat.Avro) {
new AvroDataWriter(avroSchema, createDatabaseWriter, conn, conf.insertBatchSize)
} else {
new LoadDataWriter(createDatabaseWriter, conn, conf.insertBatchSize)
}
}
}

class LoadDataWriter(outputstream: OutputStream, writeFuture: Future[Long], conn: Connection)
class LoadDataWriter(createDatabaseWriter: () => (OutputStream, Future[Long]), conn: Connection, batchSize: Int)
extends DataWriter[Row] {

private var (outputstream, writeFuture) = createDatabaseWriter()
private var rowsInBatch = 0

override def write(row: Row): Unit = {
if (rowsInBatch >= batchSize) {
Try(outputstream.close())
Await.result(writeFuture, Duration.Inf)
val (newOutputStream, newWriteFuture) = createDatabaseWriter()
outputstream = newOutputStream
writeFuture = newWriteFuture
rowsInBatch = 0
}

val rowLength = row.size
for (i <- 0 until rowLength) {
// We tried using off the shelf CSVWriter, but found it qualitatively slower.
Expand Down Expand Up @@ -205,11 +244,14 @@ class LoadDataWriter(outputstream: OutputStream, writeFuture: Future[Long], conn
outputstream.write(value)
outputstream.write(if (i < rowLength - 1) '\t' else '\n')
}

rowsInBatch += 1
}

override def commit(): WriterCommitMessage = {
Try(outputstream.close())
Await.result(writeFuture, Duration.Inf)
conn.commit()
new WriteSuccess
}

Expand All @@ -235,15 +277,18 @@ class LoadDataWriter(outputstream: OutputStream, writeFuture: Future[Long], conn
}

class AvroDataWriter(avroSchema: Schema,
outputstream: OutputStream,
writeFuture: Future[Long],
conn: Connection)
createDatabaseWriter: () => (OutputStream, Future[Long]),
conn: Connection,
batchSize: Int)
extends DataWriter[Row] {

private var (outputstream, writeFuture) = createDatabaseWriter()
private var rowsInBatch = 0
private var encoder = EncoderFactory.get().binaryEncoder(outputstream, null)

val datumWriter = new GenericDatumWriter[GenericRecord](avroSchema)
val encoder = EncoderFactory.get().binaryEncoder(outputstream, null)
val record = new GenericData.Record(avroSchema)

val conversionFunctions: Any => Any = {
case d: java.math.BigDecimal =>
d.toString
Expand All @@ -257,17 +302,31 @@ class AvroDataWriter(avroSchema: Schema,
}

override def write(row: Row): Unit = {
if (rowsInBatch >= batchSize) {
encoder.flush()
Try(outputstream.close())
Await.result(writeFuture, Duration.Inf)
val (newOutputStream, newWriteFuture) = createDatabaseWriter()
outputstream = newOutputStream
writeFuture = newWriteFuture
encoder = EncoderFactory.get().binaryEncoder(outputstream, null)
rowsInBatch = 0
}

val rowLength = row.size
for (i <- 0 until rowLength) {
record.put(i, conversionFunctions(row(i)))
}
datumWriter.write(record, encoder)

rowsInBatch += 1
}

override def commit(): WriterCommitMessage = {
encoder.flush()
Try(outputstream.close())
Await.result(writeFuture, Duration.Inf)
conn.commit()
new WriteSuccess
}

Expand Down
48 changes: 48 additions & 0 deletions src/test/scala/com/singlestore/spark/LoadDataTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -251,4 +251,52 @@ class LoadDataTest extends IntegrationSuiteBase with BeforeAndAfterEach with Bef
case e: Exception if e.getMessage.contains("Unknown column 'age' in 'field list'") =>
}
}

it("load data in batches") {
val df1 = spark.createDF(
List(
(5, "Jack", 20),
(6, "Mark", 30),
(7, "Fred", 15),
(8, "Jany", 40),
(9, "Monica", 5)
),
List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true))
)
val df2 = spark.createDF(
List(
(10, "Jany", 40),
(11, "Monica", 5)
),
List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true))
)
val df3 = spark.createDF(
List(),
List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true))
)

df1.write
.format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT)
.option("insertBatchSize", 2)
.mode(SaveMode.Append)
.save("testdb.loadDataBatches")
df2.write
.format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT)
.option("insertBatchSize", 2)
.mode(SaveMode.Append)
.save("testdb.loadDataBatches")
df3.write
.format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT)
.option("insertBatchSize", 2)
.mode(SaveMode.Append)
.save("testdb.loadDataBatches")

val actualDF =
spark.read.format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT).load("testdb.loadDataBatches")
assertSmallDataFrameEquality(
actualDF,
df1.union(df2).union(df3),
orderedComparison = false
)
}
}