Skip to content

Commit

Permalink
[CONJ-976] permit pipelining for batching even when `allowLocalInfile…
Browse files Browse the repository at this point in the history
…` option is enable
  • Loading branch information
rusher committed May 25, 2022
1 parent de7ca6f commit e6a29d0
Show file tree
Hide file tree
Showing 11 changed files with 121 additions and 67 deletions.
25 changes: 17 additions & 8 deletions src/main/java/org/mariadb/jdbc/ClientPreparedStatement.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@

package org.mariadb.jdbc;

import static org.mariadb.jdbc.util.constants.Capabilities.*;

import java.sql.*;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.concurrent.locks.ReentrantLock;
import org.mariadb.jdbc.client.Completion;
import org.mariadb.jdbc.client.result.CompleteResult;
Expand All @@ -20,7 +23,6 @@
import org.mariadb.jdbc.message.server.PrepareResultPacket;
import org.mariadb.jdbc.util.ClientParser;
import org.mariadb.jdbc.util.ParameterList;
import org.mariadb.jdbc.util.constants.Capabilities;
import org.mariadb.jdbc.util.constants.ServerStatus;

/**
Expand Down Expand Up @@ -106,17 +108,24 @@ private void executeInternal() throws SQLException {

private void executeInternalPreparedBatch() throws SQLException {
checkNotClosed();
long serverCapabilities = con.getContext().getServerCapabilities();
if (autoGeneratedKeys != Statement.RETURN_GENERATED_KEYS
&& batchParameters.size() > 1
&& (serverCapabilities & Capabilities.MARIADB_CLIENT_STMT_BULK_OPERATIONS) > 0
&& con.getContext().getConf().useBulkStmts()) {
&& con.getContext().hasClientCapability(STMT_BULK_OPERATIONS)) {
executeBatchBulk();
} else if (!con.getContext().getConf().allowLocalInfile()
|| (serverCapabilities & Capabilities.LOCAL_FILES) == 0) {
executeBatchPipeline();
} else {
executeBatchStd();
boolean possibleLoadLocal = con.getContext().hasClientCapability(LOCAL_FILES);
if (possibleLoadLocal) {
String sqlUpper = sql.toUpperCase(Locale.ROOT);
possibleLoadLocal =
sqlUpper.contains(" LOCAL ")
&& sqlUpper.contains("LOAD")
&& sqlUpper.contains(" INFILE");
}
if (possibleLoadLocal) {
executeBatchStd();
} else {
executeBatchPipeline();
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/main/java/org/mariadb/jdbc/Connection.java
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ public void setReadOnly(boolean readOnly) throws SQLException {
@Override
public String getCatalog() throws SQLException {

if ((client.getContext().getServerCapabilities() & Capabilities.CLIENT_SESSION_TRACK) != 0) {
if (client.getContext().hasClientCapability(Capabilities.CLIENT_SESSION_TRACK)) {
return client.getContext().getDatabase();
}

Expand All @@ -308,7 +308,7 @@ public String getCatalog() throws SQLException {

@Override
public void setCatalog(String catalog) throws SQLException {
if ((client.getContext().getServerCapabilities() & Capabilities.CLIENT_SESSION_TRACK) != 0
if (client.getContext().hasClientCapability(Capabilities.CLIENT_SESSION_TRACK)
&& catalog.equals(client.getContext().getDatabase())) {
return;
}
Expand Down
41 changes: 25 additions & 16 deletions src/main/java/org/mariadb/jdbc/ServerPreparedStatement.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@

package org.mariadb.jdbc;

import static org.mariadb.jdbc.util.constants.Capabilities.*;

import java.sql.*;
import java.util.*;
import java.util.concurrent.locks.ReentrantLock;
import java.util.regex.Pattern;
import org.mariadb.jdbc.client.Completion;
import org.mariadb.jdbc.client.Context;
import org.mariadb.jdbc.client.result.CompleteResult;
import org.mariadb.jdbc.client.result.Result;
import org.mariadb.jdbc.client.util.Parameters;
Expand All @@ -21,7 +24,6 @@
import org.mariadb.jdbc.message.server.OkPacket;
import org.mariadb.jdbc.message.server.PrepareResultPacket;
import org.mariadb.jdbc.util.ParameterList;
import org.mariadb.jdbc.util.constants.Capabilities;

/**
* Server prepare statement. command will generate COM_STMT_PREPARE + COM_STMT_EXECUTE (+
Expand Down Expand Up @@ -87,9 +89,7 @@ protected void executeInternal() throws SQLException {
String cmd = escapeTimeout(sql);
if (prepareResult == null) prepareResult = con.getContext().getPrepareCache().get(cmd, this);
try {
long serverCapabilities = con.getContext().getServerCapabilities();
if (prepareResult == null
&& (serverCapabilities & Capabilities.MARIADB_CLIENT_STMT_BULK_OPERATIONS) > 0) {
if (prepareResult == null && con.getContext().hasClientCapability(STMT_BULK_OPERATIONS)) {
try {
executePipeline(cmd);
} catch (BatchUpdateException b) {
Expand Down Expand Up @@ -160,20 +160,29 @@ private void executeStandard(String cmd) throws SQLException {
private void executeInternalPreparedBatch() throws SQLException {
checkNotClosed();
String cmd = escapeTimeout(sql);
long serverCapabilities = con.getContext().getServerCapabilities();
if (batchParameters.size() > 1
&& (serverCapabilities & Capabilities.MARIADB_CLIENT_STMT_BULK_OPERATIONS) > 0
&& (!con.getContext().getConf().allowLocalInfile()
|| (serverCapabilities & Capabilities.LOCAL_FILES) == 0)) {
if (con.getContext().getConf().useBulkStmts()
&& autoGeneratedKeys != Statement.RETURN_GENERATED_KEYS) {
executeBatchBulk(cmd);
} else {
executeBatchPipeline(cmd);
if (batchParameters.size() > 1 && con.getContext().hasServerCapability(STMT_BULK_OPERATIONS)) {

// ensure pipelining is possible (no LOAD DATA/XML INFILE commands)
boolean possibleLoadLocal = con.getContext().hasClientCapability(LOCAL_FILES);
if (possibleLoadLocal) {
String sqlUpper = sql.toUpperCase(Locale.ROOT);
possibleLoadLocal =
sqlUpper.contains(" LOCAL ")
&& sqlUpper.contains("LOAD")
&& sqlUpper.contains(" INFILE");
}

if (!possibleLoadLocal) {
if (con.getContext().getConf().useBulkStmts()
&& autoGeneratedKeys != Statement.RETURN_GENERATED_KEYS) {
executeBatchBulk(cmd);
} else {
executeBatchPipeline(cmd);
}
return;
}
} else {
executeBatchStandard(cmd);
}
executeBatchStandard(cmd);
}

/**
Expand Down
46 changes: 35 additions & 11 deletions src/main/java/org/mariadb/jdbc/Statement.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import java.sql.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.concurrent.locks.ReentrantLock;
import org.mariadb.jdbc.client.Completion;
import org.mariadb.jdbc.client.DataType;
Expand All @@ -17,9 +18,10 @@
import org.mariadb.jdbc.message.server.ColumnDefinitionPacket;
import org.mariadb.jdbc.message.server.OkPacket;
import org.mariadb.jdbc.util.NativeSql;
import org.mariadb.jdbc.util.constants.Capabilities;
import org.mariadb.jdbc.util.constants.ServerStatus;

import static org.mariadb.jdbc.util.constants.Capabilities.LOCAL_FILES;

/** Statement implementation */
public class Statement implements java.sql.Statement {

Expand Down Expand Up @@ -693,12 +695,23 @@ public int[] executeBatch() throws SQLException {
if (batchQueries == null || batchQueries.isEmpty()) return new int[0];
lock.lock();
try {
long serverCapabilities = con.getContext().getServerCapabilities();
// ensure pipelining is possible (no LOAD DATA/XML INFILE commands)
boolean possibleLoadLocal = con.getContext().hasClientCapability(LOCAL_FILES);
if (possibleLoadLocal) {
for (int i = 0; i < batchQueries.size(); i++) {
String sql = batchQueries.get(i).toUpperCase(Locale.ROOT);
if (sql.contains(" LOCAL ") && sql.contains("LOAD") && sql.contains(" INFILE")) {
break;
}
}
possibleLoadLocal = false;
}

List<Completion> res =
(!con.getContext().getConf().allowLocalInfile()
|| (serverCapabilities & Capabilities.LOCAL_FILES) == 0)
? executeInternalBatchPipeline()
: executeInternalBatchStandard();
possibleLoadLocal
? executeInternalBatchStandard()
: executeInternalBatchPipeline();

results = res;

int[] updates = new int[res.size()];
Expand Down Expand Up @@ -1430,12 +1443,23 @@ public long[] executeLargeBatch() throws SQLException {

lock.lock();
try {
long serverCapabilities = con.getContext().getServerCapabilities();
// ensure pipelining is possible (no LOAD DATA/XML INFILE commands)
boolean possibleLoadLocal = con.getContext().hasClientCapability(LOCAL_FILES);
if (possibleLoadLocal) {
for (int i = 0; i < batchQueries.size(); i++) {
String sql = batchQueries.get(i).toUpperCase(Locale.ROOT);
if (sql.contains(" LOCAL ") && sql.contains("LOAD") && sql.contains(" INFILE")) {
break;
}
}
possibleLoadLocal = false;
}

List<Completion> res =
(!con.getContext().getConf().allowLocalInfile()
|| (serverCapabilities & Capabilities.LOCAL_FILES) == 0)
? executeInternalBatchPipeline()
: executeInternalBatchStandard();
possibleLoadLocal
? executeInternalBatchStandard()
: executeInternalBatchPipeline();

results = res;
long[] updates = new long[res.size()];
for (int i = 0; i < res.size(); i++) {
Expand Down
13 changes: 10 additions & 3 deletions src/main/java/org/mariadb/jdbc/client/Context.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,18 @@ public interface Context {
byte[] getSeed();

/**
* Get server capabilities.
* has server capability
*
* @return server capabilities
* @return true if server has capability
*/
long getServerCapabilities();
boolean hasServerCapability(long flag);

/**
* has client capability
*
* @return true if client has capability
*/
boolean hasClientCapability(long flag);

/**
* Get server connection state
Expand Down
16 changes: 11 additions & 5 deletions src/main/java/org/mariadb/jdbc/client/context/BaseContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ public class BaseContext implements Context {

private final long threadId;
private final long serverCapabilities;
private final long clientCapabilities;
private final byte[] seed;
private final ServerVersion version;
private final boolean eofDeprecated;
Expand Down Expand Up @@ -63,9 +64,10 @@ public BaseContext(
this.serverCapabilities = handshake.getCapabilities();
this.serverStatus = handshake.getServerStatus();
this.version = handshake.getVersion();
this.eofDeprecated = (clientCapabilities & Capabilities.CLIENT_DEPRECATE_EOF) > 0;
this.skipMeta = (clientCapabilities & Capabilities.MARIADB_CLIENT_CACHE_METADATA) > 0;
this.extendedInfo = (serverCapabilities & Capabilities.MARIADB_CLIENT_EXTENDED_TYPE_INFO) > 0;
this.clientCapabilities = clientCapabilities;
this.eofDeprecated = hasClientCapability(Capabilities.CLIENT_DEPRECATE_EOF);
this.skipMeta = hasClientCapability(Capabilities.CACHE_METADATA);
this.extendedInfo = hasClientCapability(Capabilities.EXTENDED_TYPE_INFO);
this.conf = conf;
this.database = conf.database();
this.exceptionFactory = exceptionFactory;
Expand All @@ -80,8 +82,12 @@ public byte[] getSeed() {
return seed;
}

public long getServerCapabilities() {
return serverCapabilities;
public boolean hasServerCapability(long flag) {
return (serverCapabilities & flag) > 0;
}

public boolean hasClientCapability(long flag) {
return (clientCapabilities & flag) > 0;
}

public int getServerStatus() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,14 +187,14 @@ public static long initializeClientCapabilities(
| Capabilities.CONNECT_ATTRS
| Capabilities.PLUGIN_AUTH_LENENC_CLIENT_DATA
| Capabilities.CLIENT_SESSION_TRACK
| Capabilities.MARIADB_CLIENT_EXTENDED_TYPE_INFO;
| Capabilities.EXTENDED_TYPE_INFO;

// since skipping metadata is only available when using binary protocol,
// only set it when server permit it and using binary protocol
if (configuration.useServerPrepStmts()
&& Boolean.parseBoolean(
configuration.nonMappedOptions().getProperty("enableSkipMeta", "true"))) {
capabilities |= Capabilities.MARIADB_CLIENT_CACHE_METADATA;
capabilities |= Capabilities.CACHE_METADATA;
}

// remains for compatibility
Expand All @@ -204,7 +204,7 @@ public static long initializeClientCapabilities(
}

if (configuration.useBulkStmts()) {
capabilities |= Capabilities.MARIADB_CLIENT_STMT_BULK_OPERATIONS;
capabilities |= Capabilities.STMT_BULK_OPERATIONS;
}

if (!configuration.useAffectedRows()) {
Expand Down Expand Up @@ -367,7 +367,7 @@ public static SSLSocket sslWrapper(
Configuration conf = context.getConf();
if (conf.sslMode() != SslMode.DISABLE) {

if ((context.getServerCapabilities() & Capabilities.SSL) == 0) {
if (!context.hasServerCapability(Capabilities.SSL)) {
throw context
.getExceptionFactory()
.create("Trying to connect with ssl, but ssl not enabled in the server", "08000");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

package org.mariadb.jdbc.message.client;

import static org.mariadb.jdbc.util.constants.Capabilities.*;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.StringTokenizer;
Expand All @@ -15,7 +17,6 @@
import org.mariadb.jdbc.plugin.Credential;
import org.mariadb.jdbc.plugin.authentication.standard.NativePasswordPlugin;
import org.mariadb.jdbc.util.VersionFactory;
import org.mariadb.jdbc.util.constants.Capabilities;

/**
* Server handshake response builder. see
Expand Down Expand Up @@ -133,7 +134,7 @@ public int encode(Writer writer, Context context) throws IOException {

final byte[] authData;
if ("mysql_clear_password".equals(authenticationPluginType)) {
if ((clientCapabilities & Capabilities.SSL) == 0) {
if (!context.hasClientCapability(SSL)) {
throw new IllegalStateException("Cannot send password in clear if SSL is not enabled.");
}
authData =
Expand All @@ -153,28 +154,28 @@ public int encode(Writer writer, Context context) throws IOException {
writer.writeString(username != null ? username : System.getProperty("user.name"));
writer.writeByte(0x00);

if ((context.getServerCapabilities() & Capabilities.PLUGIN_AUTH_LENENC_CLIENT_DATA) != 0) {
if (context.hasServerCapability(PLUGIN_AUTH_LENENC_CLIENT_DATA)) {
writer.writeLength(authData.length);
writer.writeBytes(authData);
} else if ((context.getServerCapabilities() & Capabilities.SECURE_CONNECTION) != 0) {
} else if (context.hasServerCapability(SECURE_CONNECTION)) {
writer.writeByte((byte) authData.length);
writer.writeBytes(authData);
} else {
writer.writeBytes(authData);
writer.writeByte(0x00);
}

if ((clientCapabilities & Capabilities.CONNECT_WITH_DB) != 0) {
if (context.hasClientCapability(CONNECT_WITH_DB)) {
writer.writeString(database);
writer.writeByte(0x00);
}

if ((context.getServerCapabilities() & Capabilities.PLUGIN_AUTH) != 0) {
if (context.hasServerCapability(PLUGIN_AUTH)) {
writer.writeString(authenticationPluginType);
writer.writeByte(0x00);
}

if ((context.getServerCapabilities() & Capabilities.CONNECT_ATTRS) != 0) {
if (context.hasServerCapability(CONNECT_ATTRS)) {
writeConnectAttributes(writer, connectionAttributes, host);
}
writer.flush();
Expand Down
3 changes: 1 addition & 2 deletions src/main/java/org/mariadb/jdbc/message/server/OkPacket.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ public OkPacket(ReadableByteBuf buf, Context context) {
context.setServerStatus(buf.readUnsignedShort());
context.setWarning(buf.readUnsignedShort());

if ((context.getServerCapabilities() & Capabilities.CLIENT_SESSION_TRACK) != 0
&& buf.readableBytes() > 0) {
if (context.hasClientCapability(Capabilities.CLIENT_SESSION_TRACK) && buf.readableBytes() > 0) {
buf.skip(buf.readIntLengthEncodedNotNull()); // skip info
while (buf.readableBytes() > 0) {
if (buf.readIntLengthEncodedNotNull() > 0) {
Expand Down
Loading

0 comments on commit e6a29d0

Please sign in to comment.