diff --git a/src/main/scala/com/singlestore/spark/SinglestoreLoadDataWriter.scala b/src/main/scala/com/singlestore/spark/SinglestoreLoadDataWriter.scala index 05c19354..360a1d34 100644 --- a/src/main/scala/com/singlestore/spark/SinglestoreLoadDataWriter.scala +++ b/src/main/scala/com/singlestore/spark/SinglestoreLoadDataWriter.scala @@ -51,26 +51,11 @@ class LoadDataWriterFactory(table: TableIdentifier, conf: SinglestoreOptions) def setNextLocalInfileInputStream(input: InputStream) } - def createDataWriter(schema: StructType, - partitionId: Int, - attemptNumber: Int, - isReferenceTable: Boolean, - mode: SaveMode): DataWriter[Row] = { - val basestream = new PipedOutputStream - val inputstream = new PipedInputStream(basestream, BUFFER_SIZE) - - val (ext, outputstream) = conf.loadDataCompression match { - case CompressionType.GZip => - // With gzip default 1 we get a 50% improvement in bandwidth - // (up to 16 Mps) over gzip default 6 on customer workload. - // - ("gz", new GZIPOutputStream(basestream) { { `def`.setLevel(1) } }) - - case CompressionType.LZ4 => - ("lz4", new LZ4FrameOutputStream(basestream)) - - case CompressionType.Skip => - ("tsv", basestream) + private def createLoadDataQuery(schema: StructType, mode: SaveMode, avroSchema: Schema): String = { + val ext = conf.loadDataCompression match { + case CompressionType.GZip => "gz" + case CompressionType.LZ4 => "lz4" + case CompressionType.Skip => "tsv" } def tempColName(colName: String) = s"@${colName}_tmp" @@ -108,10 +93,8 @@ class LoadDataWriterFactory(table: TableIdentifier, conf: SinglestoreOptions) } } val maxErrorsPart = s"MAX_ERRORS ${conf.maxErrors}" - var avroSchema: Schema = null val queryPrefix = s"LOAD DATA LOCAL INFILE '###.$ext'" val queryEnding = if (loadDataFormat == SinglestoreOptions.LoadDataFormat.Avro) { - avroSchema = SchemaConverters.toAvroType(schema) val nullableSchemas = for ((field, index) <- schema.fields.zipWithIndex) yield AvroSchemaHelper.resolveNullableType(avroSchema.getFields.get(index).schema(), @@ -136,13 +119,32 @@ class LoadDataWriterFactory(table: TableIdentifier, conf: SinglestoreOptions) .filter(s => !s.isEmpty) .mkString(" ") - val conn = SinglestoreConnectionPool.getConnection(if (isReferenceTable) { - getDDLConnProperties(conf, isOnExecutor = true) - } else { - getDMLConnProperties(conf, isOnExecutor = true) - }) + query + } + + private def createStreams(): (InputStream, OutputStream) = { + val basestream = new PipedOutputStream + val inputstream = new PipedInputStream(basestream, BUFFER_SIZE) + + val (ext, outputstream) = conf.loadDataCompression match { + case CompressionType.GZip => + // With gzip default 1 we get a 50% improvement in bandwidth + // (up to 16 Mps) over gzip default 6 on customer workload. + // + ("gz", new GZIPOutputStream(basestream) { { `def`.setLevel(1) } }) - val writer = Future[Long] { + case CompressionType.LZ4 => + ("lz4", new LZ4FrameOutputStream(basestream)) + + case CompressionType.Skip => + ("tsv", basestream) + } + + (inputstream, outputstream) + } + + private def startStatementExecution(conn: Connection, query: String, inputstream: InputStream): Future[Long] = { + Future[Long] { try { val stmt = conn.createStatement() try { @@ -159,22 +161,59 @@ class LoadDataWriterFactory(table: TableIdentifier, conf: SinglestoreOptions) } } finally { inputstream.close() - conn.close() } } + } + + def createDataWriter(schema: StructType, + partitionId: Int, + attemptNumber: Int, + isReferenceTable: Boolean, + mode: SaveMode): DataWriter[Row] = { + val avroSchema: Schema = if (conf.loadDataFormat == SinglestoreOptions.LoadDataFormat.Avro) { + SchemaConverters.toAvroType(schema) + } else { + null + } + val query = createLoadDataQuery(schema, mode, avroSchema) - if (loadDataFormat == SinglestoreOptions.LoadDataFormat.Avro) { - new AvroDataWriter(avroSchema, outputstream, writer, conn) + val conn = SinglestoreConnectionPool.getConnection(if (isReferenceTable) { + getDDLConnProperties(conf, isOnExecutor = true) } else { - new LoadDataWriter(outputstream, writer, conn) + getDMLConnProperties(conf, isOnExecutor = true) + }) + conn.setAutoCommit(false); + + val createDatabaseWriter: () => (OutputStream, Future[Long]) = () => { + val (inputstream, outputstream) = createStreams() + val writer = startStatementExecution(conn, query, inputstream) + (outputstream, writer) + } + + if (conf.loadDataFormat == SinglestoreOptions.LoadDataFormat.Avro) { + new AvroDataWriter(avroSchema, createDatabaseWriter, conn, conf.insertBatchSize) + } else { + new LoadDataWriter(createDatabaseWriter, conn, conf.insertBatchSize) } } } -class LoadDataWriter(outputstream: OutputStream, writeFuture: Future[Long], conn: Connection) +class LoadDataWriter(createDatabaseWriter: () => (OutputStream, Future[Long]), conn: Connection, batchSize: Int) extends DataWriter[Row] { + private var (outputstream, writeFuture) = createDatabaseWriter() + private var rowsInBatch = 0 + override def write(row: Row): Unit = { + if (rowsInBatch >= batchSize) { + Try(outputstream.close()) + Await.result(writeFuture, Duration.Inf) + val (newOutputStream, newWriteFuture) = createDatabaseWriter() + outputstream = newOutputStream + writeFuture = newWriteFuture + rowsInBatch = 0 + } + val rowLength = row.size for (i <- 0 until rowLength) { // We tried using off the shelf CSVWriter, but found it qualitatively slower. @@ -205,11 +244,14 @@ class LoadDataWriter(outputstream: OutputStream, writeFuture: Future[Long], conn outputstream.write(value) outputstream.write(if (i < rowLength - 1) '\t' else '\n') } + + rowsInBatch += 1 } override def commit(): WriterCommitMessage = { Try(outputstream.close()) Await.result(writeFuture, Duration.Inf) + conn.commit() new WriteSuccess } @@ -235,15 +277,18 @@ class LoadDataWriter(outputstream: OutputStream, writeFuture: Future[Long], conn } class AvroDataWriter(avroSchema: Schema, - outputstream: OutputStream, - writeFuture: Future[Long], - conn: Connection) + createDatabaseWriter: () => (OutputStream, Future[Long]), + conn: Connection, + batchSize: Int) extends DataWriter[Row] { + private var (outputstream, writeFuture) = createDatabaseWriter() + private var rowsInBatch = 0 + private var encoder = EncoderFactory.get().binaryEncoder(outputstream, null) + val datumWriter = new GenericDatumWriter[GenericRecord](avroSchema) - val encoder = EncoderFactory.get().binaryEncoder(outputstream, null) val record = new GenericData.Record(avroSchema) - + val conversionFunctions: Any => Any = { case d: java.math.BigDecimal => d.toString @@ -257,17 +302,31 @@ class AvroDataWriter(avroSchema: Schema, } override def write(row: Row): Unit = { + if (rowsInBatch >= batchSize) { + encoder.flush() + Try(outputstream.close()) + Await.result(writeFuture, Duration.Inf) + val (newOutputStream, newWriteFuture) = createDatabaseWriter() + outputstream = newOutputStream + writeFuture = newWriteFuture + encoder = EncoderFactory.get().binaryEncoder(outputstream, null) + rowsInBatch = 0 + } + val rowLength = row.size for (i <- 0 until rowLength) { record.put(i, conversionFunctions(row(i))) } datumWriter.write(record, encoder) + + rowsInBatch += 1 } override def commit(): WriterCommitMessage = { encoder.flush() Try(outputstream.close()) Await.result(writeFuture, Duration.Inf) + conn.commit() new WriteSuccess } diff --git a/src/test/scala/com/singlestore/spark/LoadDataTest.scala b/src/test/scala/com/singlestore/spark/LoadDataTest.scala index e20a991d..b6c6a1d1 100644 --- a/src/test/scala/com/singlestore/spark/LoadDataTest.scala +++ b/src/test/scala/com/singlestore/spark/LoadDataTest.scala @@ -251,4 +251,52 @@ class LoadDataTest extends IntegrationSuiteBase with BeforeAndAfterEach with Bef case e: Exception if e.getMessage.contains("Unknown column 'age' in 'field list'") => } } + + it("load data in batches") { + val df1 = spark.createDF( + List( + (5, "Jack", 20), + (6, "Mark", 30), + (7, "Fred", 15), + (8, "Jany", 40), + (9, "Monica", 5) + ), + List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true)) + ) + val df2 = spark.createDF( + List( + (10, "Jany", 40), + (11, "Monica", 5) + ), + List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true)) + ) + val df3 = spark.createDF( + List(), + List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true)) + ) + + df1.write + .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT) + .option("insertBatchSize", 2) + .mode(SaveMode.Append) + .save("testdb.loadDataBatches") + df2.write + .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT) + .option("insertBatchSize", 2) + .mode(SaveMode.Append) + .save("testdb.loadDataBatches") + df3.write + .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT) + .option("insertBatchSize", 2) + .mode(SaveMode.Append) + .save("testdb.loadDataBatches") + + val actualDF = + spark.read.format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT).load("testdb.loadDataBatches") + assertSmallDataFrameEquality( + actualDF, + df1.union(df2).union(df3), + orderedComparison = false + ) + } }