diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs index 29ce7ec2e5..6ae23e11e6 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs @@ -70,6 +70,8 @@ public sealed partial class SqlCommand : DbCommand, ICloneable private static readonly DiagnosticListener _diagnosticListener = new DiagnosticListener(SqlClientDiagnosticListenerExtensions.DiagnosticListenerName); private bool _parentOperationStarted = false; + internal static readonly Action s_cancelIgnoreFailure = CancelIgnoreFailureCallback; + // Prepare // Against 7.0 Serve a prepare/unprepare requires an extra roundtrip to the server. // @@ -2137,7 +2139,7 @@ public override Task ExecuteNonQueryAsync(CancellationToken cancellationTok source.SetCanceled(); return source.Task; } - registration = cancellationToken.Register(s => ((SqlCommand)s).CancelIgnoreFailure(), this); + registration = cancellationToken.Register(s_cancelIgnoreFailure, this); } Task returnedTask = source.Task; @@ -2225,7 +2227,7 @@ protected override Task ExecuteDbDataReaderAsync(CommandBehavior b source.SetCanceled(); return source.Task; } - registration = cancellationToken.Register(s => ((SqlCommand)s).CancelIgnoreFailure(), this); + registration = cancellationToken.Register(s_cancelIgnoreFailure, this); } Task returnedTask = source.Task; @@ -2373,7 +2375,7 @@ public Task ExecuteXmlReaderAsync(CancellationToken cancellationToken source.SetCanceled(); return source.Task; } - registration = cancellationToken.Register(s => ((SqlCommand)s).CancelIgnoreFailure(), this); + registration = cancellationToken.Register(s_cancelIgnoreFailure, this); } Task returnedTask = source.Task; @@ -5947,6 +5949,11 @@ internal void CompletePendingReadWithFailure(int errorCode, bool resetForcePendi } } #endif + internal static void CancelIgnoreFailureCallback(object state) + { + SqlCommand command = (SqlCommand)state; + command.CancelIgnoreFailure(); + } internal void CancelIgnoreFailure() { diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs index 5fadbae100..c3e2ed9e27 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs @@ -97,6 +97,8 @@ internal class SharedState private SqlSequentialStream _currentStream; private SqlSequentialTextReader _currentTextReader; + private IsDBNullAsyncCallContext _cachedIsDBNullContext; + private ReadAsyncCallContext _cachedReadAsyncContext; internal SqlDataReader(SqlCommand command, CommandBehavior behavior) { @@ -4263,7 +4265,7 @@ public override Task NextResultAsync(CancellationToken cancellationToken) source.SetCanceled(); return source.Task; } - registration = cancellationToken.Register(s => ((SqlCommand)s).CancelIgnoreFailure(), _command); + registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command); } Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null); @@ -4281,35 +4283,31 @@ public override Task NextResultAsync(CancellationToken cancellationToken) return source.Task; } - PrepareAsyncInvocation(useSnapshot: true); - - Func> moreFunc = null; + return InvokeAsyncCall(new HasNextResultAsyncCallContext(this, source, registration)); + } - moreFunc = (t) => + private static Task NextResultAsyncExecute(Task task, object state) + { + HasNextResultAsyncCallContext context = (HasNextResultAsyncCallContext)state; + if (task != null) { - if (t != null) - { - PrepareForAsyncContinuation(); - } - - bool more; - if (TryNextResult(out more)) - { - // completed - return more ? ADP.TrueTask : ADP.FalseTask; - } + context._reader.PrepareForAsyncContinuation(); + } - return ContinueRetryable(moreFunc); - }; + if (context._reader.TryNextResult(out bool more)) + { + // completed + return more ? ADP.TrueTask : ADP.FalseTask; + } - return InvokeRetryable(moreFunc, source, registration); + return context._reader.ExecuteAsyncCall(context); } // NOTE: This will return null if it completed sequentially // If this returns null, then you can use bytesRead to see how many bytes were read - otherwise bytesRead should be ignored - internal Task GetBytesAsync(int i, byte[] buffer, int index, int length, int timeout, CancellationToken cancellationToken, out int bytesRead) + internal Task GetBytesAsync(int columnIndex, byte[] buffer, int index, int length, int timeout, CancellationToken cancellationToken, out int bytesRead) { - AssertReaderState(requireData: true, permitAsync: true, columnIndex: i, enforceSequentialAccess: true); + AssertReaderState(requireData: true, permitAsync: true, columnIndex: columnIndex, enforceSequentialAccess: true); Debug.Assert(IsCommandBehavior(CommandBehavior.SequentialAccess)); bytesRead = 0; @@ -4331,6 +4329,16 @@ internal Task GetBytesAsync(int i, byte[] buffer, int index, int length, in } } + var context = new GetBytesAsyncCallContext(this) + { + columnIndex = columnIndex, + buffer = buffer, + index = index, + length = length, + timeout = timeout, + cancellationToken = cancellationToken, + }; + // Check if we need to skip columns Debug.Assert(_sharedState._nextColumnDataToRead <= _lastColumnWithDataChunkRead, "Non sequential access"); if ((_sharedState._nextColumnHeaderToRead <= _lastColumnWithDataChunkRead) || (_sharedState._nextColumnDataToRead < _lastColumnWithDataChunkRead)) @@ -4343,10 +4351,6 @@ internal Task GetBytesAsync(int i, byte[] buffer, int index, int length, in return source.Task; } - PrepareAsyncInvocation(useSnapshot: true); - - Func> moreFunc = null; - // Timeout CancellationToken timeoutToken = CancellationToken.None; CancellationTokenSource timeoutCancellationSource = null; @@ -4357,65 +4361,25 @@ internal Task GetBytesAsync(int i, byte[] buffer, int index, int length, in timeoutToken = timeoutCancellationSource.Token; } - moreFunc = (t) => - { - if (t != null) - { - PrepareForAsyncContinuation(); - } - - // Prepare for stateObj timeout - SetTimeout(_defaultTimeoutMilliseconds); + context._disposable = timeoutCancellationSource; + context.timeoutToken = timeoutToken; + context._source = source; - if (TryReadColumnHeader(i)) - { - // Only once we have read up to where we need to be can we check the cancellation tokens (otherwise we will be in an unknown state) - - if (cancellationToken.IsCancellationRequested) - { - // User requested cancellation - return Task.FromCanceled(cancellationToken); - } - else if (timeoutToken.IsCancellationRequested) - { - // Timeout - return Task.FromException(ADP.ExceptionWithStackTrace(ADP.IO(SQLMessage.Timeout()))); - } - else - { - // Up to the correct column - continue to read - SwitchToAsyncWithoutSnapshot(); - int totalBytesRead; - var readTask = GetBytesAsyncReadDataStage(i, buffer, index, length, timeout, true, cancellationToken, timeoutToken, out totalBytesRead); - if (readTask == null) - { - // Completed synchronously - return Task.FromResult(totalBytesRead); - } - else - { - return readTask; - } - } - } - else - { - return ContinueRetryable(moreFunc); - } - }; + PrepareAsyncInvocation(useSnapshot: true); - return InvokeRetryable(moreFunc, source, timeoutCancellationSource); + return InvokeAsyncCall(context); } else { // We're already at the correct column, just read the data + context.mode = GetBytesAsyncCallContext.OperationMode.Read; // Switch to async PrepareAsyncInvocation(useSnapshot: false); try { - return GetBytesAsyncReadDataStage(i, buffer, index, length, timeout, false, cancellationToken, CancellationToken.None, out bytesRead); + return GetBytesAsyncReadDataStage(context, false, out bytesRead); } catch { @@ -4425,17 +4389,117 @@ internal Task GetBytesAsync(int i, byte[] buffer, int index, int length, in } } - private Task GetBytesAsyncReadDataStage(int i, byte[] buffer, int index, int length, int timeout, bool isContinuation, CancellationToken cancellationToken, CancellationToken timeoutToken, out int bytesRead) + private static Task GetBytesAsyncSeekExecute(Task task, object state) + { + GetBytesAsyncCallContext context = (GetBytesAsyncCallContext)state; + SqlDataReader reader = context._reader; + + Debug.Assert(context.mode == GetBytesAsyncCallContext.OperationMode.Seek, "context.mode must be Seek to check if seeking can resume"); + + if (task != null) + { + reader.PrepareForAsyncContinuation(); + } + + // Prepare for stateObj timeout + reader.SetTimeout(reader._defaultTimeoutMilliseconds); + + if (reader.TryReadColumnHeader(context.columnIndex)) + { + // Only once we have read up to where we need to be can we check the cancellation tokens (otherwise we will be in an unknown state) + + if (context.cancellationToken.IsCancellationRequested) + { + // User requested cancellation + return Task.FromCanceled(context.cancellationToken); + } + else if (context.timeoutToken.IsCancellationRequested) + { + // Timeout + return Task.FromException(ADP.ExceptionWithStackTrace(ADP.IO(SQLMessage.Timeout()))); + } + else + { + // Up to the correct column - continue to read + context.mode = GetBytesAsyncCallContext.OperationMode.Read; + reader.SwitchToAsyncWithoutSnapshot(); + int totalBytesRead; + var readTask = reader.GetBytesAsyncReadDataStage(context, true, out totalBytesRead); + if (readTask == null) + { + // Completed synchronously + return Task.FromResult(totalBytesRead); + } + else + { + return readTask; + } + } + } + else + { + return reader.ExecuteAsyncCall(context); + } + } + + private static Task GetBytesAsyncReadExecute(Task task, object state) { - _lastColumnWithDataChunkRead = i; + var context = (GetBytesAsyncCallContext)state; + SqlDataReader reader = context._reader; + + Debug.Assert(context.mode == GetBytesAsyncCallContext.OperationMode.Read, "context.mode must be Read to check if read can resume"); + + reader.PrepareForAsyncContinuation(); + + if (context.cancellationToken.IsCancellationRequested) + { + // User requested cancellation + return Task.FromCanceled(context.cancellationToken); + } + else if (context.timeoutToken.IsCancellationRequested) + { + // Timeout + return Task.FromException(ADP.ExceptionWithStackTrace(ADP.IO(SQLMessage.Timeout()))); + } + else + { + // Prepare for stateObj timeout + reader.SetTimeout(reader._defaultTimeoutMilliseconds); + + int bytesReadThisIteration; + bool result = reader.TryGetBytesInternalSequential( + context.columnIndex, + context.buffer, + context.index + context.totalBytesRead, + context.length - context.totalBytesRead, + out bytesReadThisIteration + ); + context.totalBytesRead += bytesReadThisIteration; + Debug.Assert(context.totalBytesRead <= context.length, "Read more bytes than required"); + + if (result) + { + return Task.FromResult(context.totalBytesRead); + } + else + { + return reader.ExecuteAsyncCall(context); + } + } + } + + private Task GetBytesAsyncReadDataStage(GetBytesAsyncCallContext context, bool isContinuation, out int bytesRead) + { + Debug.Assert(context.mode == GetBytesAsyncCallContext.OperationMode.Read, "context.Mode must be Read to read data"); + + _lastColumnWithDataChunkRead = context.columnIndex; TaskCompletionSource source = null; - CancellationTokenSource timeoutCancellationSource = null; // Prepare for stateObj timeout SetTimeout(_defaultTimeoutMilliseconds); // Try to read without any continuations (all the data may already be in the stateObj's buffer) - if (!TryGetBytesInternalSequential(i, buffer, index, length, out bytesRead)) + if (!TryGetBytesInternalSequential(context.columnIndex, context.buffer, context.index, context.length, out bytesRead)) { // This will be the 'state' for the callback int totalBytesRead = bytesRead; @@ -4460,52 +4524,18 @@ private Task GetBytesAsyncReadDataStage(int i, byte[] buffer, int index, in } // Timeout - Debug.Assert(timeoutToken == CancellationToken.None, "TimeoutToken is set when GetBytesAsyncReadDataStage is not a continuation"); - if (timeout > 0) + Debug.Assert(context.timeoutToken == CancellationToken.None, "TimeoutToken is set when GetBytesAsyncReadDataStage is not a continuation"); + if (context.timeout > 0) { - timeoutCancellationSource = new CancellationTokenSource(); - timeoutCancellationSource.CancelAfter(timeout); - timeoutToken = timeoutCancellationSource.Token; + CancellationTokenSource timeoutCancellationSource = new CancellationTokenSource(); + timeoutCancellationSource.CancelAfter(context.timeout); + Debug.Assert(context._disposable is null, "setting context.disposable would lose the previous dispoable"); + context._disposable = timeoutCancellationSource; + context.timeoutToken = timeoutCancellationSource.Token; } } - Func> moreFunc = null; - moreFunc = (_ => - { - PrepareForAsyncContinuation(); - - if (cancellationToken.IsCancellationRequested) - { - // User requested cancellation - return Task.FromCanceled(cancellationToken); - } - else if (timeoutToken.IsCancellationRequested) - { - // Timeout - return Task.FromException(ADP.ExceptionWithStackTrace(ADP.IO(SQLMessage.Timeout()))); - } - else - { - // Prepare for stateObj timeout - SetTimeout(_defaultTimeoutMilliseconds); - - int bytesReadThisIteration; - bool result = TryGetBytesInternalSequential(i, buffer, index + totalBytesRead, length - totalBytesRead, out bytesReadThisIteration); - totalBytesRead += bytesReadThisIteration; - Debug.Assert(totalBytesRead <= length, "Read more bytes than required"); - - if (result) - { - return Task.FromResult(totalBytesRead); - } - else - { - return ContinueRetryable(moreFunc); - } - } - }); - - Task retryTask = ContinueRetryable(moreFunc); + Task retryTask = ExecuteAsyncCall(context); if (isContinuation) { // Let the caller handle cleanup\completing @@ -4513,8 +4543,13 @@ private Task GetBytesAsyncReadDataStage(int i, byte[] buffer, int index, in } else { + Debug.Assert(context._source != null, "context.source shuld not be null when continuing"); // setup for cleanup\completing - retryTask.ContinueWith((t) => CompleteRetryable(t, source, timeoutCancellationSource), TaskScheduler.Default); + retryTask.ContinueWith( + continuationAction: AAsyncCallContext.s_completeCallback, + state: context, + TaskScheduler.Default + ); return source.Task; } } @@ -4637,54 +4672,71 @@ public override Task ReadAsync(CancellationToken cancellationToken) IDisposable registration = null; if (cancellationToken.CanBeCanceled) { - registration = cancellationToken.Register(s => ((SqlCommand)s).CancelIgnoreFailure(), _command); + registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command); } + var context = Interlocked.Exchange(ref _cachedReadAsyncContext, null) ?? new ReadAsyncCallContext(); + + Debug.Assert(context._reader == null && context._source == null && context._disposable == null, "cached ReadAsyncCallContext was not properly disposed"); + + context.Set(this, source, registration); + context._hasMoreData = more; + context._hasReadRowToken = rowTokenRead; + PrepareAsyncInvocation(useSnapshot: true); - Func> moreFunc = null; - moreFunc = (t) => + return InvokeAsyncCall(context); + } + + private static Task ReadAsyncExecute(Task task, object state) + { + var context = (ReadAsyncCallContext)state; + SqlDataReader reader = context._reader; + ref bool hasMoreData = ref context._hasMoreData; + ref bool hasReadRowToken = ref context._hasReadRowToken; + + if (task != null) { - if (t != null) + reader.PrepareForAsyncContinuation(); + } + + if (hasReadRowToken || reader.TryReadInternal(true, out hasMoreData)) + { + // If there are no more rows, or this is Sequential Access, then we are done + if (!hasMoreData || (reader._commandBehavior & CommandBehavior.SequentialAccess) == CommandBehavior.SequentialAccess) { - PrepareForAsyncContinuation(); + // completed + return hasMoreData ? ADP.TrueTask : ADP.FalseTask; } - - if (rowTokenRead || TryReadInternal(true, out more)) + else { - // If there are no more rows, or this is Sequential Access, then we are done - if (!more || (_commandBehavior & CommandBehavior.SequentialAccess) == CommandBehavior.SequentialAccess) - { - // completed - return more ? ADP.TrueTask : ADP.FalseTask; - } - else + // First time reading the row token - update the snapshot + if (!hasReadRowToken) { - // First time reading the row token - update the snapshot - if (!rowTokenRead) + hasReadRowToken = true; + if (reader._cachedSnapshot is null) { - rowTokenRead = true; - if (_cachedSnapshot is null) - { - _cachedSnapshot = _snapshot; - } - _snapshot = null; - PrepareAsyncInvocation(useSnapshot: true); + reader._cachedSnapshot = reader._snapshot; } + reader._snapshot = null; + reader.PrepareAsyncInvocation(useSnapshot: true); + } - // if non-sequentialaccess then read entire row before returning - if (TryReadColumn(_metaData.Length - 1, true)) - { - // completed - return ADP.TrueTask; - } + // if non-sequentialaccess then read entire row before returning + if (reader.TryReadColumn(reader._metaData.Length - 1, true)) + { + // completed + return ADP.TrueTask; } } + } - return ContinueRetryable(moreFunc); - }; + return reader.ExecuteAsyncCall(context); + } - return InvokeRetryable(moreFunc, source, registration); + private void SetCachedReadAsyncCallContext(ReadAsyncCallContext instance) + { + Interlocked.CompareExchange(ref _cachedReadAsyncContext, instance, null); } /// @@ -4782,36 +4834,48 @@ override public Task IsDBNullAsync(int i, CancellationToken cancellationTo IDisposable registration = null; if (cancellationToken.CanBeCanceled) { - registration = cancellationToken.Register(s => ((SqlCommand)s).CancelIgnoreFailure(), _command); + registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command); } + IsDBNullAsyncCallContext context = Interlocked.Exchange(ref _cachedIsDBNullContext, null) ?? new IsDBNullAsyncCallContext(); + + Debug.Assert(context._reader == null && context._source == null && context._disposable == null, "cached ISDBNullAsync context not properly disposed"); + + context.Set(this, source, registration); + context._columnIndex = i; + // Setup async PrepareAsyncInvocation(useSnapshot: true); - // Setup the retryable function - Func> moreFunc = null; - moreFunc = (t) => - { - if (t != null) - { - PrepareForAsyncContinuation(); - } + return InvokeAsyncCall(context); + } + } - if (TryReadColumnHeader(i)) - { - return _data[i].IsNull ? ADP.TrueTask : ADP.FalseTask; - } - else - { - return ContinueRetryable(moreFunc); - } - }; + private static Task IsDBNullAsyncExecute(Task task, object state) + { + IsDBNullAsyncCallContext context = (IsDBNullAsyncCallContext)state; + SqlDataReader reader = context._reader; + + if (task != null) + { + reader.PrepareForAsyncContinuation(); + } - // Go! - return InvokeRetryable(moreFunc, source, registration); + if (reader.TryReadColumnHeader(context._columnIndex)) + { + return reader._data[context._columnIndex].IsNull ? ADP.TrueTask : ADP.FalseTask; + } + else + { + return reader.ExecuteAsyncCall(context); } } + private void SetCachedIDBNullAsyncCallContext(IsDBNullAsyncCallContext instance) + { + Interlocked.CompareExchange(ref _cachedIsDBNullContext, instance, null); + } + /// override public Task GetFieldValueAsync(int i, CancellationToken cancellationToken) { @@ -4912,27 +4976,27 @@ override public Task GetFieldValueAsync(int i, CancellationToken cancellat // Setup async PrepareAsyncInvocation(useSnapshot: true); - // Setup the retryable function - Func> moreFunc = null; - moreFunc = (t) => - { - if (t != null) - { - PrepareForAsyncContinuation(); - } + return InvokeAsyncCall(new GetFieldValueAsyncCallContext(this, source, registration, i)); + } - if (TryReadColumn(i, setTimeout: false)) - { - return Task.FromResult(GetFieldValueFromSqlBufferInternal(_data[i], _metaData[i])); - } - else - { - return ContinueRetryable(moreFunc); - } - }; + private static Task GetFieldValueAsyncExecute(Task task, object state) + { + GetFieldValueAsyncCallContext context = (GetFieldValueAsyncCallContext)state; + SqlDataReader reader = context._reader; + int columnIndex = context._columnIndex; + if (task != null) + { + reader.PrepareForAsyncContinuation(); + } - // Go! - return InvokeRetryable(moreFunc, source, registration); + if (reader.TryReadColumn(columnIndex, setTimeout: false)) + { + return Task.FromResult(reader.GetFieldValueFromSqlBufferInternal(reader._data[columnIndex], reader._metaData[columnIndex])); + } + else + { + return reader.ExecuteAsyncCall(context); + } } #if DEBUG @@ -4978,79 +5042,174 @@ private class Snapshot public SqlSequentialTextReader _currentTextReader; } - private Task ContinueRetryable(Func> moreFunc) + private abstract class AAsyncCallContext : IDisposable { - // _networkPacketTaskSource could be null if the connection was closed - // while an async invocation was outstanding. - TaskCompletionSource completionSource = _stateObj._networkPacketTaskSource; - if (_cancelAsyncOnCloseToken.IsCancellationRequested || completionSource == null) + internal static readonly Action, object> s_completeCallback = SqlDataReader.CompleteAsyncCallCallback; + + internal static readonly Func> s_executeCallback = SqlDataReader.ExecuteAsyncCallCallback; + + internal SqlDataReader _reader; + internal TaskCompletionSource _source; + internal IDisposable _disposable; + + protected AAsyncCallContext() { - // Cancellation requested due to datareader being closed - return Task.FromException(ADP.ExceptionWithStackTrace(ADP.ClosedConnectionError())); } - else + + protected AAsyncCallContext(SqlDataReader reader, TaskCompletionSource source, IDisposable disposable = null) { - return completionSource.Task.ContinueWith((retryTask) => - { - if (retryTask.IsFaulted) - { - // Somehow the network task faulted - return the exception - return Task.FromException(retryTask.Exception.InnerException); - } - else if (!_cancelAsyncOnCloseToken.IsCancellationRequested) - { - TdsParserStateObject stateObj = _stateObj; - if (stateObj != null) - { - // protect continuations against concurrent - // close and cancel - lock (stateObj) - { - if (_stateObj != null) - { // reader not closed while we waited for the lock - if (retryTask.IsCanceled) - { - if (_parser != null) - { - _parser.State = TdsParserState.Broken; // We failed to respond to attention, we have to quit! - _parser.Connection.BreakConnection(); - _parser.ThrowExceptionAndWarning(_stateObj); - } - } - else - { - if (!IsClosed) - { - try - { - return moreFunc(retryTask); - } - catch (Exception) - { - CleanupAfterAsyncInvocation(); - throw; - } - } - } - } - } - } - } - // if stateObj is null, or we closed the connection or the connection was already closed, - // then mark this operation as cancelled. - return Task.FromException(ADP.ExceptionWithStackTrace(ADP.ClosedConnectionError())); - }, TaskScheduler.Default).Unwrap(); + Set(reader, source, disposable); + } + + internal void Set(SqlDataReader reader, TaskCompletionSource source, IDisposable disposable = null) + { + this._reader = reader ?? throw new ArgumentNullException(nameof(reader)); + this._source = source ?? throw new ArgumentNullException(nameof(source)); + this._disposable = disposable; + } + + internal void Clear() + { + _source = null; + _reader = null; + IDisposable copyDisposable = _disposable; + _disposable = null; + copyDisposable?.Dispose(); + } + + internal abstract Func> Execute { get; } + + public virtual void Dispose() + { + Clear(); + } + } + + private sealed class ReadAsyncCallContext : AAsyncCallContext + { + internal static readonly Func> s_execute = SqlDataReader.ReadAsyncExecute; + + internal bool _hasMoreData; + internal bool _hasReadRowToken; + + internal ReadAsyncCallContext() + { + } + + internal override Func> Execute => s_execute; + + public override void Dispose() + { + SqlDataReader reader = this._reader; + base.Dispose(); + reader.SetCachedReadAsyncCallContext(this); } } - private Task InvokeRetryable(Func> moreFunc, TaskCompletionSource source, IDisposable objectToDispose = null) + private sealed class IsDBNullAsyncCallContext : AAsyncCallContext { + internal static readonly Func> s_execute = SqlDataReader.IsDBNullAsyncExecute; + + internal int _columnIndex; + + internal IsDBNullAsyncCallContext() { } + + internal override Func> Execute => s_execute; + + public override void Dispose() + { + SqlDataReader reader = this._reader; + base.Dispose(); + reader.SetCachedIDBNullAsyncCallContext(this); + } + } + + private sealed class HasNextResultAsyncCallContext : AAsyncCallContext + { + private static readonly Func> s_execute = SqlDataReader.NextResultAsyncExecute; + + public HasNextResultAsyncCallContext(SqlDataReader reader, TaskCompletionSource source, IDisposable disposable) + : base(reader, source, disposable) + { + } + + internal override Func> Execute => s_execute; + } + + private sealed class GetBytesAsyncCallContext : AAsyncCallContext + { + internal enum OperationMode + { + Seek = 0, + Read = 1 + } + + private static readonly Func> s_executeSeek = SqlDataReader.GetBytesAsyncSeekExecute; + private static readonly Func> s_executeRead = SqlDataReader.GetBytesAsyncReadExecute; + + internal int columnIndex; + internal byte[] buffer; + internal int index; + internal int length; + internal int timeout; + internal CancellationToken cancellationToken; + internal CancellationToken timeoutToken; + internal int totalBytesRead; + + internal OperationMode mode; + + internal GetBytesAsyncCallContext(SqlDataReader reader) + { + this._reader = reader ?? throw new ArgumentNullException(nameof(reader)); + } + + internal override Func> Execute => mode == OperationMode.Seek ? s_executeSeek : s_executeRead; + + public override void Dispose() + { + buffer = null; + cancellationToken = default; + timeoutToken = default; + base.Dispose(); + } + } + + private sealed class GetFieldValueAsyncCallContext : AAsyncCallContext + { + private static readonly Func> s_execute = SqlDataReader.GetFieldValueAsyncExecute; + + internal readonly int _columnIndex; + + internal GetFieldValueAsyncCallContext(SqlDataReader reader, TaskCompletionSource source, IDisposable disposable, int columnIndex) + : base(reader, source, disposable) + { + _columnIndex = columnIndex; + } + + internal override Func> Execute => s_execute; + } + + private static Task ExecuteAsyncCallCallback(Task task, object state) + { + AAsyncCallContext context = (AAsyncCallContext)state; + return context._reader.ExecuteAsyncCall(task, context); + } + + private static void CompleteAsyncCallCallback(Task task, object state) + { + AAsyncCallContext context = (AAsyncCallContext)state; + context._reader.CompleteAsyncCall(task, context); + } + + private Task InvokeAsyncCall(AAsyncCallContext context) + { + TaskCompletionSource source = context._source; try { Task task; try { - task = moreFunc(null); + task = context.Execute(null, context); } catch (Exception ex) { @@ -5059,11 +5218,15 @@ private Task InvokeRetryable(Func> moreFunc, TaskCompletionS if (task.IsCompleted) { - CompleteRetryable(task, source, objectToDispose); + CompleteAsyncCall(task, context); } else { - task.ContinueWith((t) => CompleteRetryable(t, source, objectToDispose), TaskScheduler.Default); + task.ContinueWith( + continuationAction: AAsyncCallContext.s_completeCallback, + state: context, + TaskScheduler.Default + ); } } catch (AggregateException e) @@ -5079,17 +5242,88 @@ private Task InvokeRetryable(Func> moreFunc, TaskCompletionS return source.Task; } - private void CompleteRetryable(Task task, TaskCompletionSource source, IDisposable objectToDispose) + private Task ExecuteAsyncCall(AAsyncCallContext context) + { + // _networkPacketTaskSource could be null if the connection was closed + // while an async invocation was outstanding. + TaskCompletionSource completionSource = _stateObj._networkPacketTaskSource; + if (_cancelAsyncOnCloseToken.IsCancellationRequested || completionSource == null) + { + // Cancellation requested due to datareader being closed + return Task.FromException(ADP.ExceptionWithStackTrace(ADP.ClosedConnectionError())); + } + else + { + return completionSource.Task.ContinueWith( + continuationFunction: AAsyncCallContext.s_executeCallback, + state: context, + TaskScheduler.Default + ).Unwrap(); + } + } + + private Task ExecuteAsyncCall(Task task, AAsyncCallContext context) { - if (objectToDispose != null) + // this function must be an instance function called from the static callback because otherwise a compiler error + // is caused by accessing the _cancelAsyncOnCloseToken field of a MarchalByRefObject derived class + if (task.IsFaulted) { - objectToDispose.Dispose(); + // Somehow the network task faulted - return the exception + return Task.FromException(task.Exception.InnerException); + } + else if (!_cancelAsyncOnCloseToken.IsCancellationRequested) + { + TdsParserStateObject stateObj = _stateObj; + if (stateObj != null) + { + // protect continuations against concurrent + // close and cancel + lock (stateObj) + { + if (_stateObj != null) + { // reader not closed while we waited for the lock + if (task.IsCanceled) + { + if (_parser != null) + { + _parser.State = TdsParserState.Broken; // We failed to respond to attention, we have to quit! + _parser.Connection.BreakConnection(); + _parser.ThrowExceptionAndWarning(_stateObj); + } + } + else + { + if (!IsClosed) + { + try + { + return context.Execute(task, context); + } + catch (Exception) + { + CleanupAfterAsyncInvocation(); + throw; + } + } + } + } + } + } } + // if stateObj is null, or we closed the connection or the connection was already closed, + // then mark this operation as cancelled. + return Task.FromException(ADP.ExceptionWithStackTrace(ADP.ClosedConnectionError())); + } + + private void CompleteAsyncCall(Task task, AAsyncCallContext context) + { + TaskCompletionSource source = context._source; + context.Dispose(); // If something has forced us to switch to SyncOverAsync mode while in an async task then we need to guarantee that we do the cleanup // This avoids us replaying non-replayable data (such as DONE or ENV_CHANGE tokens) var stateObj = _stateObj; - bool ignoreCloseToken = ((stateObj != null) && (stateObj._syncOverAsync)); + bool ignoreCloseToken = (stateObj != null) && (stateObj._syncOverAsync); CleanupAfterAsyncInvocation(ignoreCloseToken); Task current = Interlocked.CompareExchange(ref _currentTask, null, source.Task);