Skip to content

Commit

Permalink
Merge pull request #60 from killbill/jdbi-BatchExecuteAndGenerateKeys
Browse files Browse the repository at this point in the history
jdbi: batch execute and generate keys
  • Loading branch information
pierre committed Feb 13, 2019
2 parents ac58ee0 + 89d1322 commit 50d8f2c
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 59 deletions.
6 changes: 5 additions & 1 deletion jdbi/src/main/java/org/skife/jdbi/v2/GeneratedKeys.java
Expand Up @@ -98,7 +98,11 @@ public <T> T first(Class<T> containerType)
public <ContainerType> ContainerType list(Class<ContainerType> containerType)
{
// return containerFactoryRegistry.lookup(containerType).create(Arrays.asList(list()));
throw new UnsupportedOperationException("Not Yet Implemented!");
if (containerType.isAssignableFrom(List.class)) {
return (ContainerType) list();
} else {
throw new UnsupportedOperationException("Not Yet Implemented!");
}
}

@Override
Expand Down
52 changes: 43 additions & 9 deletions jdbi/src/main/java/org/skife/jdbi/v2/PreparedBatch.java
@@ -1,6 +1,4 @@
/*
* Copyright (C) 2004 - 2014 Brian McCallister
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
Expand All @@ -17,15 +15,18 @@

import org.skife.jdbi.v2.exceptions.UnableToCreateStatementException;
import org.skife.jdbi.v2.exceptions.UnableToExecuteStatementException;
import org.skife.jdbi.v2.tweak.ResultSetMapper;
import org.skife.jdbi.v2.tweak.RewrittenStatement;
import org.skife.jdbi.v2.tweak.SQLLog;
import org.skife.jdbi.v2.tweak.StatementBuilder;
import org.skife.jdbi.v2.tweak.StatementCustomizer;
import org.skife.jdbi.v2.tweak.StatementLocator;
import org.skife.jdbi.v2.tweak.StatementRewriter;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
Expand Down Expand Up @@ -93,10 +94,32 @@ public PreparedBatch define(final Map<String, ? extends Object> values)
*
* @return the number of rows modified or inserted per batch part.
*/
public int[] execute()
{
public int[] execute() {
return (int[]) internalBatchExecute(null, null);
}

@SuppressWarnings("unchecked")
public <GeneratedKeyType> GeneratedKeys<GeneratedKeyType> executeAndGenerateKeys(final ResultSetMapper<GeneratedKeyType> mapper,
String... columnNames) {
return (GeneratedKeys<GeneratedKeyType>) internalBatchExecute(new QueryResultMunger<GeneratedKeys<GeneratedKeyType>>() {
public GeneratedKeys<GeneratedKeyType> munge(Statement results) throws SQLException {
return new GeneratedKeys<GeneratedKeyType>(mapper,
PreparedBatch.this,
results,
getContext(),
getContainerMapperRegistry());
}
}, columnNames);

}

private <Result> Object internalBatchExecute(QueryResultMunger<Result> munger, String[] columnNames) {
boolean generateKeys = munger != null;
// short circuit empty batch
if (parts.size() == 0) {
if (generateKeys) {
throw new IllegalArgumentException("Unable generate keys for a not prepared batch");
}
return new int[]{};
}

Expand All @@ -109,11 +132,20 @@ public int[] execute()
throw new UnableToCreateStatementException(String.format("Exception while locating statement for [%s]",
getSql()), e, getContext());
}
final RewrittenStatement rewritten = getRewriter().rewrite(my_sql, current.getParameters(), getContext());
final RewrittenStatement rewritten = getRewriter().rewrite(my_sql, current.getParams(), getContext());
PreparedStatement stmt = null;
try {
try {
stmt = getHandle().getConnection().prepareStatement(rewritten.getSql());
Connection connection = getHandle().getConnection();
if (generateKeys) {
if (columnNames != null) {
stmt = connection.prepareStatement(rewritten.getSql(), columnNames);
} else {
stmt = connection.prepareStatement(rewritten.getSql(), Statement.RETURN_GENERATED_KEYS);
}
} else {
stmt = connection.prepareStatement(rewritten.getSql(), Statement.NO_GENERATED_KEYS);
}
addCleanable(Cleanables.forStatement(stmt));
}
catch (SQLException e) {
Expand All @@ -123,7 +155,7 @@ public int[] execute()

try {
for (PreparedBatchPart part : parts) {
rewritten.bind(part.getParameters(), stmt);
rewritten.bind(part.getParams(), stmt);
stmt.addBatch();
}
}
Expand All @@ -142,15 +174,17 @@ public int[] execute()

afterExecution(stmt);

return rs;
return generateKeys ? munger.munge(stmt) : rs;
}
catch (SQLException e) {
throw new UnableToExecuteStatementException(e, getContext());
}
}
finally {
try {
cleanup();
if (!generateKeys) {
cleanup();
}
}
finally {
this.parts.clear();
Expand Down
167 changes: 118 additions & 49 deletions jdbi/src/main/java/org/skife/jdbi/v2/sqlobject/BatchHandler.java
@@ -1,6 +1,4 @@
/*
* Copyright (C) 2004 - 2014 Brian McCallister
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
Expand All @@ -19,37 +17,84 @@
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;

import net.sf.cglib.proxy.MethodProxy;

import org.skife.jdbi.v2.ConcreteStatementContext;
import org.skife.jdbi.v2.GeneratedKeys;
import org.skife.jdbi.v2.Handle;
import org.skife.jdbi.v2.PreparedBatch;
import org.skife.jdbi.v2.PreparedBatchPart;
import org.skife.jdbi.v2.StatementContext;
import org.skife.jdbi.v2.TransactionCallback;
import org.skife.jdbi.v2.TransactionStatus;
import org.skife.jdbi.v2.exceptions.DBIException;
import org.skife.jdbi.v2.exceptions.UnableToCreateStatementException;
import org.skife.jdbi.v2.exceptions.UnableToExecuteStatementException;
import org.skife.jdbi.v2.sqlobject.customizers.BatchChunkSize;
import org.skife.jdbi.v2.tweak.ResultSetMapper;

import com.fasterxml.classmate.members.ResolvedMethod;

import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import net.sf.cglib.proxy.MethodProxy;

class BatchHandler extends CustomizingStatementHandler
{
private final String sql;
private final boolean transactional;
private final ChunkSizeFunction batchChunkSize;
private final Returner returner;

public BatchHandler(Class<?> sqlObjectType, ResolvedMethod method)
{
BatchHandler(Class<?> sqlObjectType, ResolvedMethod method) {
super(sqlObjectType, method);

Method raw_method = method.getRawMember();
SqlBatch anno = raw_method.getAnnotation(SqlBatch.class);
this.sql = SqlObject.getSql(anno, raw_method);
this.transactional = anno.transactional();
this.batchChunkSize = determineBatchChunkSize(sqlObjectType, raw_method);
final GetGeneratedKeys getGeneratedKeys = raw_method.getAnnotation(GetGeneratedKeys.class);
if (getGeneratedKeys == null) {
if (!returnTypeIsValid(method.getRawMember().getReturnType())) {
throw new DBIException(invalidReturnTypeMessage(method)) {};
}
returner = new Returner() {
@Override
public Object value(PreparedBatch batch, HandleDing baton)
{
return batch.execute();
}
};
} else {
final ResultSetMapper mapper;
try {
mapper = getGeneratedKeys.value().newInstance();
} catch (Exception e) {
throw new UnableToCreateStatementException("Unable to instantiate result set mapper for statement", e);
}
final ResultReturnThing magic = ResultReturnThing.forType(method);
if (getGeneratedKeys.columnName().isEmpty()) {
returner = new Returner() {
@Override
public Object value(PreparedBatch batch, HandleDing baton)
{
GeneratedKeys o = batch.executeAndGenerateKeys(mapper);
return magic.result(o, baton);
}
};
} else {
returner = new Returner() {
@Override
public Object value(PreparedBatch batch, HandleDing baton)
{
String columnName = getGeneratedKeys.columnName();
GeneratedKeys o = batch.executeAndGenerateKeys(mapper, columnName);
return magic.result(o, baton);
}
};
}
}
}

private ChunkSizeFunction determineBatchChunkSize(Class<?> sqlObjectType, Method raw_method)
Expand Down Expand Up @@ -95,50 +140,57 @@ private int findBatchChunkSizeFromParam(Method raw_method)
@Override
public Object invoke(HandleDing h, Object target, Object[] args, MethodProxy mp)
{
boolean foundIterator = false;
Handle handle = h.getHandle();

List<Iterator> extras = new ArrayList<Iterator>();
for (final Object arg : args) {
if (arg instanceof Iterable) {
extras.add(((Iterable) arg).iterator());
foundIterator = true;
}
else if (arg instanceof Iterator) {
extras.add((Iterator) arg);
foundIterator = true;
}
else if (arg.getClass().isArray()) {
extras.add(Arrays.asList((Object[])arg).iterator());
foundIterator = true;
}
else {
extras.add(new Iterator()
{
@Override
public boolean hasNext()
{
return true;
}
{
@Override
public boolean hasNext()
{
return true;
}

@Override
@SuppressFBWarnings("IT_NO_SUCH_ELEMENT")
public Object next()
{
return arg;
}
@Override
@SuppressFBWarnings("IT_NO_SUCH_ELEMENT")
public Object next()
{
return arg;
}

@Override
public void remove()
{
// NOOP
}
}
);
@Override
public void remove()
{
// NOOP
}
}
);
}
}

if (!foundIterator) {
throw new UnableToExecuteStatementException("@SqlBatch must have at least one iterable parameter", (StatementContext)null);
}

int processed = 0;
List<int[]> rs_parts = new ArrayList<int[]>();
List<Object> results = new LinkedList<Object>();

PreparedBatch batch = handle.prepareBatch(sql);
populateSqlObjectData((ConcreteStatementContext) batch.getContext());
applyCustomizers(batch, args);
Object[] _args;
int chunk_size = batchChunkSize.call(args);
Expand All @@ -150,47 +202,45 @@ public void remove()
if (++processed == chunk_size) {
// execute this chunk
processed = 0;
rs_parts.add(executeBatch(handle, batch));
executeBatch(results, h, handle, batch);
batch = handle.prepareBatch(sql);
populateSqlObjectData((ConcreteStatementContext) batch.getContext());
applyCustomizers(batch, args);
}
}

//execute the rest
rs_parts.add(executeBatch(handle, batch));

// combine results
int end_size = 0;
for (int[] rs_part : rs_parts) {
end_size += rs_part.length;
}
int[] rs = new int[end_size];
int offset = 0;
for (int[] rs_part : rs_parts) {
System.arraycopy(rs_part, 0, rs, offset, rs_part.length);
offset += rs_part.length;
if (batch.getSize() > 0) {
executeBatch(results, h, handle, batch);
}

return rs;
return results;
}

private void executeBatch(final Collection<Object> results, final HandleDing h, final Handle handle, final PreparedBatch batch) {
final Object res = executeBatch(h, handle, batch);
if (res instanceof Collection) {
results.addAll((Collection) res);
} else {
results.add(res);
}
}

private int[] executeBatch(final Handle handle, final PreparedBatch batch)
private Object executeBatch(final HandleDing h, final Handle handle, final PreparedBatch batch)
{
if (!handle.isInTransaction() && transactional) {
// it is safe to use same prepared batch as the inTransaction passes in the same
// Handle instance.
return handle.inTransaction(new TransactionCallback<int[]>()
return handle.inTransaction(new TransactionCallback<Object>()
{
@Override
public int[] inTransaction(Handle conn, TransactionStatus status) throws Exception
public Object inTransaction(Handle conn, TransactionStatus status) throws Exception
{
return batch.execute();
return returner.value(batch, h);
}
});
}
else {
return batch.execute();
return returner.value(batch, h);
}
}

Expand All @@ -208,6 +258,11 @@ private static Object[] next(List<Iterator> args)
return rs.toArray();
}

private interface Returner
{
Object value(PreparedBatch batch, HandleDing baton);
}

private interface ChunkSizeFunction
{
int call(Object[] args);
Expand Down Expand Up @@ -242,4 +297,18 @@ public int call(Object[] args)
return (Integer)args[index];
}
}

private static boolean returnTypeIsValid(Class<?> type) {
if (type.equals(Void.TYPE) || type.isArray() && type.getComponentType().equals(Integer.TYPE)) {
return true;
}

return false;
}

private static String invalidReturnTypeMessage(ResolvedMethod method) {
return method.getDeclaringType() + "." + method +
" method is annotated with @SqlBatch so should return void or int[] but is returning: " +
method.getReturnType();
}
}

0 comments on commit 50d8f2c

Please sign in to comment.