Skip to content

Commit

Permalink
Add support for batched queries to nf-sqldb plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
pditommaso committed Apr 9, 2022
1 parent 6a6a6ea commit 3a9dad8
Show file tree
Hide file tree
Showing 6 changed files with 237 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,20 @@ import nextflow.util.CheckHelper
@Scoped('sql')
class ChannelSqlExtension extends ChannelExtensionPoint {

private static final Map QUERY_PARAMS = [db: CharSequence, emitColumns: Boolean]
private static final Map QUERY_PARAMS = [
db: CharSequence,
emitColumns: Boolean,
batchSize: Integer,
batchDelay: Integer
]

private static final Map INSERT_PARAMS = [
db: CharSequence,
into: CharSequence,
columns: [CharSequence, List],
statement: CharSequence,
batch: Integer,
batch: Integer, // deprecated
batchSize: Integer,
setup: CharSequence
]

Expand Down Expand Up @@ -61,7 +67,7 @@ class ChannelSqlExtension extends ChannelExtensionPoint {
.withDataSource(dataSource)
.withStatement(query)
.withTarget(channel)
.withEmitColumns( (opts?.emitColumns ?: false) as boolean )
.withOpts(opts)
if(NF.dsl2) {
session.addIgniter {-> handler.perform(true) }
}
Expand Down
3 changes: 3 additions & 0 deletions plugins/nf-sqldb/src/main/nextflow/sql/InsertHandler.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ class InsertHandler implements Closeable {
this.columns = cols0(this.opts.columns)
this.sqlStatement = this.opts.statement
this.batchSize = this.opts.batch ? this.opts.batch as int : DEFAULT_BATCH_SIZE
this.batchSize = this.opts.batchSize ? this.opts.batchSize as int : DEFAULT_BATCH_SIZE
this.setupStatement = this.opts.setup
if( this.opts.batch )
log.warn "The option 'batch' for the 'sqlInsert' operator has been deprecated - Use 'batchSize' instead"
if( batchSize<1 )
throw new IllegalArgumentException("SQL batch option must be greater than zero: $batchSize")
}
Expand Down
132 changes: 106 additions & 26 deletions plugins/nf-sqldb/src/main/nextflow/sql/QueryHandler.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package nextflow.sql

import java.sql.Connection
import java.sql.PreparedStatement
import java.sql.ResultSet
import java.sql.Statement
import java.util.concurrent.CompletableFuture
Expand All @@ -30,15 +31,14 @@ import nextflow.Channel
import nextflow.Global
import nextflow.Session
import nextflow.sql.config.SqlDataSource

/**
* Implement the logic for query a DB in async manner
*
* @author Paolo Di Tommaso <paolo.ditommaso@gmail.com>
*/
@Slf4j
@CompileStatic
class QueryHandler implements QueryOp {
class QueryHandler implements QueryOp<QueryHandler> {

private static Map<String,Class<?>> type_mapping = [:]

Expand Down Expand Up @@ -68,6 +68,9 @@ class QueryHandler implements QueryOp {
private String statement
private SqlDataSource dataSource
private boolean emitColumns = false
private Integer batchSize
private long batchDelayMillis = 100
private int queryCount

@Override
QueryOp withStatement(String stm) {
Expand All @@ -87,19 +90,32 @@ class QueryHandler implements QueryOp {
return this
}

@Override
QueryOp withEmitColumns(boolean emitColumns) {
this.emitColumns = emitColumns
QueryOp withOpts(Map opts) {
if( opts.emitColumns )
this.emitColumns = opts.emitColumns as boolean
if( opts.batchSize )
this.batchSize = opts.batchSize as Integer
if( opts.batchDelay )
this.batchDelayMillis = opts.batchDelay as long
return this
}

int batchSize() {
return batchSize
}

int queryCount() {
return queryCount
}

@Override
void perform(boolean async=false) {
QueryHandler perform(boolean async=false) {
final conn = connect(dataSource ?: SqlDataSource.DEFAULT)
if( async )
queryAsync(conn)
else
query0(conn)
queryExec(conn)
return this
}

protected Connection connect(SqlDataSource ds) {
Expand All @@ -117,7 +133,7 @@ class QueryHandler implements QueryOp {
}

protected queryAsync(Connection conn) {
def future = CompletableFuture.runAsync ({ query0(conn) })
def future = CompletableFuture.runAsync ({ queryExec(conn) })
future.exceptionally(this.&handlerException)
}

Expand All @@ -128,11 +144,22 @@ class QueryHandler implements QueryOp {
session?.abort(error)
}

protected void queryExec(Connection conn) {
if( batchSize ) {
query1(conn)
}
else {
query0(conn)
}
}

protected void query0(Connection conn) {
try {
try (Statement stm = conn.createStatement()) {
try( def rs = stm.executeQuery(normalize(statement)) ) {
emitRows0(rs)
if( emitColumns )
emitColumns(rs)
emitRowsAndClose(rs)
}
}
}
Expand All @@ -141,28 +168,81 @@ class QueryHandler implements QueryOp {
}
}

protected emitRows0(ResultSet rs) {
protected void query1(Connection conn) {
try {
final meta = rs.getMetaData()
final cols = meta.getColumnCount()

if( emitColumns ){
def item = new ArrayList(cols)
for( int i=0; i<cols; i++) {
item[i] = meta.getColumnName(i+1)
// create the query adding the `offset` and `limit` params
final query = makePaginationStm(statement)
// create the prepared statement
try (PreparedStatement stm = conn.prepareStatement(query)) {
int count = 0
int len = 0
do {
final offset = (count++) * batchSize
final limit = batchSize

stm.setInt(1, limit)
stm.setInt(2, offset)
queryCount++
try ( def rs = stm.executeQuery() ) {
if( emitColumns && count==1 )
emitColumns(rs)
len = emitRows(rs)
sleep(batchDelayMillis)
}
}
// emit the value
target.bind(item)
while( len==batchSize )
}
finally {
// close the channel
target.bind(Channel.STOP)
}
}
finally {
conn.close()
}
}

while( rs.next() ) {
def item = new ArrayList(cols)
for( int i=0; i<cols; i++) {
item[i] = rs.getObject(i+1)
}
// emit the value
target.bind(item)
protected String makePaginationStm(String sql) {
if( sql.toUpperCase().contains('LIMIT') )
throw new IllegalArgumentException("Sql query should not include the LIMIT statement when pageSize is specified: $sql")
if( sql.toUpperCase().contains('OFFSET') )
throw new IllegalArgumentException("Sql query should not include the OFFSET statement when pageSize is specified: $sql")

return sql.stripEnd(' ;') + " LIMIT ? OFFSET ?;"
}

protected emitColumns(ResultSet rs) {
final meta = rs.getMetaData()
final cols = meta.getColumnCount()

def item = new ArrayList(cols)
for( int i=0; i<cols; i++) {
item[i] = meta.getColumnName(i+1)
}
// emit the value
target.bind(item)
}

protected int emitRows(ResultSet rs) {
final meta = rs.getMetaData()
final cols = meta.getColumnCount()

int count=0
while( rs.next() ) {
count++
def item = new ArrayList(cols)
for( int i=0; i<cols; i++) {
item[i] = rs.getObject(i+1)
}
// emit the value
target.bind(item)
}
return count
}

protected int emitRowsAndClose(ResultSet rs) {
try {
emitRows(rs)
}
finally {
// close the channel
Expand Down
9 changes: 5 additions & 4 deletions plugins/nf-sqldb/src/main/nextflow/sql/QueryOp.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@ import nextflow.sql.config.SqlDataSource
*
* @author Paolo Di Tommaso <paolo.ditommaso@gmail.com>
*/
interface QueryOp {
interface QueryOp<T extends QueryOp> {

QueryOp withStatement(String stm)
QueryOp withTarget(DataflowWriteChannel channel)
QueryOp withDataSource(SqlDataSource ds)
QueryOp withEmitColumns(boolean headers)
void perform()
void perform(boolean async)
QueryOp withOpts(Map options)

T perform()
T perform(boolean async)
}
83 changes: 81 additions & 2 deletions plugins/nf-sqldb/src/test/nextflow/sql/QueryHandlerTest.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ package nextflow.sql

import java.nio.file.Files

import groovy.sql.Sql
import groovyx.gpars.dataflow.DataflowQueue
import nextflow.Channel
import nextflow.sql.config.SqlDataSource
import spock.lang.Specification

/**
*
* @author Paolo Di Tommaso <paolo.ditommaso@gmail.com>
Expand Down Expand Up @@ -81,7 +81,7 @@ class QueryHandlerTest extends Specification {
new QueryHandler()
.withTarget(result)
.withStatement(query)
.withEmitColumns(true)
.withOpts(emitColumns: true)
.perform()

then:
Expand All @@ -93,4 +93,83 @@ class QueryHandlerTest extends Specification {
cleanup:
folder.deleteDir()
}

def 'should append limit and offset' () {
given:
def ext = new QueryHandler()

expect:
ext.makePaginationStm('select * from FOO') == 'select * from FOO LIMIT ? OFFSET ?;'
ext.makePaginationStm('select * from FOO ; ') == 'select * from FOO LIMIT ? OFFSET ?;'

when:
ext.makePaginationStm('select * from offset')
then:
thrown(IllegalArgumentException)

when:
ext.makePaginationStm('select * from limit')
then:
thrown(IllegalArgumentException)
}

def 'should test paginated query' () {
given:
def JDBC_URL = 'jdbc:h2:mem:test_' + Random.newInstance().nextInt(10_000)
def TABLE = 'create table FOO(id int primary key, alpha varchar(255));'
def ds = new SqlDataSource([url:JDBC_URL])
and:
def sql = Sql.newInstance(JDBC_URL, 'sa', null)
sql.execute(TABLE)
for( int x : 1..13 ) {
def params = [x, "Hello $x".toString()]
sql.execute("insert into FOO (id, alpha) values (?,?);", params)
}

when:
def result = new DataflowQueue()
def query = "SELECT id, alpha from FOO order by id "
def handler = new QueryHandler()
.withTarget(result)
.withStatement(query)
.withDataSource(ds)
.withOpts(batchSize: 5)
.perform()
then:
handler.batchSize() == 5
handler.queryCount() == 3
and:
result.length() == 14 // <-- 13 + the stop signal value
and:
result.getVal() == [1, 'Hello 1']
result.getVal() == [2, 'Hello 2']
result.getVal() == [3, 'Hello 3']
result.getVal() == [4, 'Hello 4']
result.getVal() == [5, 'Hello 5']
result.getVal() == [6, 'Hello 6']
result.getVal() == [7, 'Hello 7']
result.getVal() == [8, 'Hello 8']
result.getVal() == [9, 'Hello 9']
result.getVal() == [10, 'Hello 10']
result.getVal() == [11, 'Hello 11']
result.getVal() == [12, 'Hello 12']
result.getVal() == [13, 'Hello 13']
result.getVal() == Channel.STOP

when:
def result2 = new DataflowQueue()
def query2 = "SELECT id, alpha from FOO order by id "
new QueryHandler()
.withTarget(result2)
.withStatement(query2)
.withDataSource(ds)
.withOpts(batchSize: 5, emitColumns: true)
.perform()
then:
result2.length() == 15 // <-- 13 + columns name tuple + the stop signal value
and:
result2.getVal() == ['ID', 'ALPHA']
result2.getVal() == [1, 'Hello 1']
result2.getVal() == [2, 'Hello 2']
}
}

0 comments on commit 3a9dad8

Please sign in to comment.