diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlBulkCopy.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlBulkCopy.cs index 720bed016d5d..3cd94c72aebf 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlBulkCopy.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlBulkCopy.cs @@ -1841,7 +1841,10 @@ private Task WriteRowSourceToServerAsync(int columnCount, CancellationToken ctok } else { - AsyncHelper.ContinueTask(writeTask, tcs, () => tcs.SetResult(null)); + AsyncHelper.ContinueTaskWithState(writeTask, tcs, + state: tcs, + onSuccess: state => ((TaskCompletionSource)state).SetResult(null) + ); } }, ctoken); // We do not need to propagate exception, etc, from reconnect task, we just need to wait for it to finish. return tcs.Task; @@ -2153,17 +2156,17 @@ private Task CopyColumnsAsync(int col, TaskCompletionSource source = nul private void CopyColumnsAsyncSetupContinuation(TaskCompletionSource source, Task task, int i) { AsyncHelper.ContinueTask(task, source, () => - { - if (i + 1 < _sortedColumnMappings.Count) - { - CopyColumnsAsync(i + 1, source); //continue from the next column - } - else { - source.SetResult(null); + if (i + 1 < _sortedColumnMappings.Count) + { + CopyColumnsAsync(i + 1, source); //continue from the next column + } + else + { + source.SetResult(null); + } } - }, - _connection.GetOpenTdsConnection()); + ); } // The notification logic. @@ -2257,24 +2260,6 @@ private Task CheckForCancellation(CancellationToken cts, TaskCompletionSource ContinueTaskPend(Task task, TaskCompletionSource source, Func> action) - { - if (task == null) - { - return action(); - } - else - { - Debug.Assert(source != null, "source should already be initialized if task is not null"); - AsyncHelper.ContinueTask(task, source, () => - { - TaskCompletionSource newSource = action(); - Debug.Assert(newSource == null, "Shouldn't create a new source when one already exists"); - }); - } - return null; - } - // Copies all the rows in a batch. // Maintains state machine with state variable: rowSoFar. // Returned Task could be null in two cases: (1) _isAsyncBulkCopy == false, or (2) _isAsyncBulkCopy == true but all async writes finished synchronously. @@ -2315,7 +2300,7 @@ private Task CopyRowsAsync(int rowsSoFar, int totalRows, CancellationToken cts, } resultTask = source.Task; - AsyncHelper.ContinueTask(readTask, source, () => CopyRowsAsync(i + 1, totalRows, cts, source), connectionToDoom: _connection.GetOpenTdsConnection()); + AsyncHelper.ContinueTask(readTask, source, () => CopyRowsAsync(i + 1, totalRows, cts, source)); return resultTask; // Associated task will be completed when all rows are copied to server/exception/cancelled. } } @@ -2325,19 +2310,20 @@ private Task CopyRowsAsync(int rowsSoFar, int totalRows, CancellationToken cts, resultTask = source.Task; AsyncHelper.ContinueTask(task, source, onSuccess: () => - { - CheckAndRaiseNotification(); // Check for notification now as the current row copy is done at this moment. - - Task readTask = ReadFromRowSourceAsync(cts); - if (readTask == null) { - CopyRowsAsync(i + 1, totalRows, cts, source); - } - else - { - AsyncHelper.ContinueTask(readTask, source, onSuccess: () => CopyRowsAsync(i + 1, totalRows, cts, source), connectionToDoom: _connection.GetOpenTdsConnection()); + CheckAndRaiseNotification(); // Check for notification now as the current row copy is done at this moment. + + Task readTask = ReadFromRowSourceAsync(cts); + if (readTask == null) + { + CopyRowsAsync(i + 1, totalRows, cts, source); + } + else + { + AsyncHelper.ContinueTask(readTask, source, onSuccess: () => CopyRowsAsync(i + 1, totalRows, cts, source)); + } } - }, connectionToDoom: _connection.GetOpenTdsConnection()); + ); return resultTask; } } @@ -2406,15 +2392,17 @@ private Task CopyBatchesAsync(BulkCopySimpleResultSet internalResults, string up source = new TaskCompletionSource(); } - AsyncHelper.ContinueTask(commandTask, source, () => - { - Task continuedTask = CopyBatchesAsyncContinued(internalResults, updateBulkCommandText, cts, source); - if (continuedTask == null) + AsyncHelper.ContinueTask(commandTask, source, + () => { - // Continuation finished sync, recall into CopyBatchesAsync to continue - CopyBatchesAsync(internalResults, updateBulkCommandText, cts, source); + Task continuedTask = CopyBatchesAsyncContinued(internalResults, updateBulkCommandText, cts, source); + if (continuedTask == null) + { + // Continuation finished sync, recall into CopyBatchesAsync to continue + CopyBatchesAsync(internalResults, updateBulkCommandText, cts, source); + } } - }, _connection.GetOpenTdsConnection()); + ); return source.Task; } } @@ -2462,15 +2450,19 @@ private Task CopyBatchesAsyncContinued(BulkCopySimpleResultSet internalResults, { // First time only source = new TaskCompletionSource(); } - AsyncHelper.ContinueTask(task, source, () => - { - Task continuedTask = CopyBatchesAsyncContinuedOnSuccess(internalResults, updateBulkCommandText, cts, source); - if (continuedTask == null) + AsyncHelper.ContinueTask(task, source, + onSuccess: () => { - // Continuation finished sync, recall into CopyBatchesAsync to continue - CopyBatchesAsync(internalResults, updateBulkCommandText, cts, source); - } - }, _connection.GetOpenTdsConnection(), _ => CopyBatchesAsyncContinuedOnError(cleanupParser: false), () => CopyBatchesAsyncContinuedOnError(cleanupParser: true)); + Task continuedTask = CopyBatchesAsyncContinuedOnSuccess(internalResults, updateBulkCommandText, cts, source); + if (continuedTask == null) + { + // Continuation finished sync, recall into CopyBatchesAsync to continue + CopyBatchesAsync(internalResults, updateBulkCommandText, cts, source); + } + }, + onFailure: (_) => CopyBatchesAsyncContinuedOnError(cleanupParser: false), + onCancellation: () => CopyBatchesAsyncContinuedOnError(cleanupParser: true) + ); return source.Task; } @@ -2517,22 +2509,25 @@ private Task CopyBatchesAsyncContinuedOnSuccess(BulkCopySimpleResultSet internal source = new TaskCompletionSource(); } - AsyncHelper.ContinueTask(writeTask, source, () => - { - try - { - RunParser(); - CommitTransaction(); - } - catch (Exception) + AsyncHelper.ContinueTask(writeTask, source, + onSuccess: () => { - CopyBatchesAsyncContinuedOnError(cleanupParser: false); - throw; - } + try + { + RunParser(); + CommitTransaction(); + } + catch (Exception) + { + CopyBatchesAsyncContinuedOnError(cleanupParser: false); + throw; + } - // Always call back into CopyBatchesAsync - CopyBatchesAsync(internalResults, updateBulkCommandText, cts, source); - }, connectionToDoom: _connection.GetOpenTdsConnection(), onFailure: _ => CopyBatchesAsyncContinuedOnError(cleanupParser: false)); + // Always call back into CopyBatchesAsync + CopyBatchesAsync(internalResults, updateBulkCommandText, cts, source); + }, + onFailure: (_) => CopyBatchesAsyncContinuedOnError(cleanupParser: false) + ); return source.Task; } } @@ -2651,48 +2646,50 @@ private void WriteToServerInternalRestContinuedAsync(BulkCopySimpleResultSet int { source = new TaskCompletionSource(); } - AsyncHelper.ContinueTask(task, source, () => - { - // Bulk copy task is completed at this moment. - if (task.IsCanceled) + AsyncHelper.ContinueTask(task, source, + () => { - _localColumnMappings = null; - try + // Bulk copy task is completed at this moment. + if (task.IsCanceled) { - CleanUpStateObjectOnError(); + _localColumnMappings = null; + try + { + CleanUpStateObjectOnError(); + } + finally + { + source.SetCanceled(); + } } - finally + else if (task.Exception != null) { - source.SetCanceled(); + source.SetException(task.Exception.InnerException); } - } - else if (task.Exception != null) - { - source.SetException(task.Exception.InnerException); - } - else - { - _localColumnMappings = null; - try - { - CleanUpStateObjectOnError(); - } - finally + else { - if (source != null) + _localColumnMappings = null; + try { - if (cts.IsCancellationRequested) - { // We may get cancellation req even after the entire copy. - source.SetCanceled(); - } - else + CleanUpStateObjectOnError(); + } + finally + { + if (source != null) { - source.SetResult(null); + if (cts.IsCancellationRequested) + { // We may get cancellation req even after the entire copy. + source.SetCanceled(); + } + else + { + source.SetResult(null); + } } } } } - }, _connection.GetOpenTdsConnection()); + ); return; } else @@ -2782,12 +2779,15 @@ private void WriteToServerInternalRestAsync(CancellationToken cts, TaskCompletio { regReconnectCancel = cts.Register(s => ((TaskCompletionSource)s).TrySetCanceled(), cancellableReconnectTS); } - AsyncHelper.ContinueTask(reconnectTask, cancellableReconnectTS, () => { cancellableReconnectTS.SetResult(null); }); + AsyncHelper.ContinueTaskWithState(reconnectTask, cancellableReconnectTS, + state: cancellableReconnectTS, + onSuccess: (state) => { ((TaskCompletionSource)state).SetResult(null); } + ); // No need to cancel timer since SqlBulkCopy creates specific task source for reconnection. AsyncHelper.SetTimeoutException(cancellableReconnectTS, BulkCopyTimeout, () => { return SQL.BulkLoadInvalidDestinationTable(_destinationTableName, SQL.CR_ReconnectTimeout()); }, CancellationToken.None); AsyncHelper.ContinueTask(cancellableReconnectTS.Task, source, - () => + onSuccess: () => { regReconnectCancel.Dispose(); if (_parserLock != null) @@ -2799,7 +2799,6 @@ private void WriteToServerInternalRestAsync(CancellationToken cts, TaskCompletio _parserLock.Wait(canReleaseFromAnyThread: true); WriteToServerInternalRestAsync(cts, source); }, - connectionToAbort: _connection, onFailure: (e) => { regReconnectCancel.Dispose(); }, onCancellation: () => { regReconnectCancel.Dispose(); }, exceptionConverter: (ex) => SQL.BulkLoadInvalidDestinationTable(_destinationTableName, ex)); @@ -2850,7 +2849,7 @@ private void WriteToServerInternalRestAsync(CancellationToken cts, TaskCompletio if (internalResultsTask != null) { - AsyncHelper.ContinueTask(internalResultsTask, source, () => WriteToServerInternalRestContinuedAsync(internalResultsTask.Result, cts, source), _connection.GetOpenTdsConnection()); + AsyncHelper.ContinueTask(internalResultsTask, source, () => WriteToServerInternalRestContinuedAsync(internalResultsTask.Result, cts, source)); } else { @@ -2921,17 +2920,19 @@ private Task WriteToServerInternalAsync(CancellationToken ctoken) else { Debug.Assert(_isAsyncBulkCopy, "Read must not return a Task in the Sync mode"); - AsyncHelper.ContinueTask(readTask, source, () => - { - if (!_hasMoreRowToCopy) - { - source.SetResult(null); // No rows to copy! - } - else + AsyncHelper.ContinueTask(readTask, source, + () => { - WriteToServerInternalRestAsync(ctoken, source); // Passing the same completion which will be completed by the Callee. + if (!_hasMoreRowToCopy) + { + source.SetResult(null); // No rows to copy! + } + else + { + WriteToServerInternalRestAsync(ctoken, source); // Passing the same completion which will be completed by the Callee. + } } - }, _connection.GetOpenTdsConnection()); + ); return resultTask; } } diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlCommand.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlCommand.cs index f39d2643c5d4..6c091e174309 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlCommand.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlCommand.cs @@ -1293,7 +1293,10 @@ public IAsyncResult BeginExecuteXmlReader(AsyncCallback callback, object stateOb // Add callback after work is done to avoid overlapping Begin\End methods if (callback != null) { - completion.Task.ContinueWith((t) => callback(t), TaskScheduler.Default); + completion.Task.ContinueWith( + (task,state) => ((AsyncCallback)state)(task), + state: callback + ); } return completion.Task; } @@ -1577,7 +1580,10 @@ public IAsyncResult BeginExecuteReader(AsyncCallback callback, object stateObjec // Add callback after work is done to avoid overlapping Begin\End methods if (callback != null) { - completion.Task.ContinueWith((t) => callback(t), TaskScheduler.Default); + completion.Task.ContinueWith( + (task,state) => ((AsyncCallback)state)(task), + state: callback + ); } return completion.Task; } @@ -2441,9 +2447,13 @@ private void RunExecuteNonQueryTdsSetupReconnnectContinuation(string methodName, } else { - AsyncHelper.ContinueTask(subTask, completion, () => completion.SetResult(null)); + AsyncHelper.ContinueTaskWithState(subTask, completion, + state: completion, + onSuccess: (state) => ((TaskCompletionSource)state).SetResult(null) + ); } - }, connectionToAbort: _activeConnection); + } + ); } internal SqlDataReader RunExecuteReader(CommandBehavior cmdBehavior, RunBehavior runBehavior, bool returnStream, [CallerMemberName] string method = "") @@ -2692,15 +2702,17 @@ private SqlDataReader RunExecuteReaderTds(CommandBehavior cmdBehavior, RunBehavi // This is in its own method to avoid always allocating the lambda in RunExecuteReaderTds private Task RunExecuteReaderTdsSetupContinuation(RunBehavior runBehavior, SqlDataReader ds, string optionSettings, Task writeTask) { - Task task = AsyncHelper.CreateContinuationTask(writeTask, () => - { - _activeConnection.GetOpenTdsConnection(); // it will throw if connection is closed - cachedAsyncState.SetAsyncReaderState(ds, runBehavior, optionSettings); - }, - onFailure: (exc) => - { - _activeConnection.GetOpenTdsConnection().DecrementAsyncCount(); - }); + Task task = AsyncHelper.CreateContinuationTask(writeTask, + onSuccess: () => + { + _activeConnection.GetOpenTdsConnection(); // it will throw if connection is closed + cachedAsyncState.SetAsyncReaderState(ds, runBehavior, optionSettings); + }, + onFailure: (exc) => + { + _activeConnection.GetOpenTdsConnection().DecrementAsyncCount(); + } + ); return task; } @@ -2726,9 +2738,12 @@ private void RunExecuteReaderTdsSetupReconnectContinuation(CommandBehavior cmdBe } else { - AsyncHelper.ContinueTask(subTask, completion, () => completion.SetResult(null)); + AsyncHelper.ContinueTaskWithState(subTask, completion, + state: completion, + onSuccess: (state) => ((TaskCompletionSource)state).SetResult(null) + ); } - }, connectionToAbort: _activeConnection + } ); } diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlUtil.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlUtil.cs index 2882b3ea6d62..461e5c0715ac 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlUtil.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlUtil.cs @@ -17,7 +17,7 @@ namespace System.Data.SqlClient { internal static class AsyncHelper { - internal static Task CreateContinuationTask(Task task, Action onSuccess, SqlInternalConnectionTds connectionToDoom = null, Action onFailure = null) + internal static Task CreateContinuationTask(Task task, Action onSuccess, Action onFailure = null) { if (task == null) { @@ -27,9 +27,22 @@ internal static Task CreateContinuationTask(Task task, Action onSuccess, SqlInte else { TaskCompletionSource completion = new TaskCompletionSource(); - ContinueTask(task, completion, - () => { onSuccess(); completion.SetResult(null); }, - connectionToDoom, onFailure); + ContinueTaskWithState(task, completion, + state: Tuple.Create(onSuccess, onFailure,completion), + onSuccess: (state) => { + var parameters = (Tuple, TaskCompletionSource>)state; + Action success = parameters.Item1; + TaskCompletionSource taskCompletionSource = parameters.Item3; + success(); + taskCompletionSource.SetResult(null); + }, + onFailure: (exception,state) => + { + var parameters = (Tuple, TaskCompletionSource>)state; + Action failure = parameters.Item2; + failure?.Invoke(exception); + } + ); return completion.Task; } } @@ -45,30 +58,30 @@ internal static Task CreateContinuationTaskWithState(Task task, object state, Ac { var completion = new TaskCompletionSource(); ContinueTaskWithState(task, completion, state, - onSuccess: (continueState) => { onSuccess(continueState); completion.SetResult(null); }, + onSuccess: (continueState) => { + onSuccess(continueState); + completion.SetResult(null); + }, onFailure: onFailure ); return completion.Task; } } - internal static Task CreateContinuationTask(Task task, Action onSuccess, T1 arg1, T2 arg2, SqlInternalConnectionTds connectionToDoom = null, Action onFailure = null) + internal static Task CreateContinuationTask(Task task, Action onSuccess, T1 arg1, T2 arg2, Action onFailure = null) { - return CreateContinuationTask(task, () => onSuccess(arg1, arg2), connectionToDoom, onFailure); + return CreateContinuationTask(task, () => onSuccess(arg1, arg2), onFailure); } internal static void ContinueTask(Task task, TaskCompletionSource completion, Action onSuccess, - SqlInternalConnectionTds connectionToDoom = null, Action onFailure = null, Action onCancellation = null, - Func exceptionConverter = null, - SqlConnection connectionToAbort = null + Func exceptionConverter = null ) { - Debug.Assert((connectionToAbort == null) || (connectionToDoom == null), "Should not specify both connectionToDoom and connectionToAbort"); task.ContinueWith( tsk => { @@ -81,10 +94,7 @@ internal static Task CreateContinuationTaskWithState(Task task, object state, Ac } try { - if (onFailure != null) - { - onFailure(exc); - } + onFailure?.Invoke(exc); } finally { @@ -95,10 +105,7 @@ internal static Task CreateContinuationTaskWithState(Task task, object state, Ac { try { - if (onCancellation != null) - { - onCancellation(); - } + onCancellation?.Invoke(); } finally { diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParser.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParser.cs index fc5f1b911989..50e8a934663f 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParser.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParser.cs @@ -7540,10 +7540,11 @@ internal Task TdsExecuteSQLBatch(string text, int timeout, SqlNotificationReques // Take care of releasing the locks if (releaseConnectionLock) { - task.ContinueWith(_ => - { - _connHandler._parserLock.Release(); - }, TaskScheduler.Default); + task.ContinueWith( + (_, state) => ((SqlInternalConnectionTds)state)._parserLock.Release(), + state: _connHandler, + TaskScheduler.Default + ); releaseConnectionLock = false; } @@ -7652,7 +7653,6 @@ private void TDSExecuteRPCParameterSetupWriteCompletion(_SqlRPC[] rpcArray, int startRpc, startParam ), - connectionToDoom: _connHandler, onFailure: exc => TdsExecuteRPC_OnFailure(exc, stateObj) ); } @@ -8684,9 +8684,7 @@ private Task GetTerminationTask(Task unterminatedWriteTask, object value, MetaTy } else { - return AsyncHelper.CreateContinuationTask(unterminatedWriteTask, - WriteInt, 0, stateObj, - connectionToDoom: _connHandler); + return AsyncHelper.CreateContinuationTask(unterminatedWriteTask, WriteInt, 0, stateObj); } } else diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObject.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObject.cs index e6ba9a6674a2..671829dacb2f 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObject.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObject.cs @@ -3115,8 +3115,7 @@ private Task WriteBytes(ReadOnlySpan b, int len, int offsetBuffer, bool ca private void WriteBytesSetupContinuation(byte[] array, int len, TaskCompletionSource completion, int offset, Task packetTask) { AsyncHelper.ContinueTask(packetTask, completion, - () => WriteBytes(ReadOnlySpan.Empty, len: len, offsetBuffer: offset, canAccumulate: false, completion: completion, array), - connectionToDoom: _parser.Connection + onSuccess: () => WriteBytes(ReadOnlySpan.Empty, len: len, offsetBuffer: offset, canAccumulate: false, completion: completion, array) ); } @@ -3187,7 +3186,7 @@ internal Task WritePacket(byte flushMode, bool canAccumulate = false) if (willCancel) { // If we have been cancelled, then ensure that we write the ATTN packet as well - task = AsyncHelper.CreateContinuationTask(task, CancelWritePacket, _parser.Connection); + task = AsyncHelper.CreateContinuationTask(task, CancelWritePacket); } return task;