Skip to content

Commit

Permalink
Support prepared statements for query (#2913)
Browse files Browse the repository at this point in the history
Added RawParameterizedSqlStatement

Co-authored-by: Nathan Voxland <nathan@voxland.net>
  • Loading branch information
fbiville and nvoxland committed Jun 14, 2022
1 parent 9985a44 commit 970f63b
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 31 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package liquibase.executor.jvm;

import liquibase.Scope;
import liquibase.sql.SqlConfiguration;
import liquibase.database.DatabaseConnection;
import liquibase.database.OfflineConnection;
import liquibase.database.PreparedStatementFactory;
Expand All @@ -15,19 +14,18 @@
import liquibase.servicelocator.PrioritizedService;
import liquibase.sql.CallableSql;
import liquibase.sql.Sql;
import liquibase.sql.SqlConfiguration;
import liquibase.sql.visitor.SqlVisitor;
import liquibase.sqlgenerator.SqlGeneratorFactory;
import liquibase.statement.CallableSqlStatement;
import liquibase.statement.CompoundStatement;
import liquibase.statement.ExecutablePreparedStatement;
import liquibase.statement.SqlStatement;
import liquibase.statement.core.RawParameterizedSqlStatement;
import liquibase.util.JdbcUtil;
import liquibase.util.StringUtil;

import java.sql.CallableStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.sql.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
Expand All @@ -41,23 +39,19 @@
public class JdbcExecutor extends AbstractExecutor {

/**
*
* Return the name of the Executor
*
* @return String The Executor name
*
*/
@Override
public String getName() {
return "jdbc";
}

/**
*
* Return the Executor priority
*
* @return int The Executor priority
*
*/
@Override
public int getPriority() {
Expand All @@ -81,8 +75,7 @@ public Object execute(StatementCallback action, List<SqlVisitor> sqlVisitors) th
Statement stmtToUse = stmt;

return action.doInStatement(stmtToUse);
}
catch (SQLException ex) {
} catch (SQLException ex) {
// Release Connection early, to avoid potential connection pool deadlock
// in the case when the exception translator hasn't been initialized yet.
JdbcUtil.closeStatement(stmt);
Expand All @@ -93,9 +86,8 @@ public Object execute(StatementCallback action, List<SqlVisitor> sqlVisitors) th
} else {
url = con.getURL();
}
throw new DatabaseException("Error executing SQL " + StringUtil.join(applyVisitors(action.getStatement(), sqlVisitors), "; on "+ url)+": "+ex.getMessage(), ex);
}
finally {
throw new DatabaseException("Error executing SQL " + StringUtil.join(applyVisitors(action.getStatement(), sqlVisitors), "; on " + url) + ": " + ex.getMessage(), ex);
} finally {
JdbcUtil.closeStatement(stmt);
}
}
Expand All @@ -117,15 +109,13 @@ public Object execute(CallableStatementCallback action, List<SqlVisitor> sqlVisi

stmt = ((JdbcConnection) con).getUnderlyingConnection().prepareCall(sql);
return action.doInCallableStatement(stmt);
}
catch (SQLException ex) {
} catch (SQLException ex) {
// Release Connection early, to avoid potential connection pool deadlock
// in the case when the exception translator hasn't been initialized yet.
JdbcUtil.closeStatement(stmt);
stmt = null;
throw new DatabaseException("Error executing SQL " + StringUtil.join(applyVisitors(action.getStatement(), sqlVisitors), "; on "+ con.getURL())+": "+ex.getMessage(), ex);
}
finally {
throw new DatabaseException("Error executing SQL " + StringUtil.join(applyVisitors(action.getStatement(), sqlVisitors), "; on " + con.getURL()) + ": " + ex.getMessage(), ex);
} finally {
JdbcUtil.closeStatement(stmt);
}
}
Expand All @@ -137,8 +127,27 @@ public void execute(final SqlStatement sql) throws DatabaseException {

@Override
public void execute(final SqlStatement sql, final List<SqlVisitor> sqlVisitors) throws DatabaseException {
if(sql instanceof ExecutablePreparedStatement) {
((ExecutablePreparedStatement) sql).execute(new PreparedStatementFactory((JdbcConnection)database.getConnection()));
if (sql instanceof RawParameterizedSqlStatement) {
PreparedStatementFactory factory = new PreparedStatementFactory((JdbcConnection) database.getConnection());

String finalSql = applyVisitors((RawParameterizedSqlStatement) sql, sqlVisitors);

try (PreparedStatement pstmt = factory.create(finalSql)) {
final List<?> parameters = ((RawParameterizedSqlStatement) sql).getParameters();
for (int i = 0; i < parameters.size(); i++) {
pstmt.setObject(i, parameters.get(i));
}
pstmt.execute();

return;
} catch (SQLException e) {
throw new DatabaseException(e);
}
}


if (sql instanceof ExecutablePreparedStatement) {
((ExecutablePreparedStatement) sql).execute(new PreparedStatementFactory((JdbcConnection) database.getConnection()));
return;
}
if (sql instanceof CompoundStatement) {
Expand All @@ -151,12 +160,40 @@ public void execute(final SqlStatement sql, final List<SqlVisitor> sqlVisitors)
execute(new ExecuteStatementCallback(sql, sqlVisitors), sqlVisitors);
}

private String applyVisitors(RawParameterizedSqlStatement sql, List<SqlVisitor> sqlVisitors) {
String finalSql = sql.getSql();
if (sqlVisitors != null) {
for (SqlVisitor visitor : sqlVisitors) {
if (visitor != null) {
finalSql = visitor.modifySql(finalSql, database);
}
}
}
return finalSql;
}


public Object query(final SqlStatement sql, final ResultSetExtractor rse) throws DatabaseException {
return query(sql, rse, new ArrayList<SqlVisitor>());
}

public Object query(final SqlStatement sql, final ResultSetExtractor rse, final List<SqlVisitor> sqlVisitors) throws DatabaseException {
if (sql instanceof RawParameterizedSqlStatement) {
PreparedStatementFactory factory = new PreparedStatementFactory((JdbcConnection) database.getConnection());

String finalSql = applyVisitors((RawParameterizedSqlStatement) sql, sqlVisitors);

try (PreparedStatement pstmt = factory.create(finalSql);) {
final List<?> parameters = ((RawParameterizedSqlStatement) sql).getParameters();
for (int i = 0; i < parameters.size(); i++) {
pstmt.setObject(i, parameters.get(0));
}
return rse.extractData(pstmt.executeQuery());
} catch (SQLException e) {
throw new DatabaseException(e);
}
}

if (sql instanceof CallableSqlStatement) {
return execute(new QueryCallableStatementCallback(sql, rse), sqlVisitors);
}
Expand All @@ -181,7 +218,7 @@ public Object queryForObject(SqlStatement sql, RowMapper rowMapper, List<SqlVisi
try {
return JdbcUtil.requiredSingleResult(results);
} catch (DatabaseException e) {
throw new DatabaseException("Expected single row from " + sql + " but got "+results.size(), e);
throw new DatabaseException("Expected single row from " + sql + " but got " + results.size(), e);
}
}

Expand Down Expand Up @@ -351,7 +388,7 @@ private void checkCallStatus(ResultSet resultSet, String status) throws SQLExcep

String getErrorCode(Throwable e) {
if (e instanceof SQLException) {
return "(" + ((SQLException)e).getErrorCode() + ") ";
return "(" + ((SQLException) e).getErrorCode() + ") ";
}
return "";
}
Expand Down Expand Up @@ -394,7 +431,7 @@ public Object doInStatement(Statement stmt) throws SQLException, DatabaseExcepti
log.log(sqlLogLevel, stmt.getUpdateCount() + " row(s) affected", null);
}
} catch (Throwable e) {
throw new DatabaseException(e.getMessage()+ " [Failed SQL: " + getErrorCode(e) + statement+"]", e);
throw new DatabaseException(e.getMessage() + " [Failed SQL: " + getErrorCode(e) + statement + "]", e);
}
try {
int updateCount = 0;
Expand All @@ -408,7 +445,7 @@ public Object doInStatement(Statement stmt) throws SQLException, DatabaseExcepti
} while (updateCount != -1);

} catch (Exception e) {
throw new DatabaseException(e.getMessage()+ " [Failed SQL: "+ getErrorCode(e) + statement+"]", e);
throw new DatabaseException(e.getMessage() + " [Failed SQL: " + getErrorCode(e) + statement + "]", e);
}
}
return null;
Expand Down Expand Up @@ -437,9 +474,10 @@ private QueryStatementCallback(SqlStatement sql, ResultSetExtractor rse, List<Sq
* 1. Applies all SqlVisitor to the stmt
* 2. Executes the (possibly modified) stmt
* 3. Reads all data from the java.sql.ResultSet into an Object and returns the Object.
*
* @param stmt A JDBC Statement that is expected to return a ResultSet (e.g. SELECT)
* @return An object representing all data from the result set.
* @throws SQLException If an error occurs during SQL processing
* @throws SQLException If an error occurs during SQL processing
* @throws DatabaseException If an error occurs in the DBMS-specific program code
*/
@Override
Expand All @@ -465,10 +503,9 @@ public Object doInStatement(Statement stmt) throws SQLException, DatabaseExcepti
listener.readSqlWillRun(sqlToExecute[0]);
}
}
}
finally {
} finally {
if (rs != null) {
JdbcUtil.closeResultSet(rs);
JdbcUtil.closeResultSet(rs);
}
}
}
Expand Down Expand Up @@ -497,8 +534,7 @@ public Object doInCallableStatement(CallableStatement cs) throws SQLException, D
try {
rs = cs.executeQuery();
return rse.extractData(rs);
}
finally {
} finally {
JdbcUtil.closeResultSet(rs);
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package liquibase.statement.core;

import liquibase.statement.AbstractSqlStatement;
import liquibase.util.StringUtil;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class RawParameterizedSqlStatement extends AbstractSqlStatement {

private final String sql;
private final List parameters = new ArrayList<>();

public RawParameterizedSqlStatement(String sql, Object... parameters) {
this.sql = sql;
if (parameters != null) {
this.parameters.addAll(Arrays.asList(parameters));
}
}

public String getSql() {
return sql;
}

public List<?> getParameters() {
return parameters;
}

public RawParameterizedSqlStatement addParameter(Object parameter) {
this.parameters.add(parameter);

return this;
}

@Override
public String toString() {
return sql + " with " + StringUtil.join(parameters, ",", new StringUtil.ToStringFormatter());
}
}

0 comments on commit 970f63b

Please sign in to comment.