Skip to content

Commit

Permalink
[CONJ-986] adding Statement.setLocalInfileInputStream(<InputStream>) …
Browse files Browse the repository at this point in the history
…for 2.x compatibility
  • Loading branch information
rusher committed Jul 20, 2022
1 parent 418550b commit fb57f50
Show file tree
Hide file tree
Showing 10 changed files with 214 additions and 32 deletions.
9 changes: 6 additions & 3 deletions src/main/java/org/mariadb/jdbc/ClientPreparedStatement.java
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ private void executeInternal() throws SQLException {
lock.lock();
try {
QueryWithParametersPacket query =
new QueryWithParametersPacket(preSqlCmd(), parser, parameters);
new QueryWithParametersPacket(preSqlCmd(), parser, parameters, localInfileInputStream);
results =
con.getClient()
.execute(
Expand All @@ -102,6 +102,7 @@ private void executeInternal() throws SQLException {
closeOnCompletion,
false);
} finally {
localInfileInputStream = null;
lock.unlock();
}
}
Expand Down Expand Up @@ -187,7 +188,7 @@ private void executeBatchBulk() throws SQLException {
private void executeBatchPipeline() throws SQLException {
ClientMessage[] packets = new ClientMessage[batchParameters.size()];
for (int i = 0; i < batchParameters.size(); i++) {
packets[i] = new QueryWithParametersPacket(preSqlCmd(), parser, batchParameters.get(i));
packets[i] = new QueryWithParametersPacket(preSqlCmd(), parser, batchParameters.get(i), null);
}
try {
results =
Expand Down Expand Up @@ -220,7 +221,8 @@ private void executeBatchStd() throws SQLException {
results.addAll(
con.getClient()
.execute(
new QueryWithParametersPacket(preSqlCmd(), parser, batchParameters.get(i)),
new QueryWithParametersPacket(
preSqlCmd(), parser, batchParameters.get(i), localInfileInputStream),
this,
0,
maxRows,
Expand All @@ -233,6 +235,7 @@ private void executeBatchStd() throws SQLException {
BatchUpdateException exception =
exceptionFactory().createBatchUpdate(results, batchParameters.size(), bue);
results = null;
localInfileInputStream = null;
throw exception;
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/mariadb/jdbc/Configuration.java
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ public class Configuration {

// protocol
private boolean allowMultiQueries = false;
private boolean allowLocalInfile = false;
private boolean allowLocalInfile = true;
private boolean useCompression = false;
private boolean useAffectedRows = false;
private boolean useBulkStmts = true;
Expand Down
17 changes: 12 additions & 5 deletions src/main/java/org/mariadb/jdbc/ServerPreparedStatement.java
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ protected void executeInternal() throws SQLException {
executeStandard(cmd);
}
} finally {
localInfileInputStream = null;
lock.unlock();
}
}
Expand All @@ -112,7 +113,7 @@ private void executePipeline(String cmd) throws SQLException {
// server is 10.2+, permitting to execute last prepare with (-1) statement id.
// Server send prepare, followed by execute, in one exchange.
PreparePacket prepare = new PreparePacket(cmd);
ExecutePacket execute = new ExecutePacket(null, parameters, cmd, this);
ExecutePacket execute = new ExecutePacket(null, parameters, cmd, this, localInfileInputStream);
try {
List<Completion> res =
con.getClient()
Expand Down Expand Up @@ -142,7 +143,8 @@ private void executeStandard(String cmd) throws SQLException {
}
validParameters();
// send COM_STMT_EXECUTE
ExecutePacket execute = new ExecutePacket(prepareResult, parameters, cmd, this);
ExecutePacket execute =
new ExecutePacket(prepareResult, parameters, cmd, this, localInfileInputStream);
results =
con.getClient()
.execute(
Expand Down Expand Up @@ -273,7 +275,9 @@ private List<Completion> executeBunch(String cmd, int index, int maxCmd) throws
int maxCmdToSend = Math.min(batchParameters.size() - index, maxCmd);
ClientMessage[] packets = new ClientMessage[maxCmdToSend];
for (int i = index; i < index + maxCmdToSend; i++) {
packets[i - index] = new ExecutePacket(prepareResult, batchParameters.get(i), cmd, this);
packets[i - index] =
new ExecutePacket(
prepareResult, batchParameters.get(i), cmd, this, localInfileInputStream);
}
return con.getClient()
.executePipeline(
Expand All @@ -293,7 +297,8 @@ private List<Completion> executeBunchPrepare(String cmd, int index, int maxCmd)
ClientMessage[] packets = new ClientMessage[maxCmdToSend + 1];
packets[0] = new PreparePacket(cmd);
for (int i = index; i < index + maxCmdToSend; i++) {
packets[i + 1 - index] = new ExecutePacket(null, batchParameters.get(i), cmd, this);
packets[i + 1 - index] =
new ExecutePacket(null, batchParameters.get(i), cmd, this, localInfileInputStream);
}
List<Completion> res =
con.getClient()
Expand Down Expand Up @@ -334,7 +339,8 @@ private void executeBatchStandard(String cmd) throws SQLException {
}
}
try {
ExecutePacket execute = new ExecutePacket(prepareResult, batchParameter, cmd, this);
ExecutePacket execute =
new ExecutePacket(prepareResult, batchParameter, cmd, this, localInfileInputStream);
tmpResults.addAll(con.getClient().execute(execute, this, false));
} catch (SQLException e) {
if (error == null) error = e;
Expand Down Expand Up @@ -619,6 +625,7 @@ public int[] executeBatch() throws SQLException {
return updates;

} finally {
localInfileInputStream = null;
batchParameters.clear();
lock.unlock();
}
Expand Down
15 changes: 13 additions & 2 deletions src/main/java/org/mariadb/jdbc/Statement.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import static org.mariadb.jdbc.util.constants.Capabilities.LOCAL_FILES;

import java.io.InputStream;
import java.sql.*;
import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -57,6 +58,8 @@ public class Statement implements java.sql.Statement {
protected List<Completion> results;
/** current results */
protected Completion currResult;
/** streaming load data infile data */
protected InputStream localInfileInputStream;

/**
* Constructor
Expand Down Expand Up @@ -93,6 +96,11 @@ private ExceptionFactory exceptionFactory() {
return con.getExceptionFactory().of(this);
}

public void setLocalInfileInputStream(InputStream inputStream) throws SQLException {
checkNotClosed();
localInfileInputStream = inputStream;
}

/**
* Executes the given SQL statement, which returns a single <code>ResultSet</code> object.
*
Expand Down Expand Up @@ -919,7 +927,7 @@ private void executeInternal(String sql, int autoGeneratedKeys) throws SQLExcept
results =
con.getClient()
.execute(
new QueryPacket(cmd),
new QueryPacket(cmd, localInfileInputStream),
this,
fetchSize,
maxRows,
Expand All @@ -928,6 +936,7 @@ private void executeInternal(String sql, int autoGeneratedKeys) throws SQLExcept
closeOnCompletion,
false);
} finally {
localInfileInputStream = null;
lock.unlock();
}
}
Expand Down Expand Up @@ -1513,7 +1522,7 @@ public List<Completion> executeInternalBatchStandard() throws SQLException {
results.addAll(
con.getClient()
.execute(
new QueryPacket(batchQuery),
new QueryPacket(batchQuery, localInfileInputStream),
this,
0,
0L,
Expand All @@ -1531,6 +1540,8 @@ public List<Completion> executeInternalBatchStandard() throws SQLException {
completion instanceof OkPacket ? (int) ((OkPacket) completion).getAffectedRows() : 0;
}
throw new BatchUpdateException(sqle.getMessage(), updateCounts, sqle);
} finally {
localInfileInputStream = null;
}
}
}
53 changes: 35 additions & 18 deletions src/main/java/org/mariadb/jdbc/message/ClientMessage.java
Original file line number Diff line number Diff line change
Expand Up @@ -138,31 +138,44 @@ default Completion readPacket(
errorPacket.getMessage(), errorPacket.getSqlState(), errorPacket.getErrorCode());
case 0xfb:
buf.skip(1); // skip header
String fileName = buf.readStringNullEnd();

SQLException exception = null;
if (!message.validateLocalFileName(fileName, context)) {
exception =
exceptionFactory
.withSql(this.description())
.create(
String.format(
"LOAD DATA LOCAL INFILE asked for file '%s' that doesn't correspond to initial query %s. Possible malicious proxy changing server answer ! Command interrupted",
fileName, this.description()),
"HY000");
} else {
try (InputStream is = new FileInputStream(fileName)) {

InputStream is = getLocalInfileInputStream();
if (is == null) {
String fileName = buf.readStringNullEnd();
if (!message.validateLocalFileName(fileName, context)) {
exception =
exceptionFactory
.withSql(this.description())
.create(
String.format(
"LOAD DATA LOCAL INFILE asked for file '%s' that doesn't correspond to initial query %s. Possible malicious proxy changing server answer ! Command interrupted",
fileName, this.description()),
"HY000");
} else {

try {
is = new FileInputStream(fileName);
} catch (FileNotFoundException f) {
exception =
exceptionFactory
.withSql(this.description())
.create("Could not send file : " + f.getMessage(), "HY000", f);
}
}
}

// sending stream
if (is != null) {
try {
byte[] fileBuf = new byte[8192];
int len;
while ((len = is.read(fileBuf)) > 0) {
writer.writeBytes(fileBuf, 0, len);
writer.flush();
}
} catch (FileNotFoundException f) {
exception =
exceptionFactory
.withSql(this.description())
.create("Could not send file : " + f.getMessage(), "HY000", f);
} finally {
is.close();
}
}

Expand Down Expand Up @@ -261,6 +274,10 @@ default Completion readPacket(
}
}

default InputStream getLocalInfileInputStream() {
return null;
}

/**
* Request for local file to be validated from current query.
*
Expand Down
13 changes: 12 additions & 1 deletion src/main/java/org/mariadb/jdbc/message/client/ExecutePacket.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package org.mariadb.jdbc.message.client;

import java.io.IOException;
import java.io.InputStream;
import java.sql.SQLException;
import org.mariadb.jdbc.ServerPreparedStatement;
import org.mariadb.jdbc.client.Context;
Expand All @@ -25,6 +26,7 @@ public final class ExecutePacket implements RedoableWithPrepareClientMessage {
private final String command;
private final ServerPreparedStatement prep;
private Prepare prepareResult;
private InputStream localInfileInputStream;

/**
* Constructor
Expand All @@ -35,11 +37,16 @@ public final class ExecutePacket implements RedoableWithPrepareClientMessage {
* @param prep prepared statement
*/
public ExecutePacket(
Prepare prepareResult, Parameters parameters, String command, ServerPreparedStatement prep) {
Prepare prepareResult,
Parameters parameters,
String command,
ServerPreparedStatement prep,
InputStream localInfileInputStream) {
this.parameters = parameters;
this.prepareResult = prepareResult;
this.command = command;
this.prep = prep;
this.localInfileInputStream = localInfileInputStream;
}

public void saveParameters() {
Expand Down Expand Up @@ -131,6 +138,10 @@ public String getCommand() {
return command;
}

public InputStream getLocalInfileInputStream() {
return localInfileInputStream;
}

public ServerPreparedStatement prep() {
return prep;
}
Expand Down
12 changes: 12 additions & 0 deletions src/main/java/org/mariadb/jdbc/message/client/QueryPacket.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package org.mariadb.jdbc.message.client;

import java.io.IOException;
import java.io.InputStream;
import org.mariadb.jdbc.client.Context;
import org.mariadb.jdbc.client.socket.Writer;
import org.mariadb.jdbc.message.ClientMessage;
Expand All @@ -13,6 +14,7 @@
public final class QueryPacket implements RedoableClientMessage {

private final String sql;
private final InputStream localInfileInputStream;

/**
* Constructor
Expand All @@ -21,6 +23,12 @@ public final class QueryPacket implements RedoableClientMessage {
*/
public QueryPacket(String sql) {
this.sql = sql;
this.localInfileInputStream = null;
}

public QueryPacket(String sql, InputStream localInfileInputStream) {
this.sql = sql;
this.localInfileInputStream = localInfileInputStream;
}

public int batchUpdateLength() {
Expand Down Expand Up @@ -49,6 +57,10 @@ public boolean validateLocalFileName(String fileName, Context context) {
return ClientMessage.validateLocalFileName(sql, null, fileName, context);
}

public InputStream getLocalInfileInputStream() {
return localInfileInputStream;
}

public String description() {
return sql;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package org.mariadb.jdbc.message.client;

import java.io.IOException;
import java.io.InputStream;
import java.sql.SQLException;
import org.mariadb.jdbc.client.Context;
import org.mariadb.jdbc.client.socket.Writer;
Expand All @@ -23,6 +24,7 @@ public final class QueryWithParametersPacket implements RedoableClientMessage {
private final String preSqlCmd;
private final ClientParser parser;
private Parameters parameters;
private InputStream localInfileInputStream;

/**
* Constructor
Expand All @@ -31,10 +33,15 @@ public final class QueryWithParametersPacket implements RedoableClientMessage {
* @param parser command parser result
* @param parameters parameters
*/
public QueryWithParametersPacket(String preSqlCmd, ClientParser parser, Parameters parameters) {
public QueryWithParametersPacket(
String preSqlCmd,
ClientParser parser,
Parameters parameters,
InputStream localInfileInputStream) {
this.preSqlCmd = preSqlCmd;
this.parser = parser;
this.parameters = parameters;
this.localInfileInputStream = localInfileInputStream;
}

@Override
Expand Down Expand Up @@ -83,6 +90,10 @@ public boolean validateLocalFileName(String fileName, Context context) {
return ClientMessage.validateLocalFileName(parser.getSql(), parameters, fileName, context);
}

public InputStream getLocalInfileInputStream() {
return localInfileInputStream;
}

@Override
public String description() {
return parser.getSql();
Expand Down
Loading

0 comments on commit fb57f50

Please sign in to comment.