diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java b/src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java index 3905f2c09c..f421c8cf00 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java @@ -52,15 +52,10 @@ import java.util.Set; import java.util.SimpleTimeZone; import java.util.TimeZone; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.Future; import java.util.concurrent.SynchronousQueue; -import java.util.concurrent.ThreadFactory; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; import java.util.logging.Level; import java.util.logging.Logger; @@ -6192,7 +6187,7 @@ final class TDSReaderMark { final class TDSReader { private final static Logger logger = Logger.getLogger("com.microsoft.sqlserver.jdbc.internals.TDS.Reader"); final private String traceID; - private TimeoutTimer tcpKeepAliveTimeoutTimer; + private TimeoutCommand timeoutCommand; final public String toString() { return traceID; @@ -6236,17 +6231,6 @@ private static int nextReaderID() { this.tdsChannel = tdsChannel; this.con = con; this.command = command; // may be null - if (null != command) { - // if cancelQueryTimeout is set, we should wait for the total amount of queryTimeout + cancelQueryTimeout to - // terminate the connection. - this.tcpKeepAliveTimeoutTimer = (command.getCancelQueryTimeoutSeconds() > 0 - && command.getQueryTimeoutSeconds() > 0) - ? (new TimeoutTimer( - command.getCancelQueryTimeoutSeconds() - + command.getQueryTimeoutSeconds(), - null, con)) - : null; - } // if the logging level is not detailed than fine or more we will not have proper reader IDs. if (logger.isLoggable(Level.FINE)) traceID = "TDSReader@" + nextReaderID() + " (" + con.toString() + ")"; @@ -6351,11 +6335,16 @@ synchronized final boolean readPacket() throws SQLServerException { + " should be less than numMsgsSent:" + tdsChannel.numMsgsSent; TDSPacket newPacket = new TDSPacket(con.getTDSPacketSize()); - if (null != tcpKeepAliveTimeoutTimer) { - if (logger.isLoggable(Level.FINEST)) { - logger.finest(this.toString() + ": starting timer..."); + if (null != command) { + // if cancelQueryTimeout is set, we should wait for the total amount of + // queryTimeout + cancelQueryTimeout to + // terminate the connection. + if ((command.getCancelQueryTimeoutSeconds() > 0 && command.getQueryTimeoutSeconds() > 0)) { + // if a timeout is configured with this object, add it to the timeout poller + int timeout = command.getCancelQueryTimeoutSeconds() + command.getQueryTimeoutSeconds(); + this.timeoutCommand = new TdsTimeoutCommand(timeout, this.command, this.con); + TimeoutPoller.getTimeoutPoller().addTimeoutCommand(this.timeoutCommand); } - tcpKeepAliveTimeoutTimer.start(); } // First, read the packet header. for (int headerBytesRead = 0; headerBytesRead < TDS.PACKET_HEADER_SIZE;) { @@ -6375,11 +6364,8 @@ synchronized final boolean readPacket() throws SQLServerException { } // if execution was subject to timeout then stop timing - if (null != tcpKeepAliveTimeoutTimer) { - if (logger.isLoggable(Level.FINEST)) { - logger.finest(this.toString() + ":stopping timer..."); - } - tcpKeepAliveTimeoutTimer.stop(); + if (this.timeoutCommand != null) { + TimeoutPoller.getTimeoutPoller().remove(this.timeoutCommand); } // Header size is a 2 byte unsigned short integer in big-endian order. int packetLength = Util.readUnsignedShortBigEndian(newPacket.header, TDS.PACKET_HEADER_MESSAGE_LENGTH); @@ -6969,82 +6955,25 @@ final void trySetSensitivityClassification(SensitivityClassification sensitivity /** - * Timer for use with Commands that support a timeout. - * - * Once started, the timer runs for the prescribed number of seconds unless stopped. If the timer runs out, it - * interrupts its associated Command with a reason like "timed out". + * The tds default implementation of a timeout command */ -final class TimeoutTimer implements Runnable { - private static final String threadGroupName = "mssql-jdbc-TimeoutTimer"; - private final int timeoutSeconds; - private final TDSCommand command; - private volatile Future task; - private final SQLServerConnection con; - - private static final ExecutorService executor = Executors.newCachedThreadPool(new ThreadFactory() { - private final AtomicReference tgr = new AtomicReference<>(); - private final AtomicInteger threadNumber = new AtomicInteger(0); - - @Override - public Thread newThread(Runnable r) { - ThreadGroup tg = tgr.get(); - - if (tg == null || tg.isDestroyed()) { - tg = new ThreadGroup(threadGroupName); - tgr.set(tg); - } - - Thread t = new Thread(tg, r, tg.getName() + "-" + threadNumber.incrementAndGet()); - t.setDaemon(true); - return t; - } - }); - - private volatile boolean canceled = false; - - TimeoutTimer(int timeoutSeconds, TDSCommand command, SQLServerConnection con) { - assert timeoutSeconds > 0; - - this.timeoutSeconds = timeoutSeconds; - this.command = command; - this.con = con; - } - - final void start() { - task = executor.submit(this); - } - - final void stop() { - task.cancel(true); - canceled = true; +class TdsTimeoutCommand extends TimeoutCommand { + public TdsTimeoutCommand(int timeout, TDSCommand command, SQLServerConnection sqlServerConnection) { + super(timeout, command, sqlServerConnection); } - public void run() { - int secondsRemaining = timeoutSeconds; + public void interrupt() { + TDSCommand command = getCommand(); + SQLServerConnection sqlServerConnection = getSqlServerConnection(); try { - // Poll every second while time is left on the timer. - // Return if/when the timer is canceled. - do { - if (canceled) - return; - - Thread.sleep(1000); - } while (--secondsRemaining > 0); - } catch (InterruptedException e) { - // re-interrupt the current thread, in order to restore the thread's interrupt status. - Thread.currentThread().interrupt(); - return; - } - - // If the timer wasn't canceled before it ran out of - // time then interrupt the registered command. - try { - // If TCP Connection to server is silently dropped, exceeding the query timeout on the same connection does + // If TCP Connection to server is silently dropped, exceeding the query timeout + // on the same connection does // not throw SQLTimeoutException - // The application stops responding instead until SocketTimeoutException is thrown. In this case, we must + // The application stops responding instead until SocketTimeoutException is + // thrown. In this case, we must // manually terminate the connection. - if (null == command && null != con) { - con.terminate(SQLServerException.DRIVER_ERROR_IO_FAILED, + if (null == command && null != sqlServerConnection) { + sqlServerConnection.terminate(SQLServerException.DRIVER_ERROR_IO_FAILED, SQLServerException.getErrString("R_connectionIsClosed")); } else { // If the timer wasn't canceled before it ran out of @@ -7061,7 +6990,6 @@ public void run() { } } - /** * TDSCommand encapsulates an interruptable TDS conversation. * @@ -7095,10 +7023,6 @@ final void log(Level level, String message) { logger.log(level, toString() + ": " + message); } - // Optional timer that is set if the command was created with a non-zero timeout period. - // When the timer expires, the command is interrupted. - private final TimeoutTimer timeoutTimer; - // TDS channel accessors // These are set/reset at command execution time. // Volatile ensures visibility to execution thread and interrupt thread @@ -7187,6 +7111,7 @@ protected void setProcessedResponse(boolean processedResponse) { private volatile boolean readingResponse; private int queryTimeoutSeconds; private int cancelQueryTimeoutSeconds; + private TdsTimeoutCommand timeoutCommand; protected int getQueryTimeoutSeconds() { return this.queryTimeoutSeconds; @@ -7213,7 +7138,6 @@ final boolean readingResponse() { this.logContext = logContext; this.queryTimeoutSeconds = queryTimeoutSeconds; this.cancelQueryTimeoutSeconds = cancelQueryTimeoutSeconds; - this.timeoutTimer = (queryTimeoutSeconds > 0) ? (new TimeoutTimer(queryTimeoutSeconds, this, null)) : null; } /** @@ -7602,11 +7526,9 @@ final TDSReader startResponse(boolean isAdaptive) throws SQLServerException { // If command execution is subject to timeout then start timing until // the server returns the first response packet. - if (null != timeoutTimer) { - if (logger.isLoggable(Level.FINEST)) - logger.finest(this.toString() + ": Starting timer..."); - - timeoutTimer.start(); + if (queryTimeoutSeconds > 0) { + this.timeoutCommand = new TdsTimeoutCommand(queryTimeoutSeconds, this, null); + TimeoutPoller.getTimeoutPoller().addTimeoutCommand(this.timeoutCommand); } if (logger.isLoggable(Level.FINEST)) @@ -7629,11 +7551,8 @@ final TDSReader startResponse(boolean isAdaptive) throws SQLServerException { } finally { // If command execution was subject to timeout then stop timing as soon // as the server returns the first response packet or errors out. - if (null != timeoutTimer) { - if (logger.isLoggable(Level.FINEST)) - logger.finest(this.toString() + ": Stopping timer..."); - - timeoutTimer.stop(); + if (this.timeoutCommand != null) { + TimeoutPoller.getTimeoutPoller().remove(this.timeoutCommand); } } diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerBulkCopy.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerBulkCopy.java index f45787f6ff..ace8bb6cbd 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerBulkCopy.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerBulkCopy.java @@ -247,56 +247,16 @@ class BulkColumnMetaData { private int srcColumnCount; /** - * Timer for the bulk copy operation. The other timeout timers in the TDS layer only measure the response of the - * first packet from SQL Server. + * Timeout for the bulk copy command */ - private final class BulkTimeoutTimer implements Runnable { - private final int timeoutSeconds; - private int secondsRemaining; - private final TDSCommand command; - private Thread timerThread; - private volatile boolean canceled = false; - - BulkTimeoutTimer(int timeoutSeconds, TDSCommand command) { - assert timeoutSeconds > 0; - assert null != command; - - this.timeoutSeconds = timeoutSeconds; - this.secondsRemaining = timeoutSeconds; - this.command = command; - } - - final void start() { - timerThread = new Thread(this); - timerThread.setDaemon(true); - timerThread.start(); - } - - final void stop() { - canceled = true; - timerThread.interrupt(); + private final class BulkTimeoutCommand extends TimeoutCommand { + public BulkTimeoutCommand(int timeout, TDSCommand command, SQLServerConnection sqlServerConnection) { + super(timeout, command, sqlServerConnection); } - final boolean expired() { - return (secondsRemaining <= 0); - } - - public void run() { - try { - // Poll every second while time is left on the timer. - // Return if/when the timer is canceled. - do { - if (canceled) - return; - - Thread.sleep(1000); - } while (--secondsRemaining > 0); - } catch (InterruptedException e) { - // re-interrupt the current thread, in order to restore the thread's interrupt status. - Thread.currentThread().interrupt(); - return; - } - + @Override + public void interrupt() { + TDSCommand command = getCommand(); // If the timer wasn't canceled before it ran out of // time then interrupt the registered command. try { @@ -310,7 +270,7 @@ public void run() { } } - private BulkTimeoutTimer timeoutTimer = null; + private BulkTimeoutCommand timeoutCommand; /** * The maximum temporal precision we can send when using varchar(precision) in bulkcommand, to send a @@ -687,15 +647,15 @@ final class InsertBulk extends TDSCommand { InsertBulk() { super("InsertBulk", 0, 0); int timeoutSeconds = copyOptions.getBulkCopyTimeout(); - timeoutTimer = (timeoutSeconds > 0) ? (new BulkTimeoutTimer(timeoutSeconds, this)) : null; + timeoutCommand = timeoutSeconds > 0 ? new BulkTimeoutCommand(timeoutSeconds, this, null) : null; } final boolean doExecute() throws SQLServerException { - if (null != timeoutTimer) { + if (null != timeoutCommand) { if (logger.isLoggable(Level.FINEST)) logger.finest(this.toString() + ": Starting bulk timer..."); - timeoutTimer.start(); + TimeoutPoller.getTimeoutPoller().addTimeoutCommand(timeoutCommand); } // doInsertBulk inserts the rows in one batch. It returns true if there are more rows in @@ -712,18 +672,18 @@ final boolean doExecute() throws SQLServerException { // Check whether it is a timeout exception. if (rootCause instanceof SQLException) { - checkForTimeoutException((SQLException) rootCause, timeoutTimer); + checkForTimeoutException((SQLException) rootCause, timeoutCommand); } // It is not a timeout exception. Re-throw. throw topLevelException; } - if (null != timeoutTimer) { + if (null != timeoutCommand) { if (logger.isLoggable(Level.FINEST)) logger.finest(this.toString() + ": Stopping bulk timer..."); - timeoutTimer.stop(); + TimeoutPoller.getTimeoutPoller().remove(timeoutCommand); } return true; @@ -1188,9 +1148,9 @@ private void writeColumnMetaData(TDSWriter tdsWriter) throws SQLServerException /** * Helper method that throws a timeout exception if the cause of the exception was that the query was cancelled */ - private void checkForTimeoutException(SQLException e, BulkTimeoutTimer timeoutTimer) throws SQLServerException { + private void checkForTimeoutException(SQLException e, BulkTimeoutCommand timeoutCommand) throws SQLServerException { if ((null != e.getSQLState()) && (e.getSQLState().equals(SQLState.STATEMENT_CANCELED.getSQLStateCode())) - && timeoutTimer.expired()) { + && timeoutCommand.canTimeout()) { // If SQLServerBulkCopy is managing the transaction, a rollback is needed. if (copyOptions.isUseInternalTransaction()) { connection.rollback(); diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/TimeoutCommand.java b/src/main/java/com/microsoft/sqlserver/jdbc/TimeoutCommand.java new file mode 100644 index 0000000000..65b1f68b26 --- /dev/null +++ b/src/main/java/com/microsoft/sqlserver/jdbc/TimeoutCommand.java @@ -0,0 +1,41 @@ +/* + * Microsoft JDBC Driver for SQL Server Copyright(c) Microsoft Corporation All rights reserved. This program is made + * available under the terms of the MIT License. See the LICENSE file in the project root for more information. + */ + +package com.microsoft.sqlserver.jdbc; + +/** + * Abstract implementation of a command that can be timed out using the {@link TimeoutPoller} + */ +abstract class TimeoutCommand { + private final long startTime; + private final int timeout; + private final T command; + private final SQLServerConnection sqlServerConnection; + + TimeoutCommand(int timeout, T command, SQLServerConnection sqlServerConnection) { + this.timeout = timeout; + this.command = command; + this.sqlServerConnection = sqlServerConnection; + this.startTime = System.currentTimeMillis(); + } + + public boolean canTimeout() { + long currentTime = System.currentTimeMillis(); + return ((currentTime - startTime) / 1000) >= timeout; + } + + public T getCommand() { + return command; + } + + public SQLServerConnection getSqlServerConnection() { + return sqlServerConnection; + } + + /** + * The implementation for interrupting this timeout command + */ + public abstract void interrupt(); +} diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/TimeoutPoller.java b/src/main/java/com/microsoft/sqlserver/jdbc/TimeoutPoller.java new file mode 100644 index 0000000000..6c53d4d744 --- /dev/null +++ b/src/main/java/com/microsoft/sqlserver/jdbc/TimeoutPoller.java @@ -0,0 +1,82 @@ +/* + * Microsoft JDBC Driver for SQL Server Copyright(c) Microsoft Corporation All rights reserved. This program is made + * available under the terms of the MIT License. See the LICENSE file in the project root for more information. + */ + +package com.microsoft.sqlserver.jdbc; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.logging.Level; +import java.util.logging.Logger; + + +/** + * Thread that runs in the background while the mssql driver is used that can timeout TDSCommands Checks all registered + * commands every second to see if they can be interrupted + */ +final class TimeoutPoller implements Runnable { + private List> timeoutCommands = new ArrayList<>(); + final static Logger logger = Logger.getLogger("com.microsoft.sqlserver.jdbc.TimeoutPoller"); + private static volatile TimeoutPoller timeoutPoller = null; + + static TimeoutPoller getTimeoutPoller() { + if (timeoutPoller == null) { + synchronized (TimeoutPoller.class) { + if (timeoutPoller == null) { + // initialize the timeout poller thread once + timeoutPoller = new TimeoutPoller(); + // start the timeout polling thread + Thread pollerThread = new Thread(timeoutPoller, "mssql-jdbc-TimeoutPoller"); + pollerThread.setDaemon(true); + pollerThread.start(); + } + } + } + return timeoutPoller; + } + + void addTimeoutCommand(TimeoutCommand timeoutCommand) { + synchronized (timeoutCommands) { + timeoutCommands.add(timeoutCommand); + } + } + + void remove(TimeoutCommand timeoutCommand) { + synchronized (timeoutCommands) { + timeoutCommands.remove(timeoutCommand); + } + } + + private TimeoutPoller() {} + + public void run() { + try { + // Poll every second checking for commands that have timed out and need + // interruption + while (true) { + synchronized (timeoutCommands) { + Iterator> timeoutCommandIterator = timeoutCommands.iterator(); + while (timeoutCommandIterator.hasNext()) { + TimeoutCommand timeoutCommand = timeoutCommandIterator.next(); + try { + if (timeoutCommand.canTimeout()) { + try { + timeoutCommand.interrupt(); + } finally { + timeoutCommandIterator.remove(); + } + } + } catch (Exception e) { + logger.log(Level.WARNING, "Could not timeout command", e); + } + } + } + Thread.sleep(1000); + } + } catch (Exception e) { + logger.log(Level.SEVERE, "Error processing timeout commands", e); + } + } +} diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/bulkCopy/BulkCopyTimeoutTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/bulkCopy/BulkCopyTimeoutTest.java index c3497adf8f..76688ca9e0 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/bulkCopy/BulkCopyTimeoutTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/bulkCopy/BulkCopyTimeoutTest.java @@ -47,13 +47,13 @@ public void testZeroTimeOut() throws SQLException { */ @Test @DisplayName("BulkCopy:test negative timeout") - public void testNegativeTimeOut() throws SQLException { + public void testNegativeTimeOut() { assertThrows(SQLException.class, new org.junit.jupiter.api.function.Executable() { @Override public void execute() throws SQLException { testBulkCopyWithTimeout(-1); } - }); + }, "The timeout argument cannot be negative."); } private void testBulkCopyWithTimeout(int timeout) throws SQLException { diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/timeouts/TimeoutTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/timeouts/TimeoutTest.java new file mode 100644 index 0000000000..a61e097b17 --- /dev/null +++ b/src/test/java/com/microsoft/sqlserver/jdbc/timeouts/TimeoutTest.java @@ -0,0 +1,63 @@ +/* + * Microsoft JDBC Driver for SQL Server Copyright(c) Microsoft Corporation All rights reserved. This program is made + * available under the terms of the MIT License. See the LICENSE file in the project root for more information. + */ + +package com.microsoft.sqlserver.jdbc.timeouts; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.PreparedStatement; +import java.sql.SQLException; +import java.sql.SQLTimeoutException; + +import org.junit.Assert; +import org.junit.jupiter.api.Test; +import org.junit.platform.runner.JUnitPlatform; +import org.junit.runner.RunWith; + +import com.microsoft.sqlserver.testframework.AbstractTest; + + +@RunWith(JUnitPlatform.class) +public class TimeoutTest extends AbstractTest { + @Test + public void testBasicQueryTimeout() { + boolean exceptionThrown = false; + try { + // wait 1 minute and timeout after 10 seconds + Assert.assertTrue("Select succeeded", runQuery("WAITFOR DELAY '00:01'", 10)); + } catch (SQLException e) { + exceptionThrown = true; + Assert.assertTrue("Timeout exception not thrown", e.getClass().equals(SQLTimeoutException.class)); + } + Assert.assertTrue("A SQLTimeoutException was expected", exceptionThrown); + } + + @Test + public void testQueryTimeoutValid() { + boolean exceptionThrown = false; + int timeoutInSeconds = 10; + long start = System.currentTimeMillis(); + try { + // wait 1 minute and timeout after 10 seconds + Assert.assertTrue("Select succeeded", runQuery("WAITFOR DELAY '00:01'", timeoutInSeconds)); + } catch (SQLException e) { + int secondsElapsed = (int) ((System.currentTimeMillis() - start) / 1000); + Assert.assertTrue("Query did not timeout expected, elapsedTime=" + secondsElapsed, + secondsElapsed >= timeoutInSeconds); + exceptionThrown = true; + Assert.assertTrue("Timeout exception not thrown", e.getClass().equals(SQLTimeoutException.class)); + } + Assert.assertTrue("A SQLTimeoutException was expected", exceptionThrown); + } + + private boolean runQuery(String query, int timeout) throws SQLException { + try (Connection con = DriverManager.getConnection(connectionString); + PreparedStatement preparedStatement = con.prepareStatement(query)) { + // set provided timeout + preparedStatement.setQueryTimeout(timeout); + return preparedStatement.execute(); + } + } +}