From 70c898882c9add74f0c0b6a928772a7a26396189 Mon Sep 17 00:00:00 2001 From: AndriySvyryd Date: Thu, 31 Oct 2019 09:08:06 -0700 Subject: [PATCH] Add back execution retry for query. Move the retry scope for SaveChanges to properly reset the state before retrying. Fixes #18628 --- All.sln.DotSettings | 6 +- .../RelationalDatabaseFacadeExtensions.cs | 4 + .../Migrations/HistoryRepository.cs | 10 +- .../Migrations/Internal/Migrator.cs | 2 + .../Migrations/MigrationCommand.cs | 2 + .../Query/Internal/BufferedDataReader.cs | 1324 +++++++++++++++++ .../Query/Internal/QueryingEnumerable.cs | 148 +- ...dQueryCompilingExpressionVisitorFactory.cs | 4 +- .../Query/QuerySqlGenerator.cs | 4 +- ...elationalCompiledQueryCacheKeyGenerator.cs | 17 +- .../Query/RelationalQueryContext.cs | 9 - .../RelationalQueryContextDependencies.cs | 6 + ...jectionBindingRemovingExpressionVisitor.cs | 65 +- ...alShapedQueryCompilingExpressionVisitor.cs | 42 +- src/EFCore.Relational/Storage/ReaderColumn.cs | 52 + .../Storage/ReaderColumn`.cs | 29 + .../Storage/RelationalCommand.cs | 21 +- .../RelationalCommandParameterObject.cs | 8 + .../Storage/RelationalTypeMapping.cs | 12 +- .../Update/Internal/BatchExecutor.cs | 23 +- .../Update/ReaderModificationCommandBatch.cs | 2 + ...SqlServerCompiledQueryCacheKeyGenerator.cs | 4 +- .../Internal/SqlServerDatabaseCreator.cs | 4 + .../SqlServerSequenceHiLoValueGenerator.cs | 10 +- .../Storage/Internal/SqliteDatabaseCreator.cs | 2 + .../ChangeTracking/Internal/StateManager.cs | 135 +- .../Internal/StateManagerDependencies.cs | 48 + ...piledQueryCacheKeyGeneratorDependencies.cs | 36 +- ...ingExpressionVisitor.ExpressionVisitors.cs | 6 +- src/EFCore/Query/QueryCompilationContext.cs | 2 + .../QueryCompilationContextDependencies.cs | 37 + src/EFCore/Query/QueryContext.cs | 9 + src/EFCore/Query/QueryContextDependencies.cs | 24 +- .../ShapedQueryCompilingExpressionVisitor.cs | 3 + .../Storage/ExecutionStrategyExtensions.cs | 32 +- ...tegy.cs => TestCosmosExecutionStrategy.cs} | 1 + .../CommandInterceptionTestBase.cs | 41 +- .../Query/AsyncFromSqlQueryTestBase.cs | 56 +- .../Query/FromSqlQueryTestBase.cs | 142 +- .../Query/FromSqlSprocQueryTestBase.cs | 12 +- .../Query/GearsOfWarFromSqlQueryTestBase.cs | 10 +- .../Query/InheritanceRelationalTestBase.cs | 12 +- .../Query/NullSemanticsQueryTestBase.cs | 8 +- .../Query/QueryNoClientEvalTestBase.cs | 8 +- .../RelationalDatabaseCleaner.cs | 2 +- .../TestUtilities/RelationalTestStore.cs | 12 +- .../Query/Internal/BufferedDataReaderTest.cs | 217 +++ .../Storage/RelationalCommandTest.cs | 30 +- .../FakeProvider/FakeDbDataReader.cs | 42 +- .../InterceptionTestBase.cs | 2 +- .../TestUtilities/DataGenerator.cs | 47 + .../TestUtilities/TestHelpers.cs | 2 +- .../CommandInterceptionSqlServerTest.cs | 16 + .../ExecutionStrategyTest.cs | 348 +++-- .../Query/AsyncSimpleQuerySqlServerTest.cs | 20 +- .../TestSqlServerRetryingExecutionStrategy.cs | 4 +- ...SqlServerRetryingExecutionStrategyTests.cs | 9 +- ...dQueryCacheKeyGeneratorDependenciesTest.cs | 1 + ...QueryCompilationContextDependenciesTest.cs | 1 + .../Storage/ExecutionStrategyTest.cs | 12 +- 60 files changed, 2626 insertions(+), 571 deletions(-) create mode 100644 src/EFCore.Relational/Query/Internal/BufferedDataReader.cs create mode 100644 src/EFCore.Relational/Storage/ReaderColumn.cs create mode 100644 src/EFCore.Relational/Storage/ReaderColumn`.cs rename test/EFCore.Cosmos.FunctionalTests/TestUtilities/{TestSqlServerRetryingExecutionStrategy.cs => TestCosmosExecutionStrategy.cs} (96%) create mode 100644 test/EFCore.Relational.Tests/Query/Internal/BufferedDataReaderTest.cs create mode 100644 test/EFCore.Specification.Tests/TestUtilities/DataGenerator.cs diff --git a/All.sln.DotSettings b/All.sln.DotSettings index ea433c87d71..af1c1434b20 100644 --- a/All.sln.DotSettings +++ b/All.sln.DotSettings @@ -198,6 +198,7 @@ Licensed under the Apache License, Version 2.0. See License.txt in the project r True True True + True True True True @@ -207,10 +208,13 @@ Licensed under the Apache License, Version 2.0. See License.txt in the project r True True True + True True True + True True True True True - True \ No newline at end of file + True + True \ No newline at end of file diff --git a/src/EFCore.Relational/Extensions/RelationalDatabaseFacadeExtensions.cs b/src/EFCore.Relational/Extensions/RelationalDatabaseFacadeExtensions.cs index b85fe09aa52..3b73238aadd 100644 --- a/src/EFCore.Relational/Extensions/RelationalDatabaseFacadeExtensions.cs +++ b/src/EFCore.Relational/Extensions/RelationalDatabaseFacadeExtensions.cs @@ -222,6 +222,7 @@ public static int ExecuteSqlCommand( new RelationalCommandParameterObject( GetFacadeDependencies(databaseFacade).RelationalConnection, rawSqlCommand.ParameterValues, + null, ((IDatabaseFacadeDependenciesAccessor)databaseFacade).Context, logger)); } @@ -388,6 +389,7 @@ public static async Task ExecuteSqlCommandAsync( new RelationalCommandParameterObject( facadeDependencies.RelationalConnection, rawSqlCommand.ParameterValues, + null, ((IDatabaseFacadeDependenciesAccessor)databaseFacade).Context, logger), cancellationToken); @@ -504,6 +506,7 @@ public static int ExecuteSqlRaw( new RelationalCommandParameterObject( facadeDependencies.RelationalConnection, rawSqlCommand.ParameterValues, + null, ((IDatabaseFacadeDependenciesAccessor)databaseFacade).Context, logger)); } @@ -656,6 +659,7 @@ public static async Task ExecuteSqlRawAsync( new RelationalCommandParameterObject( facadeDependencies.RelationalConnection, rawSqlCommand.ParameterValues, + null, ((IDatabaseFacadeDependenciesAccessor)databaseFacade).Context, logger), cancellationToken); diff --git a/src/EFCore.Relational/Migrations/HistoryRepository.cs b/src/EFCore.Relational/Migrations/HistoryRepository.cs index bd6a27a0503..b9da13008e7 100644 --- a/src/EFCore.Relational/Migrations/HistoryRepository.cs +++ b/src/EFCore.Relational/Migrations/HistoryRepository.cs @@ -127,7 +127,7 @@ protected virtual string ProductVersionColumnName /// /// Checks whether or not the history table exists. /// - /// True if the table already exists, false otherwise. + /// true if the table already exists, false otherwise. public virtual bool Exists() => Dependencies.DatabaseCreator.Exists() && InterpretExistsResult( @@ -135,6 +135,7 @@ public virtual bool Exists() new RelationalCommandParameterObject( Dependencies.Connection, null, + null, Dependencies.CurrentContext.Context, Dependencies.CommandLogger))); @@ -144,7 +145,7 @@ public virtual bool Exists() /// A to observe while waiting for the task to complete. /// /// A task that represents the asynchronous operation. The task result contains - /// True if the table already exists, false otherwise. + /// true if the table already exists, false otherwise. /// public virtual async Task ExistsAsync(CancellationToken cancellationToken = default) => await Dependencies.DatabaseCreator.ExistsAsync(cancellationToken) @@ -153,6 +154,7 @@ await Dependencies.RawSqlCommandBuilder.Build(ExistsSql).ExecuteScalarAsync( new RelationalCommandParameterObject( Dependencies.Connection, null, + null, Dependencies.CurrentContext.Context, Dependencies.CommandLogger), cancellationToken)); @@ -160,7 +162,7 @@ await Dependencies.RawSqlCommandBuilder.Build(ExistsSql).ExecuteScalarAsync( /// /// Interprets the result of executing . /// - /// true if the table exists; otherwise, false. + /// true if the table already exists, false otherwise. protected abstract bool InterpretExistsResult([NotNull] object value); /// @@ -217,6 +219,7 @@ public virtual IReadOnlyList GetAppliedMigrations() new RelationalCommandParameterObject( Dependencies.Connection, null, + null, Dependencies.CurrentContext.Context, Dependencies.CommandLogger))) { @@ -251,6 +254,7 @@ public virtual async Task> GetAppliedMigrationsAsync( new RelationalCommandParameterObject( Dependencies.Connection, null, + null, Dependencies.CurrentContext.Context, Dependencies.CommandLogger), cancellationToken)) diff --git a/src/EFCore.Relational/Migrations/Internal/Migrator.cs b/src/EFCore.Relational/Migrations/Internal/Migrator.cs index 9593c458905..f8dc7514f7d 100644 --- a/src/EFCore.Relational/Migrations/Internal/Migrator.cs +++ b/src/EFCore.Relational/Migrations/Internal/Migrator.cs @@ -117,6 +117,7 @@ public virtual void Migrate(string targetMigration = null) new RelationalCommandParameterObject( _connection, null, + null, _currentContext.Context, _commandLogger)); } @@ -154,6 +155,7 @@ await command.ExecuteNonQueryAsync( new RelationalCommandParameterObject( _connection, null, + null, _currentContext.Context, _commandLogger), cancellationToken); diff --git a/src/EFCore.Relational/Migrations/MigrationCommand.cs b/src/EFCore.Relational/Migrations/MigrationCommand.cs index ddd4a65dfc2..a3d3d98c45d 100644 --- a/src/EFCore.Relational/Migrations/MigrationCommand.cs +++ b/src/EFCore.Relational/Migrations/MigrationCommand.cs @@ -64,6 +64,7 @@ public virtual int ExecuteNonQuery( new RelationalCommandParameterObject( connection, parameterValues, + null, _context, _logger)); @@ -82,6 +83,7 @@ public virtual Task ExecuteNonQueryAsync( new RelationalCommandParameterObject( connection, parameterValues, + null, _context, _logger), cancellationToken); diff --git a/src/EFCore.Relational/Query/Internal/BufferedDataReader.cs b/src/EFCore.Relational/Query/Internal/BufferedDataReader.cs new file mode 100644 index 00000000000..1b2e25ac33f --- /dev/null +++ b/src/EFCore.Relational/Query/Internal/BufferedDataReader.cs @@ -0,0 +1,1324 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Data; +using System.Data.Common; +using System.Diagnostics; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using JetBrains.Annotations; +using Microsoft.EntityFrameworkCore.Diagnostics; +using Microsoft.EntityFrameworkCore.Storage; +using Microsoft.EntityFrameworkCore.Utilities; + +namespace Microsoft.EntityFrameworkCore.Query.Internal +{ + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public class BufferedDataReader : DbDataReader + { + private DbDataReader _underlyingReader; + private List _bufferedDataRecords = new List(); + private BufferedDataRecord _currentResultSet; + private int _currentResultSetNumber; + private int _recordsAffected; + private bool _disposed; + private bool _isClosed; + + public BufferedDataReader([NotNull] DbDataReader reader) + { + _underlyingReader = reader; + } + + public override int RecordsAffected => _recordsAffected; + + public override object this[string name] => throw new NotSupportedException(); + + public override object this[int ordinal] => throw new NotSupportedException(); + + public override int Depth => throw new NotSupportedException(); + + public override int FieldCount + { + get + { + AssertReaderIsOpen(); + return _currentResultSet.FieldCount; + } + } + + public override bool HasRows + { + get + { + AssertReaderIsOpen(); + return _currentResultSet.HasRows; + } + } + + public override bool IsClosed => _isClosed; + + [Conditional("DEBUG")] + private void AssertReaderIsOpen() + { + if (_underlyingReader != null) + { + throw new InvalidOperationException("The reader wasn't initialized"); + } + + if (_isClosed) + { + throw new InvalidOperationException("The reader is closed."); + } + } + + [Conditional("DEBUG")] + private void AssertReaderIsOpenWithData() + { + AssertReaderIsOpen(); + + if (!_currentResultSet.IsDataReady) + { + throw new InvalidOperationException("The reader doesn't have any data."); + } + } + + [Conditional("DEBUG")] + private void AssertFieldIsReady(int ordinal) + { + AssertReaderIsOpenWithData(); + + if (0 > ordinal + || ordinal > _currentResultSet.FieldCount) + { + throw new IndexOutOfRangeException(); + } + } + + public virtual BufferedDataReader Initialize([NotNull] IReadOnlyList columns) + { + if (_underlyingReader == null) + { + return this; + } + + try + { + do + { + _bufferedDataRecords.Add(new BufferedDataRecord().Initialize(_underlyingReader, columns)); + } + while (_underlyingReader.NextResult()); + + _recordsAffected = _underlyingReader.RecordsAffected; + _currentResultSet = _bufferedDataRecords[_currentResultSetNumber]; + + return this; + } + finally + { + _underlyingReader.Dispose(); + _underlyingReader = null; + } + } + + public virtual async Task InitializeAsync( + [NotNull] IReadOnlyList columns, CancellationToken cancellationToken) + { + if (_underlyingReader == null) + { + return this; + } + + try + { + do + { + _bufferedDataRecords.Add(await new BufferedDataRecord().InitializeAsync(_underlyingReader, columns, cancellationToken)); + } + while (await _underlyingReader.NextResultAsync(cancellationToken)); + + _recordsAffected = _underlyingReader.RecordsAffected; + _currentResultSet = _bufferedDataRecords[_currentResultSetNumber]; + + return this; + } + finally + { + _underlyingReader.Dispose(); + _underlyingReader = null; + } + } + + public static bool IsSupportedValueType(Type type) + => type == typeof(int) + || type == typeof(bool) + || type == typeof(Guid) + || type == typeof(byte) + || type == typeof(char) + || type == typeof(DateTime) + || type == typeof(DateTimeOffset) + || type == typeof(decimal) + || type == typeof(double) + || type == typeof(float) + || type == typeof(short) + || type == typeof(long) + || type == typeof(uint) + || type == typeof(ushort) + || type == typeof(ulong) + || type == typeof(sbyte); + + public override void Close() + { + _bufferedDataRecords = null; + _isClosed = true; + + var reader = _underlyingReader; + if (reader != null) + { + _underlyingReader = null; + reader.Dispose(); + } + } + + protected override void Dispose(bool disposing) + { + if (!_disposed + && disposing + && !IsClosed) + { + Close(); + } + + _disposed = true; + + base.Dispose(disposing); + } + + public override bool GetBoolean(int ordinal) + { + AssertFieldIsReady(ordinal); + return _currentResultSet.GetBoolean(ordinal); + } + + public override byte GetByte(int ordinal) + { + AssertFieldIsReady(ordinal); + return _currentResultSet.GetByte(ordinal); + } + + public override long GetBytes(int ordinal, long dataOffset, byte[] buffer, int bufferOffset, int length) + { + throw new NotSupportedException(); + } + + public override char GetChar(int ordinal) + { + AssertFieldIsReady(ordinal); + return _currentResultSet.GetChar(ordinal); + } + + public override long GetChars(int ordinal, long dataOffset, char[] buffer, int bufferOffset, int length) + { + throw new NotSupportedException(); + } + + public override DateTime GetDateTime(int ordinal) + { + AssertFieldIsReady(ordinal); + return _currentResultSet.GetDateTime(ordinal); + } + + public override decimal GetDecimal(int ordinal) + { + AssertFieldIsReady(ordinal); + return _currentResultSet.GetDecimal(ordinal); + } + + public override double GetDouble(int ordinal) + { + AssertFieldIsReady(ordinal); + return _currentResultSet.GetDouble(ordinal); + } + + public override float GetFloat(int ordinal) + { + AssertFieldIsReady(ordinal); + return _currentResultSet.GetFloat(ordinal); + } + + public override Guid GetGuid(int ordinal) + { + AssertFieldIsReady(ordinal); + return _currentResultSet.GetGuid(ordinal); + } + + public override short GetInt16(int ordinal) + { + AssertFieldIsReady(ordinal); + return _currentResultSet.GetInt16(ordinal); + } + + public override int GetInt32(int ordinal) + { + AssertFieldIsReady(ordinal); + return _currentResultSet.GetInt32(ordinal); + } + + public override long GetInt64(int ordinal) + { + AssertFieldIsReady(ordinal); + return _currentResultSet.GetInt64(ordinal); + } + + public override string GetString(int ordinal) + { + AssertFieldIsReady(ordinal); + return _currentResultSet.GetFieldValue(ordinal); + } + + public override T GetFieldValue(int ordinal) + { + AssertFieldIsReady(ordinal); + return _currentResultSet.GetFieldValue(ordinal); + } + + public override Task GetFieldValueAsync(int ordinal, CancellationToken cancellationToken) + { + AssertFieldIsReady(ordinal); + return _currentResultSet.GetFieldValueAsync(ordinal, cancellationToken); + } + + public override object GetValue(int ordinal) + { + AssertFieldIsReady(ordinal); + return _currentResultSet.GetValue(ordinal); + } + + public override int GetValues(object[] values) + { + AssertReaderIsOpenWithData(); + return _currentResultSet.GetValues(values); + } + + public override string GetDataTypeName(int ordinal) + { + AssertReaderIsOpen(); + return _currentResultSet.GetDataTypeName(ordinal); + } + + public override Type GetFieldType(int ordinal) + { + AssertReaderIsOpen(); + return _currentResultSet.GetFieldType(ordinal); + } + + public override string GetName(int ordinal) + { + AssertReaderIsOpen(); + return _currentResultSet.GetName(ordinal); + } + + public override int GetOrdinal(string name) + { + Check.NotNull(name, "name"); + AssertReaderIsOpen(); + return _currentResultSet.GetOrdinal(name); + } + + public override bool IsDBNull(int ordinal) + { + AssertFieldIsReady(ordinal); + return _currentResultSet.IsDBNull(ordinal); + } + + public override Task IsDBNullAsync(int ordinal, CancellationToken cancellationToken) + { + AssertFieldIsReady(ordinal); + return _currentResultSet.IsDBNullAsync(ordinal, cancellationToken); + } + + public override IEnumerator GetEnumerator() => throw new NotSupportedException(); + + public override DataTable GetSchemaTable() => throw new NotSupportedException(); + + public override bool NextResult() + { + AssertReaderIsOpen(); + if (++_currentResultSetNumber < _bufferedDataRecords.Count) + { + _currentResultSet = _bufferedDataRecords[_currentResultSetNumber]; + return true; + } + + _currentResultSet = null; + return false; + } + + public override Task NextResultAsync(CancellationToken cancellationToken) + => Task.FromResult(NextResult()); + + public override bool Read() + { + AssertReaderIsOpen(); + return _currentResultSet.Read(); + } + + public override Task ReadAsync(CancellationToken cancellationToken) + { + AssertReaderIsOpen(); + return _currentResultSet.ReadAsync(cancellationToken); + } + + private class BufferedDataRecord + { + private int _currentRowNumber = -1; + private int _rowCount; + private string[] _dataTypeNames; + private Type[] _fieldTypes; + private string[] _columnNames; + private Lazy> _fieldNameLookup; + + private int _rowCapacity = 1; + + // Resizing bool[] is faster than BitArray, but the latter is more efficient for long-term storage. + private BitArray _bools; + private bool[] _tempBools; + private int _boolCount; + private byte[] _bytes; + private int _byteCount; + private char[] _chars; + private int _charCount; + private DateTime[] _dateTimes; + private int _dateTimeCount; + private DateTimeOffset[] _dateTimeOffsets; + private int _dateTimeOffsetCount; + private decimal[] _decimals; + private int _decimalCount; + private double[] _doubles; + private int _doubleCount; + private float[] _floats; + private int _floatCount; + private Guid[] _guids; + private int _guidCount; + private short[] _shorts; + private int _shortCount; + private int[] _ints; + private int _intCount; + private long[] _longs; + private int _longCount; + private sbyte[] _sbytes; + private int _sbyteCount; + private uint[] _uints; + private int _uintCount; + private ushort[] _ushorts; + private int _ushortCount; + private ulong[] _ulongs; + private int _ulongCount; + private object[] _objects; + private int _objectCount; + private int[] _ordinalToIndexMap; + + private BitArray _nulls; + private bool[] _tempNulls; + private int _nullCount; + private int[] _nullOrdinalToIndexMap; + + private TypeCase[] _columnTypeCases; + + private DbDataReader _underlyingReader; + private IReadOnlyList _columns; + private int[] _indexMap; + + public bool IsDataReady { get; private set; } + + public bool HasRows => _rowCount > 0; + + public int FieldCount => _fieldTypes.Length; + + public string GetDataTypeName(int ordinal) => _dataTypeNames[ordinal]; + + public Type GetFieldType(int ordinal) => _fieldTypes[ordinal]; + + public string GetName(int ordinal) => _columnNames[ordinal]; + + public int GetOrdinal(string name) => _fieldNameLookup.Value[name]; + + public bool GetBoolean(int ordinal) + => _columnTypeCases[ordinal] == TypeCase.Bool + ? _bools[_currentRowNumber * _boolCount + _ordinalToIndexMap[ordinal]] + : GetFieldValue(ordinal); + + public byte GetByte(int ordinal) + => _columnTypeCases[ordinal] == TypeCase.Byte + ? _bytes[_currentRowNumber * _byteCount + _ordinalToIndexMap[ordinal]] + : GetFieldValue(ordinal); + + public char GetChar(int ordinal) + => _columnTypeCases[ordinal] == TypeCase.Char + ? _chars[_currentRowNumber * _charCount + _ordinalToIndexMap[ordinal]] + : GetFieldValue(ordinal); + + public DateTime GetDateTime(int ordinal) + => _columnTypeCases[ordinal] == TypeCase.DateTime + ? _dateTimes[_currentRowNumber * _dateTimeCount + _ordinalToIndexMap[ordinal]] + : GetFieldValue(ordinal); + + public DateTimeOffset GetDateTimeOffset(int ordinal) + => _columnTypeCases[ordinal] == TypeCase.DateTimeOffset + ? _dateTimeOffsets[_currentRowNumber * _dateTimeOffsetCount + _ordinalToIndexMap[ordinal]] + : GetFieldValue(ordinal); + + public decimal GetDecimal(int ordinal) + => _columnTypeCases[ordinal] == TypeCase.Decimal + ? _decimals[_currentRowNumber * _decimalCount + _ordinalToIndexMap[ordinal]] + : GetFieldValue(ordinal); + + public double GetDouble(int ordinal) + => _columnTypeCases[ordinal] == TypeCase.Double + ? _doubles[_currentRowNumber * _doubleCount + _ordinalToIndexMap[ordinal]] + : GetFieldValue(ordinal); + + public float GetFloat(int ordinal) + => _columnTypeCases[ordinal] == TypeCase.Float + ? _floats[_currentRowNumber * _floatCount + _ordinalToIndexMap[ordinal]] + : GetFieldValue(ordinal); + + public Guid GetGuid(int ordinal) + => _columnTypeCases[ordinal] == TypeCase.Guid + ? _guids[_currentRowNumber * _guidCount + _ordinalToIndexMap[ordinal]] + : GetFieldValue(ordinal); + + public short GetInt16(int ordinal) + => _columnTypeCases[ordinal] == TypeCase.Short + ? _shorts[_currentRowNumber * _shortCount + _ordinalToIndexMap[ordinal]] + : GetFieldValue(ordinal); + + public int GetInt32(int ordinal) + => _columnTypeCases[ordinal] == TypeCase.Int + ? _ints[_currentRowNumber * _intCount + _ordinalToIndexMap[ordinal]] + : GetFieldValue(ordinal); + + public long GetInt64(int ordinal) + => _columnTypeCases[ordinal] == TypeCase.Long + ? _longs[_currentRowNumber * _longCount + _ordinalToIndexMap[ordinal]] + : GetFieldValue(ordinal); + + public sbyte GetSByte(int ordinal) + => _columnTypeCases[ordinal] == TypeCase.SByte + ? _sbytes[_currentRowNumber * _sbyteCount + _ordinalToIndexMap[ordinal]] + : GetFieldValue(ordinal); + + public ushort GetUInt16(int ordinal) + => _columnTypeCases[ordinal] == TypeCase.UShort + ? _ushorts[_currentRowNumber * _ushortCount + _ordinalToIndexMap[ordinal]] + : GetFieldValue(ordinal); + + public uint GetUInt32(int ordinal) + => _columnTypeCases[ordinal] == TypeCase.UInt + ? _uints[_currentRowNumber * _uintCount + _ordinalToIndexMap[ordinal]] + : GetFieldValue(ordinal); + + public ulong GetUInt64(int ordinal) + => _columnTypeCases[ordinal] == TypeCase.ULong + ? _ulongs[_currentRowNumber * _ulongCount + _ordinalToIndexMap[ordinal]] + : GetFieldValue(ordinal); + + public object GetValue(int ordinal) + => GetFieldValue(ordinal); + + public int GetValues(object[] values) + => throw new NotSupportedException(); + + public T GetFieldValue(int ordinal) + { + switch (_columnTypeCases[ordinal]) + { + case TypeCase.Bool: + return (T)(object)GetBoolean(ordinal); + case TypeCase.Byte: + return (T)(object)GetByte(ordinal); + case TypeCase.Char: + return (T)(object)GetChar(ordinal); + case TypeCase.DateTime: + return (T)(object)GetDateTime(ordinal); + case TypeCase.DateTimeOffset: + return (T)(object)GetDateTimeOffset(ordinal); + case TypeCase.Decimal: + return (T)(object)GetDecimal(ordinal); + case TypeCase.Double: + return (T)(object)GetDouble(ordinal); + case TypeCase.Float: + return (T)(object)GetFloat(ordinal); + case TypeCase.Guid: + return (T)(object)GetGuid(ordinal); + case TypeCase.Short: + return (T)(object)GetInt16(ordinal); + case TypeCase.Int: + return (T)(object)GetInt32(ordinal); + case TypeCase.Long: + return (T)(object)GetInt64(ordinal); + case TypeCase.SByte: + return (T)(object)GetSByte(ordinal); + case TypeCase.UShort: + return (T)(object)GetUInt16(ordinal); + case TypeCase.UInt: + return (T)(object)GetUInt32(ordinal); + case TypeCase.ULong: + return (T)(object)GetUInt64(ordinal); + case TypeCase.Empty: + return default; + default: + return (T)_objects[_currentRowNumber * _objectCount + _ordinalToIndexMap[ordinal]]; + } + } + + public Task GetFieldValueAsync(int ordinal, CancellationToken cancellationToken) + => Task.FromResult(GetFieldValue(ordinal)); + + public bool IsDBNull(int ordinal) => _nulls[_currentRowNumber * _nullCount + _nullOrdinalToIndexMap[ordinal]]; + + public Task IsDBNullAsync(int ordinal, CancellationToken cancellationToken) => Task.FromResult(IsDBNull(ordinal)); + + public bool Read() => IsDataReady = ++_currentRowNumber < _rowCount; + + public Task ReadAsync(CancellationToken cancellationToken) => Task.FromResult(Read()); + + public BufferedDataRecord Initialize([NotNull] DbDataReader reader, [NotNull] IReadOnlyList columns) + { + _underlyingReader = reader; + _columns = columns; + + ReadMetadata(); + InitializeFields(); + + while (reader.Read()) + { + ReadRow(); + } + + _bools = new BitArray(_tempBools); + _tempBools = null; + _nulls = new BitArray(_tempNulls); + _tempNulls = null; + _rowCount = _currentRowNumber + 1; + _currentRowNumber = -1; + _underlyingReader = null; + _columns = null; + + return this; + } + + public async Task InitializeAsync( + [NotNull] DbDataReader reader, [NotNull] IReadOnlyList columns, CancellationToken cancellationToken) + { + _underlyingReader = reader; + _columns = columns; + + ReadMetadata(); + InitializeFields(); + + while (await reader.ReadAsync(cancellationToken)) + { + ReadRow(); + } + + _bools = new BitArray(_tempBools); + _tempBools = null; + _nulls = new BitArray(_tempNulls); + _tempNulls = null; + _rowCount = _currentRowNumber + 1; + _currentRowNumber = -1; + _underlyingReader = null; + _columns = null; + + return this; + } + + private void ReadRow() + { + _currentRowNumber++; + + if (_rowCapacity == _currentRowNumber) + { + DoubleBufferCapacity(); + } + + for (var i = 0; i < FieldCount; i++) + { + var column = _columns[i]; + var nullIndex = _nullOrdinalToIndexMap[i]; + switch (_columnTypeCases[i]) + { + case TypeCase.Bool: + if (nullIndex != -1) + { + if (!(_tempNulls[_currentRowNumber * _nullCount + nullIndex] = _underlyingReader.IsDBNull(i))) + { + ReadBool(_underlyingReader, i, column); + } + } + else + { + ReadBool(_underlyingReader, i, column); + } + + break; + case TypeCase.Byte: + if (nullIndex != -1) + { + if (!(_tempNulls[_currentRowNumber * _nullCount + nullIndex] = _underlyingReader.IsDBNull(i))) + { + ReadByte(_underlyingReader, i, column); + } + } + else + { + ReadByte(_underlyingReader, i, column); + } + + break; + case TypeCase.Char: + if (nullIndex != -1) + { + if (!(_tempNulls[_currentRowNumber * _nullCount + nullIndex] = _underlyingReader.IsDBNull(i))) + { + ReadChar(_underlyingReader, i, column); + } + } + else + { + ReadChar(_underlyingReader, i, column); + } + + break; + case TypeCase.DateTime: + if (nullIndex != -1) + { + if (!(_tempNulls[_currentRowNumber * _nullCount + nullIndex] = _underlyingReader.IsDBNull(i))) + { + ReadDateTime(_underlyingReader, i, column); + } + } + else + { + ReadDateTime(_underlyingReader, i, column); + } + + break; + case TypeCase.DateTimeOffset: + if (nullIndex != -1) + { + if (!(_tempNulls[_currentRowNumber * _nullCount + nullIndex] = _underlyingReader.IsDBNull(i))) + { + ReadDateTimeOffset(_underlyingReader, i, column); + } + } + else + { + ReadDateTimeOffset(_underlyingReader, i, column); + } + + break; + case TypeCase.Decimal: + if (nullIndex != -1) + { + if (!(_tempNulls[_currentRowNumber * _nullCount + nullIndex] = _underlyingReader.IsDBNull(i))) + { + ReadDecimal(_underlyingReader, i, column); + } + } + else + { + ReadDecimal(_underlyingReader, i, column); + } + + break; + case TypeCase.Double: + if (nullIndex != -1) + { + if (!(_tempNulls[_currentRowNumber * _nullCount + nullIndex] = _underlyingReader.IsDBNull(i))) + { + ReadDouble(_underlyingReader, i, column); + } + } + else + { + ReadDouble(_underlyingReader, i, column); + } + + break; + case TypeCase.Float: + if (nullIndex != -1) + { + if (!(_tempNulls[_currentRowNumber * _nullCount + nullIndex] = _underlyingReader.IsDBNull(i))) + { + ReadFloat(_underlyingReader, i, column); + } + } + else + { + ReadFloat(_underlyingReader, i, column); + } + + break; + case TypeCase.Guid: + if (nullIndex != -1) + { + if (!(_tempNulls[_currentRowNumber * _nullCount + nullIndex] = _underlyingReader.IsDBNull(i))) + { + ReadGuid(_underlyingReader, i, column); + } + } + else + { + ReadGuid(_underlyingReader, i, column); + } + + break; + case TypeCase.Short: + if (nullIndex != -1) + { + if (!(_tempNulls[_currentRowNumber * _nullCount + nullIndex] = _underlyingReader.IsDBNull(i))) + { + ReadShort(_underlyingReader, i, column); + } + } + else + { + ReadShort(_underlyingReader, i, column); + } + + break; + case TypeCase.Int: + if (nullIndex != -1) + { + if (!(_tempNulls[_currentRowNumber * _nullCount + nullIndex] = _underlyingReader.IsDBNull(i))) + { + ReadInt(_underlyingReader, i, column); + } + } + else + { + ReadInt(_underlyingReader, i, column); + } + + break; + case TypeCase.Long: + if (nullIndex != -1) + { + if (!(_tempNulls[_currentRowNumber * _nullCount + nullIndex] = _underlyingReader.IsDBNull(i))) + { + ReadLong(_underlyingReader, i, column); + } + } + else + { + ReadLong(_underlyingReader, i, column); + } + + break; + case TypeCase.SByte: + if (nullIndex != -1) + { + if (!(_tempNulls[_currentRowNumber * _nullCount + nullIndex] = _underlyingReader.IsDBNull(i))) + { + ReadSByte(_underlyingReader, i, column); + } + } + else + { + ReadSByte(_underlyingReader, i, column); + } + + break; + case TypeCase.UShort: + if (nullIndex != -1) + { + if (!(_tempNulls[_currentRowNumber * _nullCount + nullIndex] = _underlyingReader.IsDBNull(i))) + { + ReadUShort(_underlyingReader, i, column); + } + } + else + { + ReadUShort(_underlyingReader, i, column); + } + + break; + case TypeCase.UInt: + if (nullIndex != -1) + { + if (!(_tempNulls[_currentRowNumber * _nullCount + nullIndex] = _underlyingReader.IsDBNull(i))) + { + ReadUInt(_underlyingReader, i, column); + } + } + else + { + ReadUInt(_underlyingReader, i, column); + } + + break; + case TypeCase.ULong: + if (nullIndex != -1) + { + if (!(_tempNulls[_currentRowNumber * _nullCount + nullIndex] = _underlyingReader.IsDBNull(i))) + { + ReadULong(_underlyingReader, i, column); + } + } + else + { + ReadULong(_underlyingReader, i, column); + } + + break; + case TypeCase.Empty: + break; + default: + if (nullIndex != -1) + { + if (!(_tempNulls[_currentRowNumber * _nullCount + nullIndex] = _underlyingReader.IsDBNull(i))) + { + ReadObject(_underlyingReader, i, column); + } + } + else + { + ReadObject(_underlyingReader, i, column); + } + + break; + } + } + } + + private void ReadMetadata() + { + var fieldCount = _underlyingReader.FieldCount; + var dataTypeNames = new string[fieldCount]; + var columnTypes = new Type[fieldCount]; + var columnNames = new string[fieldCount]; + for (var i = 0; i < fieldCount; i++) + { + dataTypeNames[i] = _underlyingReader.GetDataTypeName(i); + columnTypes[i] = _underlyingReader.GetFieldType(i); + columnNames[i] = _underlyingReader.GetName(i); + } + + _dataTypeNames = dataTypeNames; + _fieldTypes = columnTypes; + _columnNames = columnNames; + _fieldNameLookup = new Lazy>(CreateNameLookup, isThreadSafe: false); + + Dictionary CreateNameLookup() + { + var index = new Dictionary(StringComparer.OrdinalIgnoreCase); + for (var i = 0; i < _columnNames.Length; i++) + { + index[_columnNames[i]] = i; + } + + return index; + } + } + + private void InitializeFields() + { + var fieldCount = FieldCount; + if (FieldCount < _columns.Count) + { + throw new InvalidOperationException("The underlying reader doesn't have as many fields as expected."); + } + + _columnTypeCases = Enumerable.Repeat(TypeCase.Empty, fieldCount).ToArray(); + _ordinalToIndexMap = Enumerable.Repeat(-1, fieldCount).ToArray(); + if (_columns.Count > 0 + && _columns[0].Name != null) + { + // Non-Composed FromSql + var readerColumns = _fieldNameLookup.Value; + + _indexMap = new int[_columns.Count]; + var newColumnMap = new ReaderColumn[fieldCount]; + for (var i = 0; i < _columns.Count; i++) + { + var column = _columns[i]; + if (!readerColumns.TryGetValue(column.Name, out var ordinal)) + { + throw new InvalidOperationException(RelationalStrings.FromSqlMissingColumn(column.Name)); + } + + newColumnMap[ordinal] = column; + _indexMap[i] = ordinal; + } + + _columns = newColumnMap; + } + + if (FieldCount != _columns.Count) + { + var newColumnMap = new ReaderColumn[fieldCount]; + for (var i = 0; i < _columns.Count; i++) + { + newColumnMap[i] = _columns[i]; + } + _columns = newColumnMap; + } + + for (var i = 0; i < fieldCount; i++) + { + var column = _columns[i]; + if (column == null) + { + continue; + } + + var type = column.Type; + if (type == typeof(bool)) + { + _columnTypeCases[i] = TypeCase.Bool; + _ordinalToIndexMap[i] = _boolCount; + _boolCount++; + } + else if (type == typeof(byte)) + { + _columnTypeCases[i] = TypeCase.Byte; + _ordinalToIndexMap[i] = _byteCount; + _byteCount++; + } + else if (type == typeof(char)) + { + _columnTypeCases[i] = TypeCase.Char; + _ordinalToIndexMap[i] = _charCount; + _charCount++; + } + else if (type == typeof(DateTime)) + { + _columnTypeCases[i] = TypeCase.DateTime; + _ordinalToIndexMap[i] = _dateTimeCount; + _dateTimeCount++; + } + else if (type == typeof(DateTimeOffset)) + { + _columnTypeCases[i] = TypeCase.DateTimeOffset; + _ordinalToIndexMap[i] = _dateTimeOffsetCount; + _dateTimeOffsetCount++; + } + else if (type == typeof(decimal)) + { + _columnTypeCases[i] = TypeCase.Decimal; + _ordinalToIndexMap[i] = _decimalCount; + _decimalCount++; + } + else if (type == typeof(double)) + { + _columnTypeCases[i] = TypeCase.Double; + _ordinalToIndexMap[i] = _doubleCount; + _doubleCount++; + } + else if (type == typeof(float)) + { + _columnTypeCases[i] = TypeCase.Float; + _ordinalToIndexMap[i] = _floatCount; + _floatCount++; + } + else if (type == typeof(Guid)) + { + _columnTypeCases[i] = TypeCase.Guid; + _ordinalToIndexMap[i] = _guidCount; + _guidCount++; + } + else if (type == typeof(short)) + { + _columnTypeCases[i] = TypeCase.Short; + _ordinalToIndexMap[i] = _shortCount; + _shortCount++; + } + else if (type == typeof(int)) + { + _columnTypeCases[i] = TypeCase.Int; + _ordinalToIndexMap[i] = _intCount; + _intCount++; + } + else if (type == typeof(long)) + { + _columnTypeCases[i] = TypeCase.Long; + _ordinalToIndexMap[i] = _longCount; + _longCount++; + } + else if (type == typeof(sbyte)) + { + _columnTypeCases[i] = TypeCase.SByte; + _ordinalToIndexMap[i] = _sbyteCount; + _sbyteCount++; + } + else if (type == typeof(ushort)) + { + _columnTypeCases[i] = TypeCase.UShort; + _ordinalToIndexMap[i] = _ushortCount; + _ushortCount++; + } + else if (type == typeof(uint)) + { + _columnTypeCases[i] = TypeCase.UInt; + _ordinalToIndexMap[i] = _uintCount; + _uintCount++; + } + else if (type == typeof(ulong)) + { + _columnTypeCases[i] = TypeCase.ULong; + _ordinalToIndexMap[i] = _ulongCount; + _ulongCount++; + } + else + { + _columnTypeCases[i] = TypeCase.Object; + _ordinalToIndexMap[i] = _objectCount; + _objectCount++; + } + } + + _tempBools = new bool[_rowCapacity * _boolCount]; + _bytes = new byte[_rowCapacity * _byteCount]; + _chars = new char[_rowCapacity * _charCount]; + _dateTimes = new DateTime[_rowCapacity * _dateTimeCount]; + _dateTimeOffsets = new DateTimeOffset[_rowCapacity * _dateTimeOffsetCount]; + _decimals = new decimal[_rowCapacity * _decimalCount]; + _doubles = new double[_rowCapacity * _doubleCount]; + _floats = new float[_rowCapacity * _floatCount]; + _guids = new Guid[_rowCapacity * _guidCount]; + _shorts = new short[_rowCapacity * _shortCount]; + _ints = new int[_rowCapacity * _intCount]; + _longs = new long[_rowCapacity * _longCount]; + _sbytes = new sbyte[_rowCapacity * _sbyteCount]; + _ushorts = new ushort[_rowCapacity * _ushortCount]; + _uints = new uint[_rowCapacity * _uintCount]; + _ulongs = new ulong[_rowCapacity * _ulongCount]; + _objects = new object[_rowCapacity * _objectCount]; + + _nullOrdinalToIndexMap = Enumerable.Repeat(-1, fieldCount).ToArray(); + for (var i = 0; i < fieldCount; i++) + { + if (_columns[i]?.IsNullable == true) + { + _nullOrdinalToIndexMap[i] = _nullCount; + _nullCount++; + } + } + + _tempNulls = new bool[_rowCapacity * _nullCount]; + } + + private void DoubleBufferCapacity() + { + _rowCapacity <<= 1; + + var newBools = new bool[_tempBools.Length << 1]; + Array.Copy(_tempBools, newBools, _tempBools.Length); + _tempBools = newBools; + + var newBytes = new byte[_bytes.Length << 1]; + Array.Copy(_bytes, newBytes, _bytes.Length); + _bytes = newBytes; + + var newChars = new char[_chars.Length << 1]; + Array.Copy(_chars, newChars, _chars.Length); + _chars = newChars; + + var newDateTimes = new DateTime[_dateTimes.Length << 1]; + Array.Copy(_dateTimes, newDateTimes, _dateTimes.Length); + _dateTimes = newDateTimes; + + var newDateTimeOffsets = new DateTimeOffset[_dateTimeOffsets.Length << 1]; + Array.Copy(_dateTimeOffsets, newDateTimeOffsets, _dateTimeOffsets.Length); + _dateTimeOffsets = newDateTimeOffsets; + + var newDecimals = new decimal[_decimals.Length << 1]; + Array.Copy(_decimals, newDecimals, _decimals.Length); + _decimals = newDecimals; + + var newDoubles = new double[_doubles.Length << 1]; + Array.Copy(_doubles, newDoubles, _doubles.Length); + _doubles = newDoubles; + + var newFloats = new float[_floats.Length << 1]; + Array.Copy(_floats, newFloats, _floats.Length); + _floats = newFloats; + + var newGuids = new Guid[_guids.Length << 1]; + Array.Copy(_guids, newGuids, _guids.Length); + _guids = newGuids; + + var newShorts = new short[_shorts.Length << 1]; + Array.Copy(_shorts, newShorts, _shorts.Length); + _shorts = newShorts; + + var newInts = new int[_ints.Length << 1]; + Array.Copy(_ints, newInts, _ints.Length); + _ints = newInts; + + var newLongs = new long[_longs.Length << 1]; + Array.Copy(_longs, newLongs, _longs.Length); + _longs = newLongs; + + var newSBytes = new sbyte[_sbytes.Length << 1]; + Array.Copy(_sbytes, newSBytes, _sbytes.Length); + _sbytes = newSBytes; + + var newUShorts = new ushort[_ushorts.Length << 1]; + Array.Copy(_ushorts, newUShorts, _ushorts.Length); + _ushorts = newUShorts; + + var newUInts = new uint[_uints.Length << 1]; + Array.Copy(_uints, newUInts, _uints.Length); + _uints = newUInts; + + var newULongs = new ulong[_ulongs.Length << 1]; + Array.Copy(_ulongs, newULongs, _ulongs.Length); + _ulongs = newULongs; + + var newObjects = new object[_objects.Length << 1]; + Array.Copy(_objects, newObjects, _objects.Length); + _objects = newObjects; + + var newNulls = new bool[_tempNulls.Length << 1]; + Array.Copy(_tempNulls, newNulls, _tempNulls.Length); + _tempNulls = newNulls; + } + + private void ReadBool(DbDataReader reader, int ordinal, ReaderColumn column) + { + _tempBools[_currentRowNumber * _boolCount + _ordinalToIndexMap[ordinal]] = + ((ReaderColumn)column).GetFieldValue(reader, _indexMap); + } + + private void ReadByte(DbDataReader reader, int ordinal, ReaderColumn column) + { + _bytes[_currentRowNumber * _byteCount + _ordinalToIndexMap[ordinal]] = + ((ReaderColumn)column).GetFieldValue(reader, _indexMap); + } + + private void ReadChar(DbDataReader reader, int ordinal, ReaderColumn column) + { + _chars[_currentRowNumber * _charCount + _ordinalToIndexMap[ordinal]] = + ((ReaderColumn)column).GetFieldValue(reader, _indexMap); + } + + private void ReadDateTime(DbDataReader reader, int ordinal, ReaderColumn column) + { + _dateTimes[_currentRowNumber * _dateTimeCount + _ordinalToIndexMap[ordinal]] = + ((ReaderColumn)column).GetFieldValue(reader, _indexMap); + } + + private void ReadDateTimeOffset(DbDataReader reader, int ordinal, ReaderColumn column) + { + _dateTimeOffsets[_currentRowNumber * _dateTimeOffsetCount + _ordinalToIndexMap[ordinal]] = + ((ReaderColumn)column).GetFieldValue(reader, _indexMap); + } + + private void ReadDecimal(DbDataReader reader, int ordinal, ReaderColumn column) + { + _decimals[_currentRowNumber * _decimalCount + _ordinalToIndexMap[ordinal]] = + ((ReaderColumn)column).GetFieldValue(reader, _indexMap); + } + + private void ReadDouble(DbDataReader reader, int ordinal, ReaderColumn column) + { + _doubles[_currentRowNumber * _doubleCount + _ordinalToIndexMap[ordinal]] = + ((ReaderColumn)column).GetFieldValue(reader, _indexMap); + } + + private void ReadFloat(DbDataReader reader, int ordinal, ReaderColumn column) + { + _floats[_currentRowNumber * _floatCount + _ordinalToIndexMap[ordinal]] = + ((ReaderColumn)column).GetFieldValue(reader, _indexMap); + } + + private void ReadGuid(DbDataReader reader, int ordinal, ReaderColumn column) + { + _guids[_currentRowNumber * _guidCount + _ordinalToIndexMap[ordinal]] = + ((ReaderColumn)column).GetFieldValue(reader, _indexMap); + } + + private void ReadShort(DbDataReader reader, int ordinal, ReaderColumn column) + { + _shorts[_currentRowNumber * _shortCount + _ordinalToIndexMap[ordinal]] = + ((ReaderColumn)column).GetFieldValue(reader, _indexMap); + } + + private void ReadInt(DbDataReader reader, int ordinal, ReaderColumn column) + { + _ints[_currentRowNumber * _intCount + _ordinalToIndexMap[ordinal]] = + ((ReaderColumn)column).GetFieldValue(reader, _indexMap); + } + + private void ReadLong(DbDataReader reader, int ordinal, ReaderColumn column) + { + _longs[_currentRowNumber * _longCount + _ordinalToIndexMap[ordinal]] = + ((ReaderColumn)column).GetFieldValue(reader, _indexMap); + } + + private void ReadSByte(DbDataReader reader, int ordinal, ReaderColumn column) + { + _sbytes[_currentRowNumber * _sbyteCount + _ordinalToIndexMap[ordinal]] = + ((ReaderColumn)column).GetFieldValue(reader, _indexMap); + } + + private void ReadUShort(DbDataReader reader, int ordinal, ReaderColumn column) + { + _ushorts[_currentRowNumber * _ushortCount + _ordinalToIndexMap[ordinal]] = + ((ReaderColumn)column).GetFieldValue(reader, _indexMap); + } + + private void ReadUInt(DbDataReader reader, int ordinal, ReaderColumn column) + { + _uints[_currentRowNumber * _uintCount + _ordinalToIndexMap[ordinal]] = + ((ReaderColumn)column).GetFieldValue(reader, _indexMap); + } + + private void ReadULong(DbDataReader reader, int ordinal, ReaderColumn column) + { + _ulongs[_currentRowNumber * _ulongCount + _ordinalToIndexMap[ordinal]] = + ((ReaderColumn)column).GetFieldValue(reader, _indexMap); + } + + private void ReadObject(DbDataReader reader, int ordinal, ReaderColumn column) + { + _objects[_currentRowNumber * _objectCount + _ordinalToIndexMap[ordinal]] = + ((ReaderColumn)column).GetFieldValue(reader, _indexMap); + } + + private enum TypeCase + { + Empty = 0, + Object, + Bool, + Byte, + Char, + DateTime, + DateTimeOffset, + Decimal, + Double, + Float, + Guid, + SByte, + Short, + Int, + Long, + UInt, + ULong, + UShort + } + } + } +} diff --git a/src/EFCore.Relational/Query/Internal/QueryingEnumerable.cs b/src/EFCore.Relational/Query/Internal/QueryingEnumerable.cs index fabb829eb93..4a58a8c7acb 100644 --- a/src/EFCore.Relational/Query/Internal/QueryingEnumerable.cs +++ b/src/EFCore.Relational/Query/Internal/QueryingEnumerable.cs @@ -24,6 +24,7 @@ public class QueryingEnumerable : IEnumerable, IAsyncEnumerable private readonly RelationalQueryContext _relationalQueryContext; private readonly RelationalCommandCache _relationalCommandCache; private readonly IReadOnlyList _columnNames; + private readonly IReadOnlyList _readerColumns; private readonly Func _shaper; private readonly Type _contextType; private readonly IDiagnosticsLogger _logger; @@ -32,6 +33,7 @@ public QueryingEnumerable( RelationalQueryContext relationalQueryContext, RelationalCommandCache relationalCommandCache, IReadOnlyList columnNames, + IReadOnlyList readerColumns, Func shaper, Type contextType, IDiagnosticsLogger logger) @@ -39,6 +41,7 @@ public QueryingEnumerable( _relationalQueryContext = relationalQueryContext; _relationalCommandCache = relationalCommandCache; _columnNames = columnNames; + _readerColumns = readerColumns; _shaper = shaper; _contextType = contextType; _logger = logger; @@ -50,11 +53,38 @@ public virtual IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancella public virtual IEnumerator GetEnumerator() => new Enumerator(this); IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + public static int[] BuildIndexMap(IReadOnlyList columnNames, DbDataReader dataReader) + { + if (columnNames == null) + { + return null; + } + + // Non-Composed FromSql + var readerColumns = Enumerable.Range(0, dataReader.FieldCount) + .ToDictionary(dataReader.GetName, i => i, StringComparer.OrdinalIgnoreCase); + + var indexMap = new int[columnNames.Count]; + for (var i = 0; i < columnNames.Count; i++) + { + var columnName = columnNames[i]; + if (!readerColumns.TryGetValue(columnName, out var ordinal)) + { + throw new InvalidOperationException(RelationalStrings.FromSqlMissingColumn(columnName)); + } + + indexMap[i] = ordinal; + } + + return indexMap; + } + private sealed class Enumerator : IEnumerator { private readonly RelationalQueryContext _relationalQueryContext; private readonly RelationalCommandCache _relationalCommandCache; private readonly IReadOnlyList _columnNames; + private readonly IReadOnlyList _readerColumns; private readonly Func _shaper; private readonly Type _contextType; private readonly IDiagnosticsLogger _logger; @@ -62,12 +92,14 @@ private sealed class Enumerator : IEnumerator private RelationalDataReader _dataReader; private int[] _indexMap; private ResultCoordinator _resultCoordinator; + private IExecutionStrategy _executionStrategy; public Enumerator(QueryingEnumerable queryingEnumerable) { _relationalQueryContext = queryingEnumerable._relationalQueryContext; _relationalCommandCache = queryingEnumerable._relationalCommandCache; _columnNames = queryingEnumerable._columnNames; + _readerColumns = queryingEnumerable._readerColumns; _shaper = queryingEnumerable._shaper; _contextType = queryingEnumerable._contextType; _logger = queryingEnumerable._logger; @@ -85,41 +117,12 @@ public bool MoveNext() { if (_dataReader == null) { - var relationalCommand = _relationalCommandCache.GetRelationalCommand( - _relationalQueryContext.ParameterValues); - - _dataReader - = relationalCommand.ExecuteReader( - new RelationalCommandParameterObject( - _relationalQueryContext.Connection, - _relationalQueryContext.ParameterValues, - _relationalQueryContext.Context, - _relationalQueryContext.CommandLogger)); - - // Non-Composed FromSql - if (_columnNames != null) - { - var readerColumns = Enumerable.Range(0, _dataReader.DbDataReader.FieldCount) - .ToDictionary(i => _dataReader.DbDataReader.GetName(i), i => i, StringComparer.OrdinalIgnoreCase); - - _indexMap = new int[_columnNames.Count]; - for (var i = 0; i < _columnNames.Count; i++) - { - var columnName = _columnNames[i]; - if (!readerColumns.TryGetValue(columnName, out var ordinal)) - { - throw new InvalidOperationException(RelationalStrings.FromSqlMissingColumn(columnName)); - } - - _indexMap[i] = ordinal; - } - } - else + if (_executionStrategy == null) { - _indexMap = null; + _executionStrategy = _relationalQueryContext.ExecutionStrategyFactory.Create(); } - _resultCoordinator = new ResultCoordinator(); + _executionStrategy.Execute(true, InitializeReader, null); } var hasNext = _resultCoordinator.HasNext ?? _dataReader.Read(); @@ -166,6 +169,26 @@ public bool MoveNext() } } + private bool InitializeReader(DbContext _, bool result) + { + var relationalCommand = _relationalCommandCache.GetRelationalCommand(_relationalQueryContext.ParameterValues); + + _dataReader + = relationalCommand.ExecuteReader( + new RelationalCommandParameterObject( + _relationalQueryContext.Connection, + _relationalQueryContext.ParameterValues, + _readerColumns, + _relationalQueryContext.Context, + _relationalQueryContext.CommandLogger)); + + _indexMap = BuildIndexMap(_columnNames, _dataReader.DbDataReader); + + _resultCoordinator = new ResultCoordinator(); + + return result; + } + public void Dispose() { _dataReader?.Dispose(); @@ -180,6 +203,7 @@ private sealed class AsyncEnumerator : IAsyncEnumerator private readonly RelationalQueryContext _relationalQueryContext; private readonly RelationalCommandCache _relationalCommandCache; private readonly IReadOnlyList _columnNames; + private readonly IReadOnlyList _readerColumns; private readonly Func _shaper; private readonly Type _contextType; private readonly IDiagnosticsLogger _logger; @@ -188,6 +212,7 @@ private sealed class AsyncEnumerator : IAsyncEnumerator private RelationalDataReader _dataReader; private int[] _indexMap; private ResultCoordinator _resultCoordinator; + private IExecutionStrategy _executionStrategy; public AsyncEnumerator( QueryingEnumerable queryingEnumerable, @@ -196,6 +221,7 @@ public AsyncEnumerator( _relationalQueryContext = queryingEnumerable._relationalQueryContext; _relationalCommandCache = queryingEnumerable._relationalCommandCache; _columnNames = queryingEnumerable._columnNames; + _readerColumns = queryingEnumerable._readerColumns; _shaper = queryingEnumerable._shaper; _contextType = queryingEnumerable._contextType; _logger = queryingEnumerable._logger; @@ -212,42 +238,12 @@ public async ValueTask MoveNextAsync() { if (_dataReader == null) { - var relationalCommand = _relationalCommandCache.GetRelationalCommand( - _relationalQueryContext.ParameterValues); - - _dataReader - = await relationalCommand.ExecuteReaderAsync( - new RelationalCommandParameterObject( - _relationalQueryContext.Connection, - _relationalQueryContext.ParameterValues, - _relationalQueryContext.Context, - _relationalQueryContext.CommandLogger), - _cancellationToken); - - // Non-Composed FromSql - if (_columnNames != null) + if (_executionStrategy == null) { - var readerColumns = Enumerable.Range(0, _dataReader.DbDataReader.FieldCount) - .ToDictionary(i => _dataReader.DbDataReader.GetName(i), i => i, StringComparer.OrdinalIgnoreCase); - - _indexMap = new int[_columnNames.Count]; - for (var i = 0; i < _columnNames.Count; i++) - { - var columnName = _columnNames[i]; - if (!readerColumns.TryGetValue(columnName, out var ordinal)) - { - throw new InvalidOperationException(RelationalStrings.FromSqlMissingColumn(columnName)); - } - - _indexMap[i] = ordinal; - } - } - else - { - _indexMap = null; + _executionStrategy = _relationalQueryContext.ExecutionStrategyFactory.Create(); } - _resultCoordinator = new ResultCoordinator(); + await _executionStrategy.ExecuteAsync(true, InitializeReaderAsync, null, _cancellationToken); } var hasNext = _resultCoordinator.HasNext ?? await _dataReader.ReadAsync(_cancellationToken); @@ -294,6 +290,28 @@ public async ValueTask MoveNextAsync() } } + private async Task InitializeReaderAsync(DbContext _, bool result, CancellationToken cancellationToken) + { + var relationalCommand = _relationalCommandCache.GetRelationalCommand( + _relationalQueryContext.ParameterValues); + + _dataReader + = await relationalCommand.ExecuteReaderAsync( + new RelationalCommandParameterObject( + _relationalQueryContext.Connection, + _relationalQueryContext.ParameterValues, + _readerColumns, + _relationalQueryContext.Context, + _relationalQueryContext.CommandLogger), + cancellationToken); + + _indexMap = BuildIndexMap(_columnNames, _dataReader.DbDataReader); + + _resultCoordinator = new ResultCoordinator(); + + return result; + } + public ValueTask DisposeAsync() { if (_dataReader != null) diff --git a/src/EFCore.Relational/Query/Internal/RelationalShapedQueryCompilingExpressionVisitorFactory.cs b/src/EFCore.Relational/Query/Internal/RelationalShapedQueryCompilingExpressionVisitorFactory.cs index 1dde1662282..b1db95792fb 100644 --- a/src/EFCore.Relational/Query/Internal/RelationalShapedQueryCompilingExpressionVisitorFactory.cs +++ b/src/EFCore.Relational/Query/Internal/RelationalShapedQueryCompilingExpressionVisitorFactory.cs @@ -29,11 +29,9 @@ public RelationalShapedQueryCompilingExpressionVisitorFactory( } public virtual ShapedQueryCompilingExpressionVisitor Create(QueryCompilationContext queryCompilationContext) - { - return new RelationalShapedQueryCompilingExpressionVisitor( + => new RelationalShapedQueryCompilingExpressionVisitor( _dependencies, _relationalDependencies, queryCompilationContext); - } } } diff --git a/src/EFCore.Relational/Query/QuerySqlGenerator.cs b/src/EFCore.Relational/Query/QuerySqlGenerator.cs index 3e6aba88f5d..4133df125a6 100644 --- a/src/EFCore.Relational/Query/QuerySqlGenerator.cs +++ b/src/EFCore.Relational/Query/QuerySqlGenerator.cs @@ -15,7 +15,7 @@ namespace Microsoft.EntityFrameworkCore.Query { public class QuerySqlGenerator : SqlExpressionVisitor { - private static readonly Regex _composibleSql + private static readonly Regex _composableSql = new Regex(@"^\s*?SELECT\b", RegexOptions.IgnoreCase, TimeSpan.FromMilliseconds(value: 1000.0)); private readonly IRelationalCommandBuilderFactory _relationalCommandBuilderFactory; @@ -319,7 +319,7 @@ protected override Expression VisitFromSql(FromSqlExpression fromSqlExpression) { _relationalCommandBuilder.AppendLine("("); - if (!_composibleSql.IsMatch(fromSqlExpression.Sql)) + if (!_composableSql.IsMatch(fromSqlExpression.Sql)) { throw new InvalidOperationException(RelationalStrings.FromSqlNonComposable); } diff --git a/src/EFCore.Relational/Query/RelationalCompiledQueryCacheKeyGenerator.cs b/src/EFCore.Relational/Query/RelationalCompiledQueryCacheKeyGenerator.cs index d1e1a8115ad..c2a4f3c5a2b 100644 --- a/src/EFCore.Relational/Query/RelationalCompiledQueryCacheKeyGenerator.cs +++ b/src/EFCore.Relational/Query/RelationalCompiledQueryCacheKeyGenerator.cs @@ -66,7 +66,8 @@ public override object GenerateCacheKey(Expression query, bool async) protected new RelationalCompiledQueryCacheKey GenerateCacheKeyCore([NotNull] Expression query, bool async) => new RelationalCompiledQueryCacheKey( base.GenerateCacheKeyCore(query, async), - RelationalOptionsExtension.Extract(RelationalDependencies.ContextOptions).UseRelationalNulls); + RelationalOptionsExtension.Extract(RelationalDependencies.ContextOptions).UseRelationalNulls, + shouldBuffer: Dependencies.IsRetryingExecutionStrategy); /// /// @@ -82,17 +83,20 @@ protected readonly struct RelationalCompiledQueryCacheKey { private readonly CompiledQueryCacheKey _compiledQueryCacheKey; private readonly bool _useRelationalNulls; + private readonly bool _shouldBuffer; /// /// Initializes a new instance of the class. /// /// The non-relational cache key. /// True to use relational null logic. + /// True if the query should be buffered. public RelationalCompiledQueryCacheKey( - CompiledQueryCacheKey compiledQueryCacheKey, bool useRelationalNulls) + CompiledQueryCacheKey compiledQueryCacheKey, bool useRelationalNulls, bool shouldBuffer) { _compiledQueryCacheKey = compiledQueryCacheKey; _useRelationalNulls = useRelationalNulls; + _shouldBuffer = shouldBuffer; } /// @@ -106,12 +110,13 @@ public RelationalCompiledQueryCacheKey( /// public override bool Equals(object obj) => !(obj is null) - && obj is RelationalCompiledQueryCacheKey - && Equals((RelationalCompiledQueryCacheKey)obj); + && obj is RelationalCompiledQueryCacheKey key + && Equals(key); private bool Equals(RelationalCompiledQueryCacheKey other) => _compiledQueryCacheKey.Equals(other._compiledQueryCacheKey) - && _useRelationalNulls == other._useRelationalNulls; + && _useRelationalNulls == other._useRelationalNulls + && _shouldBuffer == other._shouldBuffer; /// /// Gets the hash code for the key. @@ -119,7 +124,7 @@ private bool Equals(RelationalCompiledQueryCacheKey other) /// /// The hash code for the key. /// - public override int GetHashCode() => HashCode.Combine(_compiledQueryCacheKey, _useRelationalNulls); + public override int GetHashCode() => HashCode.Combine(_compiledQueryCacheKey, _useRelationalNulls, _shouldBuffer); } } } diff --git a/src/EFCore.Relational/Query/RelationalQueryContext.cs b/src/EFCore.Relational/Query/RelationalQueryContext.cs index 6601923169c..14ccb558bb1 100644 --- a/src/EFCore.Relational/Query/RelationalQueryContext.cs +++ b/src/EFCore.Relational/Query/RelationalQueryContext.cs @@ -46,14 +46,5 @@ public RelationalQueryContext( /// public virtual IRelationalConnection Connection => RelationalDependencies.RelationalConnection; - - /// - /// The execution strategy factory. - /// - /// - /// The execution strategy factory. - /// - public virtual IExecutionStrategyFactory ExecutionStrategyFactory - => RelationalDependencies.ExecutionStrategyFactory; } } diff --git a/src/EFCore.Relational/Query/RelationalQueryContextDependencies.cs b/src/EFCore.Relational/Query/RelationalQueryContextDependencies.cs index ed335407559..7ad601c4ca8 100644 --- a/src/EFCore.Relational/Query/RelationalQueryContextDependencies.cs +++ b/src/EFCore.Relational/Query/RelationalQueryContextDependencies.cs @@ -1,6 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; using JetBrains.Annotations; using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Storage; @@ -62,7 +63,9 @@ public RelationalQueryContextDependencies( Check.NotNull(executionStrategyFactory, nameof(executionStrategyFactory)); RelationalConnection = relationalConnection; +#pragma warning disable 618 ExecutionStrategyFactory = executionStrategyFactory; +#pragma warning restore 618 } /// @@ -73,6 +76,7 @@ public RelationalQueryContextDependencies( /// /// The execution strategy. /// + [Obsolete("Moved to QueryContextDependencies")] public IExecutionStrategyFactory ExecutionStrategyFactory { get; } /// @@ -81,7 +85,9 @@ public RelationalQueryContextDependencies( /// A replacement for the current dependency of this type. /// A new parameter object with the given service replaced. public RelationalQueryContextDependencies With([NotNull] IRelationalConnection relationalConnection) +#pragma warning disable 618 => new RelationalQueryContextDependencies(relationalConnection, ExecutionStrategyFactory); +#pragma warning restore 618 /// /// Clones this dependency parameter object with one service replaced. diff --git a/src/EFCore.Relational/Query/RelationalShapedQueryCompilingExpressionVisitor.RelationalProjectionBindingRemovingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalShapedQueryCompilingExpressionVisitor.RelationalProjectionBindingRemovingExpressionVisitor.cs index 03d90077a1f..e7eb6c00454 100644 --- a/src/EFCore.Relational/Query/RelationalShapedQueryCompilingExpressionVisitor.RelationalProjectionBindingRemovingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalShapedQueryCompilingExpressionVisitor.RelationalProjectionBindingRemovingExpressionVisitor.cs @@ -9,6 +9,7 @@ using System.Reflection; using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Query.Internal; using Microsoft.EntityFrameworkCore.Query.SqlExpressions; using Microsoft.EntityFrameworkCore.Storage; @@ -23,15 +24,33 @@ private class RelationalProjectionBindingRemovingExpressionVisitor : ExpressionV private readonly SelectExpression _selectExpression; private readonly ParameterExpression _dbDataReaderParameter; + private readonly ParameterExpression _indexMapParameter; private readonly IDictionary> _materializationContextBindings = new Dictionary>(); public RelationalProjectionBindingRemovingExpressionVisitor( - SelectExpression selectExpression, ParameterExpression dbDataReaderParameter) + SelectExpression selectExpression, + ParameterExpression dbDataReaderParameter, + ParameterExpression indexMapParameter, + bool buffer) { _selectExpression = selectExpression; _dbDataReaderParameter = dbDataReaderParameter; + _indexMapParameter = indexMapParameter; + if (buffer) + { + ProjectionColumns = new ReaderColumn[selectExpression.Projection.Count]; + } + } + + private ReaderColumn[] ProjectionColumns { get; } + + public virtual Expression Visit(Expression node, out IReadOnlyList projectionColumns) + { + var result = Visit(node); + projectionColumns = ProjectionColumns; + return result; } protected override Expression VisitBinary(BinaryExpression binaryExpression) @@ -118,8 +137,8 @@ private object GetProjectionIndex(ProjectionBindingExpression projectionBindingE private static bool IsNullableProjection(ProjectionExpression projection) => !(projection.Expression is ColumnExpression column) || column.IsNullable; - private static Expression CreateGetValueExpression( - Expression dbDataReader, + private Expression CreateGetValueExpression( + ParameterExpression dbDataReader, int index, bool nullable, RelationalTypeMapping typeMapping, @@ -127,16 +146,52 @@ private static Expression CreateGetValueExpression( { var getMethod = typeMapping.GetDataReaderMethod(); - var indexExpression = Expression.Constant(index); + Expression indexExpression = Expression.Constant(index); + if (_indexMapParameter != null) + { + indexExpression = Expression.ArrayIndex(_indexMapParameter, indexExpression); + } Expression valueExpression = Expression.Call( getMethod.DeclaringType != typeof(DbDataReader) ? Expression.Convert(dbDataReader, getMethod.DeclaringType) - : dbDataReader, + : (Expression)dbDataReader, getMethod, indexExpression); + if (ProjectionColumns != null) + { + var columnType = valueExpression.Type; + if (!columnType.IsValueType + || !BufferedDataReader.IsSupportedValueType(columnType)) + { + columnType = typeof(object); + valueExpression = Expression.Convert(valueExpression, typeof(object)); + } + + if (ProjectionColumns[index] == null) + { + ProjectionColumns[index] = ReaderColumn.Create( + columnType, + nullable, + _indexMapParameter != null ? ((ColumnExpression)_selectExpression.Projection[index].Expression).Name : null, + Expression.Lambda( + valueExpression, + dbDataReader, + _indexMapParameter ?? Expression.Parameter(typeof(int[]))).Compile()); + } + + if (getMethod.DeclaringType != typeof(DbDataReader)) + { + valueExpression + = Expression.Call( + dbDataReader, + RelationalTypeMapping.GetDataReaderMethod(columnType), + indexExpression); + } + } + valueExpression = typeMapping.CustomizeDataReaderExpression(valueExpression); var converter = typeMapping.Converter; diff --git a/src/EFCore.Relational/Query/RelationalShapedQueryCompilingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalShapedQueryCompilingExpressionVisitor.cs index 6eb38022f00..0ff1e0f4339 100644 --- a/src/EFCore.Relational/Query/RelationalShapedQueryCompilingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalShapedQueryCompilingExpressionVisitor.cs @@ -10,6 +10,7 @@ using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Query.Internal; using Microsoft.EntityFrameworkCore.Query.SqlExpressions; +using Microsoft.EntityFrameworkCore.Storage; namespace Microsoft.EntityFrameworkCore.Query { @@ -54,16 +55,20 @@ protected override Expression VisitShapedQueryExpression(ShapedQueryExpression s shaper = InjectEntityMaterializers(shaper); - shaper = new RelationalProjectionBindingRemovingExpressionVisitor(selectExpression, dataReaderParameter) - .Visit(shaper); - shaper = new CustomShaperCompilingExpressionVisitor( - dataReaderParameter, resultCoordinatorParameter, IsTracking) + var isNonComposedFromSql = selectExpression.IsNonComposedFromSql(); + shaper = new RelationalProjectionBindingRemovingExpressionVisitor( + selectExpression, + dataReaderParameter, + isNonComposedFromSql ? indexMapParameter : null, + IsBuffering) + .Visit(shaper, out var projectionColumns); + + shaper = new CustomShaperCompilingExpressionVisitor(dataReaderParameter, resultCoordinatorParameter, IsTracking) .Visit(shaper); IReadOnlyList columnNames = null; - if (selectExpression.IsNonComposedFromSql()) + if (isNonComposedFromSql) { - shaper = new IndexMapInjectingExpressionVisitor(indexMapParameter).Visit(shaper); columnNames = selectExpression.Projection.Select(pe => ((ColumnExpression)pe.Expression).Name).ToList(); } @@ -82,33 +87,10 @@ protected override Expression VisitShapedQueryExpression(ShapedQueryExpression s Expression.Convert(QueryCompilationContext.QueryContextParameter, typeof(RelationalQueryContext)), Expression.Constant(relationalCommandCache), Expression.Constant(columnNames, typeof(IReadOnlyList)), + Expression.Constant(projectionColumns, typeof(IReadOnlyList)), Expression.Constant(shaperLambda.Compile()), Expression.Constant(_contextType), Expression.Constant(_logger)); } - - private class IndexMapInjectingExpressionVisitor : ExpressionVisitor - { - private readonly ParameterExpression _indexMapParameter; - - public IndexMapInjectingExpressionVisitor(ParameterExpression indexMapParameter) - { - _indexMapParameter = indexMapParameter; - } - - protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) - { - if (methodCallExpression.Object != null - && typeof(DbDataReader).IsAssignableFrom(methodCallExpression.Object.Type)) - { - var indexArgument = methodCallExpression.Arguments[0]; - return methodCallExpression.Update( - methodCallExpression.Object, - new[] { Expression.ArrayIndex(_indexMapParameter, indexArgument) }); - } - - return base.VisitMethodCall(methodCallExpression); - } - } } } diff --git a/src/EFCore.Relational/Storage/ReaderColumn.cs b/src/EFCore.Relational/Storage/ReaderColumn.cs new file mode 100644 index 00000000000..15136a143cb --- /dev/null +++ b/src/EFCore.Relational/Storage/ReaderColumn.cs @@ -0,0 +1,52 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Concurrent; +using System.Reflection; +using JetBrains.Annotations; + +namespace Microsoft.EntityFrameworkCore.Storage +{ + /// + /// + /// An expected column in the relational data reader. + /// + /// + /// This type is typically used by database providers (and other extensions). It is generally + /// not used in application code. + /// + /// + public abstract class ReaderColumn + { + private static readonly ConcurrentDictionary _constructors + = new ConcurrentDictionary(); + + protected ReaderColumn([NotNull] Type type, bool nullable, [CanBeNull] string name) + { + Type = type; + IsNullable = nullable; + Name = name; + } + + public virtual Type Type { get; } + public virtual bool IsNullable { get; } + public virtual string Name { get; } + + /// + /// Creates an instance of . + /// + /// The type of the column. + /// Whether the column can contain null values. + /// The column name if it is used to access the column values, null otherwise. + /// + /// A used to get the field value for this column. + /// + /// An instance of . + public static ReaderColumn Create([NotNull] Type type, bool nullable, [CanBeNull] string columnName, [NotNull] object readFunc) + => (ReaderColumn)GetConstructor(type).Invoke(new[] { nullable, columnName, readFunc }); + + private static ConstructorInfo GetConstructor(Type type) + => _constructors.GetOrAdd(type, t => typeof(ReaderColumn<>).MakeGenericType(t).GetConstructors()[0]); + } +} diff --git a/src/EFCore.Relational/Storage/ReaderColumn`.cs b/src/EFCore.Relational/Storage/ReaderColumn`.cs new file mode 100644 index 00000000000..bbb0e47f5f0 --- /dev/null +++ b/src/EFCore.Relational/Storage/ReaderColumn`.cs @@ -0,0 +1,29 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Data.Common; +using JetBrains.Annotations; + +namespace Microsoft.EntityFrameworkCore.Storage +{ + /// + /// + /// An expected column in the relational data reader. + /// + /// + /// This type is typically used by database providers (and other extensions). It is generally + /// not used in application code. + /// + /// + public class ReaderColumn : ReaderColumn + { + public ReaderColumn(bool nullable, [CanBeNull] string name, [NotNull] Func getFieldValue) + : base(typeof(T), nullable, name) + { + GetFieldValue = getFieldValue; + } + + public virtual Func GetFieldValue { get; } + } +} diff --git a/src/EFCore.Relational/Storage/RelationalCommand.cs b/src/EFCore.Relational/Storage/RelationalCommand.cs index ca9c6572d3e..9a653c17004 100644 --- a/src/EFCore.Relational/Storage/RelationalCommand.cs +++ b/src/EFCore.Relational/Storage/RelationalCommand.cs @@ -9,6 +9,7 @@ using System.Threading.Tasks; using JetBrains.Annotations; using Microsoft.EntityFrameworkCore.Diagnostics; +using Microsoft.EntityFrameworkCore.Query.Internal; using Microsoft.EntityFrameworkCore.Utilities; namespace Microsoft.EntityFrameworkCore.Storage @@ -375,7 +376,10 @@ await logger.CommandErrorAsync( /// The result of the command. public virtual RelationalDataReader ExecuteReader(RelationalCommandParameterObject parameterObject) { - var (connection, context, logger) = (parameterObject.Connection, parameterObject.Context, parameterObject.Logger); + var connection = parameterObject.Connection; + var context = parameterObject.Context; + var readerColumns = parameterObject.ReaderColumns; + var logger = parameterObject.Logger; var commandId = Guid.NewGuid(); var command = CreateCommand(parameterObject, commandId, DbCommandMethod.ExecuteReader); @@ -414,6 +418,11 @@ public virtual RelationalDataReader ExecuteReader(RelationalCommandParameterObje stopwatch.Elapsed); } + if (readerColumns != null) + { + reader = new BufferedDataReader(reader).Initialize(readerColumns); + } + var result = new RelationalDataReader( connection, command, @@ -461,7 +470,10 @@ public virtual async Task ExecuteReaderAsync( RelationalCommandParameterObject parameterObject, CancellationToken cancellationToken = default) { - var (connection, context, logger) = (parameterObject.Connection, parameterObject.Context, parameterObject.Logger); + var connection = parameterObject.Connection; + var context = parameterObject.Context; + var readerColumns = parameterObject.ReaderColumns; + var logger = parameterObject.Logger; var commandId = Guid.NewGuid(); var command = CreateCommand(parameterObject, commandId, DbCommandMethod.ExecuteReader); @@ -503,6 +515,11 @@ public virtual async Task ExecuteReaderAsync( cancellationToken); } + if (readerColumns != null) + { + reader = await new BufferedDataReader(reader).InitializeAsync(readerColumns, cancellationToken); + } + var result = new RelationalDataReader( connection, command, diff --git a/src/EFCore.Relational/Storage/RelationalCommandParameterObject.cs b/src/EFCore.Relational/Storage/RelationalCommandParameterObject.cs index 26ebcb9656a..6aea922e97c 100644 --- a/src/EFCore.Relational/Storage/RelationalCommandParameterObject.cs +++ b/src/EFCore.Relational/Storage/RelationalCommandParameterObject.cs @@ -30,11 +30,13 @@ public readonly struct RelationalCommandParameterObject /// /// The connection on which the command will execute. /// The SQL parameter values to use, or null if none. + /// The expected columns if the reader needs to be buffered, or null otherwise. /// The current instance, or null if it is not known. /// A logger, or null if no logger is available. public RelationalCommandParameterObject( [NotNull] IRelationalConnection connection, [CanBeNull] IReadOnlyDictionary parameterValues, + [CanBeNull] IReadOnlyList readerColumns, [CanBeNull] DbContext context, [CanBeNull] IDiagnosticsLogger logger) { @@ -42,6 +44,7 @@ public RelationalCommandParameterObject( Connection = connection; ParameterValues = parameterValues; + ReaderColumns = readerColumns; Context = context; Logger = logger; } @@ -56,6 +59,11 @@ public RelationalCommandParameterObject( /// public IReadOnlyDictionary ParameterValues { get; } + /// + /// The expected columns if the reader needs to be buffered, or null otherwise. + /// + public IReadOnlyList ReaderColumns { get; } + /// /// The current instance, or null if it is not known. /// diff --git a/src/EFCore.Relational/Storage/RelationalTypeMapping.cs b/src/EFCore.Relational/Storage/RelationalTypeMapping.cs index 4bdaf5933c9..61038d56262 100644 --- a/src/EFCore.Relational/Storage/RelationalTypeMapping.cs +++ b/src/EFCore.Relational/Storage/RelationalTypeMapping.cs @@ -523,10 +523,18 @@ public virtual MethodInfo GetDataReaderMethod() { var type = (Converter?.ProviderClrType ?? ClrType).UnwrapNullableType(); - return _getXMethods.TryGetValue(type, out var method) + return GetDataReaderMethod(type); + } + + /// + /// The method to use when reading values of the given type. The method must be defined + /// on . + /// + /// The method to use to read the value. + public static MethodInfo GetDataReaderMethod([NotNull] Type type) + => _getXMethods.TryGetValue(type, out var method) ? method : _getFieldValueMethod.MakeGenericMethod(type); - } /// /// Gets a custom expression tree for reading the value from the input data reader diff --git a/src/EFCore.Relational/Update/Internal/BatchExecutor.cs b/src/EFCore.Relational/Update/Internal/BatchExecutor.cs index 18c8a293574..3e35e31e650 100644 --- a/src/EFCore.Relational/Update/Internal/BatchExecutor.cs +++ b/src/EFCore.Relational/Update/Internal/BatchExecutor.cs @@ -1,6 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; @@ -37,7 +38,9 @@ public class BatchExecutor : IBatchExecutor public BatchExecutor([NotNull] ICurrentDbContext currentContext, [NotNull] IExecutionStrategyFactory executionStrategyFactory) { CurrentContext = currentContext; +#pragma warning disable 618 ExecutionStrategyFactory = executionStrategyFactory; +#pragma warning restore 618 } /// @@ -54,6 +57,7 @@ public BatchExecutor([NotNull] ICurrentDbContext currentContext, [NotNull] IExec /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// + [Obsolete("This isn't used anymore")] protected virtual IExecutionStrategyFactory ExecutionStrategyFactory { get; } /// @@ -65,14 +69,7 @@ public BatchExecutor([NotNull] ICurrentDbContext currentContext, [NotNull] IExec public virtual int Execute( IEnumerable commandBatches, IRelationalConnection connection) - => CurrentContext.Context.Database.AutoTransactionsEnabled - ? ExecutionStrategyFactory.Create().Execute((commandBatches, connection), Execute, null) - : Execute(CurrentContext.Context, (commandBatches, connection)); - - private int Execute(DbContext _, (IEnumerable, IRelationalConnection) parameters) { - var commandBatches = parameters.Item1; - var connection = parameters.Item2; var rowsAffected = 0; IDbContextTransaction startedTransaction = null; try @@ -118,21 +115,11 @@ private int Execute(DbContext _, (IEnumerable, IRelati /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - public virtual Task ExecuteAsync( + public virtual async Task ExecuteAsync( IEnumerable commandBatches, IRelationalConnection connection, CancellationToken cancellationToken = default) - => CurrentContext.Context.Database.AutoTransactionsEnabled - ? ExecutionStrategyFactory.Create().ExecuteAsync((commandBatches, connection), ExecuteAsync, null, cancellationToken) - : ExecuteAsync(CurrentContext.Context, (commandBatches, connection), cancellationToken); - - private async Task ExecuteAsync( - DbContext _, - (IEnumerable, IRelationalConnection) parameters, - CancellationToken cancellationToken = default) { - var commandBatches = parameters.Item1; - var connection = parameters.Item2; var rowsAffected = 0; IDbContextTransaction startedTransaction = null; try diff --git a/src/EFCore.Relational/Update/ReaderModificationCommandBatch.cs b/src/EFCore.Relational/Update/ReaderModificationCommandBatch.cs index 75d8f1d01d8..59dd5ed4d0c 100644 --- a/src/EFCore.Relational/Update/ReaderModificationCommandBatch.cs +++ b/src/EFCore.Relational/Update/ReaderModificationCommandBatch.cs @@ -239,6 +239,7 @@ public override void Execute(IRelationalConnection connection) new RelationalCommandParameterObject( connection, storeCommand.ParameterValues, + null, Dependencies.CurrentContext.Context, Dependencies.Logger))) { @@ -276,6 +277,7 @@ public override async Task ExecuteAsync( new RelationalCommandParameterObject( connection, storeCommand.ParameterValues, + null, Dependencies.CurrentContext.Context, Dependencies.Logger), cancellationToken)) diff --git a/src/EFCore.SqlServer/Query/Internal/SqlServerCompiledQueryCacheKeyGenerator.cs b/src/EFCore.SqlServer/Query/Internal/SqlServerCompiledQueryCacheKeyGenerator.cs index e0b6324b7cc..3a5be3c0086 100644 --- a/src/EFCore.SqlServer/Query/Internal/SqlServerCompiledQueryCacheKeyGenerator.cs +++ b/src/EFCore.SqlServer/Query/Internal/SqlServerCompiledQueryCacheKeyGenerator.cs @@ -64,8 +64,8 @@ public SqlServerCompiledQueryCacheKey( public override bool Equals(object obj) => !(obj is null) - && obj is SqlServerCompiledQueryCacheKey - && Equals((SqlServerCompiledQueryCacheKey)obj); + && obj is SqlServerCompiledQueryCacheKey key + && Equals(key); private bool Equals(SqlServerCompiledQueryCacheKey other) => _relationalCompiledQueryCacheKey.Equals(other._relationalCompiledQueryCacheKey) diff --git a/src/EFCore.SqlServer/Storage/Internal/SqlServerDatabaseCreator.cs b/src/EFCore.SqlServer/Storage/Internal/SqlServerDatabaseCreator.cs index f952b4f131d..81d20e5bfd7 100644 --- a/src/EFCore.SqlServer/Storage/Internal/SqlServerDatabaseCreator.cs +++ b/src/EFCore.SqlServer/Storage/Internal/SqlServerDatabaseCreator.cs @@ -121,6 +121,7 @@ public override bool HasTables() new RelationalCommandParameterObject( connection, null, + null, Dependencies.CurrentContext.Context, Dependencies.CommandLogger)) != 0); @@ -139,6 +140,7 @@ public override Task HasTablesAsync(CancellationToken cancellationToken = new RelationalCommandParameterObject( connection, null, + null, Dependencies.CurrentContext.Context, Dependencies.CommandLogger), cancellationToken: ct) @@ -198,6 +200,7 @@ private bool Exists(bool retryOnNotExists) new RelationalCommandParameterObject( _connection, null, + null, Dependencies.CurrentContext.Context, Dependencies.CommandLogger)); @@ -258,6 +261,7 @@ await _rawSqlCommandBuilder new RelationalCommandParameterObject( _connection, null, + null, Dependencies.CurrentContext.Context, Dependencies.CommandLogger), ct); diff --git a/src/EFCore.SqlServer/ValueGeneration/Internal/SqlServerSequenceHiLoValueGenerator.cs b/src/EFCore.SqlServer/ValueGeneration/Internal/SqlServerSequenceHiLoValueGenerator.cs index 4e7030c4626..fe284511fe4 100644 --- a/src/EFCore.SqlServer/ValueGeneration/Internal/SqlServerSequenceHiLoValueGenerator.cs +++ b/src/EFCore.SqlServer/ValueGeneration/Internal/SqlServerSequenceHiLoValueGenerator.cs @@ -63,8 +63,9 @@ protected override long GetNewLowValue() .ExecuteScalar( new RelationalCommandParameterObject( _connection, - null, - null, + parameterValues: null, + readerColumns: null, + context: null, _commandLogger)), typeof(long), CultureInfo.InvariantCulture); @@ -82,8 +83,9 @@ await _rawSqlCommandBuilder .ExecuteScalarAsync( new RelationalCommandParameterObject( _connection, - null, - null, + parameterValues: null, + readerColumns: null, + context: null, _commandLogger), cancellationToken), typeof(long), diff --git a/src/EFCore.Sqlite.Core/Storage/Internal/SqliteDatabaseCreator.cs b/src/EFCore.Sqlite.Core/Storage/Internal/SqliteDatabaseCreator.cs index 0513775c33c..03076df737c 100644 --- a/src/EFCore.Sqlite.Core/Storage/Internal/SqliteDatabaseCreator.cs +++ b/src/EFCore.Sqlite.Core/Storage/Internal/SqliteDatabaseCreator.cs @@ -64,6 +64,7 @@ public override void Create() Dependencies.Connection, null, null, + null, Dependencies.CommandLogger)); Dependencies.Connection.Close(); @@ -114,6 +115,7 @@ public override bool HasTables() Dependencies.Connection, null, null, + null, Dependencies.CommandLogger)); return count != 0; diff --git a/src/EFCore/ChangeTracking/Internal/StateManager.cs b/src/EFCore/ChangeTracking/Internal/StateManager.cs index 26dbdba0bd2..f8531d68bf2 100644 --- a/src/EFCore/ChangeTracking/Internal/StateManager.cs +++ b/src/EFCore/ChangeTracking/Internal/StateManager.cs @@ -884,47 +884,6 @@ public virtual IEntityFinder CreateEntityFinder(IEntityType entityType) /// public virtual int ChangedCount { get; set; } - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public virtual int SaveChanges(bool acceptAllChangesOnSuccess) - { - if (ChangedCount == 0) - { - return 0; - } - - var entriesToSave = GetEntriesToSave(cascadeChanges: true); - if (entriesToSave.Count == 0) - { - return 0; - } - - try - { - var result = SaveChanges(entriesToSave); - - if (acceptAllChangesOnSuccess) - { - AcceptAllChanges((IReadOnlyList)entriesToSave); - } - - return result; - } - catch - { - foreach (var entry in entriesToSave) - { - ((InternalEntityEntry)entry).DiscardStoreGeneratedValues(); - } - - throw; - } - } - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -1071,14 +1030,51 @@ private static bool KeyValuesEqual(IProperty property, object value, object curr ?.Equals(currentValue, value) ?? Equals(currentValue, value); + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - public virtual async Task SaveChangesAsync( - bool acceptAllChangesOnSuccess, CancellationToken cancellationToken = default) + protected virtual int SaveChanges( + [NotNull] IList entriesToSave) + { + using (_concurrencyDetector.EnterCriticalSection()) + { + return _database.SaveChanges(entriesToSave); + } + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected virtual async Task SaveChangesAsync( + [NotNull] IList entriesToSave, + CancellationToken cancellationToken = default) + { + using (_concurrencyDetector.EnterCriticalSection()) + { + return await _database.SaveChangesAsync(entriesToSave, cancellationToken); + } + } + + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual int SaveChanges(bool acceptAllChangesOnSuccess) + => Context.Database.AutoTransactionsEnabled + ? Dependencies.ExecutionStrategyFactory.Create().Execute(acceptAllChangesOnSuccess, SaveChanges, null) + : SaveChanges(Context, acceptAllChangesOnSuccess); + + private int SaveChanges(DbContext _, bool acceptAllChangesOnSuccess) { if (ChangedCount == 0) { @@ -1093,7 +1089,7 @@ public virtual async Task SaveChangesAsync( try { - var result = await SaveChangesAsync(entriesToSave, cancellationToken); + var result = SaveChanges(entriesToSave); if (acceptAllChangesOnSuccess) { @@ -1119,28 +1115,45 @@ public virtual async Task SaveChangesAsync( /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - protected virtual int SaveChanges( - [NotNull] IList entriesToSave) + public virtual Task SaveChangesAsync( + bool acceptAllChangesOnSuccess, CancellationToken cancellationToken = default) + => Context.Database.AutoTransactionsEnabled + ? Dependencies.ExecutionStrategyFactory.Create().ExecuteAsync(acceptAllChangesOnSuccess, SaveChangesAsync, null, cancellationToken) + : SaveChangesAsync(Context, acceptAllChangesOnSuccess, cancellationToken); + + private async Task SaveChangesAsync( + DbContext _, bool acceptAllChangesOnSuccess, CancellationToken cancellationToken) { - using (_concurrencyDetector.EnterCriticalSection()) + if (ChangedCount == 0) { - return _database.SaveChanges(entriesToSave); + return 0; } - } - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - protected virtual async Task SaveChangesAsync( - [NotNull] IList entriesToSave, - CancellationToken cancellationToken = default) - { - using (_concurrencyDetector.EnterCriticalSection()) + var entriesToSave = GetEntriesToSave(cascadeChanges: true); + if (entriesToSave.Count == 0) { - return await _database.SaveChangesAsync(entriesToSave, cancellationToken); + return 0; + } + + try + { + var result = await SaveChangesAsync(entriesToSave, cancellationToken); + + if (acceptAllChangesOnSuccess) + { + AcceptAllChanges((IReadOnlyList)entriesToSave); + } + + return result; + } + catch + { + foreach (var entry in entriesToSave) + { + ((InternalEntityEntry)entry).DiscardStoreGeneratedValues(); + } + + throw; } } diff --git a/src/EFCore/ChangeTracking/Internal/StateManagerDependencies.cs b/src/EFCore/ChangeTracking/Internal/StateManagerDependencies.cs index 570f25c2862..55d89aafab8 100644 --- a/src/EFCore/ChangeTracking/Internal/StateManagerDependencies.cs +++ b/src/EFCore/ChangeTracking/Internal/StateManagerDependencies.cs @@ -75,6 +75,7 @@ public StateManagerDependencies( [NotNull] IEntityFinderSource entityFinderSource, [NotNull] IDbSetSource setSource, [NotNull] IEntityMaterializerSource entityMaterializerSource, + [NotNull] IExecutionStrategyFactory executionStrategyFactory, [NotNull] ILoggingOptions loggingOptions, [NotNull] IDiagnosticsLogger updateLogger, [NotNull] IDiagnosticsLogger changeTrackingLogger) @@ -90,6 +91,7 @@ public StateManagerDependencies( EntityFinderSource = entityFinderSource; SetSource = setSource; EntityMaterializerSource = entityMaterializerSource; + ExecutionStrategyFactory = executionStrategyFactory; LoggingOptions = loggingOptions; UpdateLogger = updateLogger; ChangeTrackingLogger = changeTrackingLogger; @@ -183,6 +185,14 @@ public StateManagerDependencies( /// public IEntityMaterializerSource EntityMaterializerSource { get; } + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public IExecutionStrategyFactory ExecutionStrategyFactory { get; } + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -225,6 +235,7 @@ public StateManagerDependencies With([NotNull] IInternalEntityEntryFactory inter EntityFinderSource, SetSource, EntityMaterializerSource, + ExecutionStrategyFactory, LoggingOptions, UpdateLogger, ChangeTrackingLogger); @@ -247,6 +258,7 @@ public StateManagerDependencies With([NotNull] IInternalEntityEntrySubscriber in EntityFinderSource, SetSource, EntityMaterializerSource, + ExecutionStrategyFactory, LoggingOptions, UpdateLogger, ChangeTrackingLogger); @@ -269,6 +281,7 @@ public StateManagerDependencies With([NotNull] IInternalEntityEntryNotifier inte EntityFinderSource, SetSource, EntityMaterializerSource, + ExecutionStrategyFactory, LoggingOptions, UpdateLogger, ChangeTrackingLogger); @@ -291,6 +304,7 @@ public StateManagerDependencies With([NotNull] ValueGenerationManager valueGener EntityFinderSource, SetSource, EntityMaterializerSource, + ExecutionStrategyFactory, LoggingOptions, UpdateLogger, ChangeTrackingLogger); @@ -313,6 +327,7 @@ public StateManagerDependencies With([NotNull] IModel model) EntityFinderSource, SetSource, EntityMaterializerSource, + ExecutionStrategyFactory, LoggingOptions, UpdateLogger, ChangeTrackingLogger); @@ -335,6 +350,7 @@ public StateManagerDependencies With([NotNull] IDatabase database) EntityFinderSource, SetSource, EntityMaterializerSource, + ExecutionStrategyFactory, LoggingOptions, UpdateLogger, ChangeTrackingLogger); @@ -357,6 +373,7 @@ public StateManagerDependencies With([NotNull] IConcurrencyDetector concurrencyD EntityFinderSource, SetSource, EntityMaterializerSource, + ExecutionStrategyFactory, LoggingOptions, UpdateLogger, ChangeTrackingLogger); @@ -379,6 +396,7 @@ public StateManagerDependencies With([NotNull] ICurrentDbContext currentContext) EntityFinderSource, SetSource, EntityMaterializerSource, + ExecutionStrategyFactory, LoggingOptions, UpdateLogger, ChangeTrackingLogger); @@ -401,6 +419,7 @@ public StateManagerDependencies With([NotNull] IEntityFinderSource entityFinderS entityFinderSource, SetSource, EntityMaterializerSource, + ExecutionStrategyFactory, LoggingOptions, UpdateLogger, ChangeTrackingLogger); @@ -423,6 +442,7 @@ public StateManagerDependencies With([NotNull] IDbSetSource setSource) EntityFinderSource, setSource, EntityMaterializerSource, + ExecutionStrategyFactory, LoggingOptions, UpdateLogger, ChangeTrackingLogger); @@ -445,6 +465,31 @@ public StateManagerDependencies With([NotNull] IEntityMaterializerSource entityM EntityFinderSource, SetSource, entityMaterializerSource, + ExecutionStrategyFactory, + LoggingOptions, + UpdateLogger, + ChangeTrackingLogger); + + + /// + /// Clones this dependency parameter object with one service replaced. + /// + /// A replacement for the current dependency of this type. + /// A new parameter object with the given service replaced. + public StateManagerDependencies With([NotNull] IExecutionStrategyFactory executionStrategyFactory) + => new StateManagerDependencies( + InternalEntityEntryFactory, + InternalEntityEntrySubscriber, + InternalEntityEntryNotifier, + ValueGenerationManager, + Model, + Database, + ConcurrencyDetector, + CurrentContext, + EntityFinderSource, + SetSource, + EntityMaterializerSource, + executionStrategyFactory, LoggingOptions, UpdateLogger, ChangeTrackingLogger); @@ -467,6 +512,7 @@ public StateManagerDependencies With([NotNull] ILoggingOptions loggingOptions) EntityFinderSource, SetSource, EntityMaterializerSource, + ExecutionStrategyFactory, loggingOptions, UpdateLogger, ChangeTrackingLogger); @@ -489,6 +535,7 @@ public StateManagerDependencies With([NotNull] IDiagnosticsLogger public sealed class CompiledQueryCacheKeyGeneratorDependencies { + private readonly IExecutionStrategyFactory _executionStrategyFactory; + /// /// /// Creates the service dependencies parameter object for a . @@ -51,20 +54,26 @@ public sealed class CompiledQueryCacheKeyGeneratorDependencies /// the constructor at any point in this process. /// /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// The service lifetime is . This means that each + /// instance will use its own instance of this service. + /// The implementation may depend on other services registered with any lifetime. + /// The implementation does not need to be thread-safe. /// /// [EntityFrameworkInternal] - public CompiledQueryCacheKeyGeneratorDependencies([NotNull] IModel model, [NotNull] ICurrentDbContext currentContext) + public CompiledQueryCacheKeyGeneratorDependencies( + [NotNull] IModel model, + [NotNull] ICurrentDbContext currentContext, + [NotNull] IExecutionStrategyFactory executionStrategyFactory) { Check.NotNull(model, nameof(model)); Check.NotNull(currentContext, nameof(currentContext)); + Check.NotNull(executionStrategyFactory, nameof(executionStrategyFactory)); Model = model; CurrentContext = currentContext; + _executionStrategyFactory = executionStrategyFactory; + IsRetryingExecutionStrategy = executionStrategyFactory.Create().RetriesOnFailure; } /// @@ -77,13 +86,18 @@ public CompiledQueryCacheKeyGeneratorDependencies([NotNull] IModel model, [NotNu /// public ICurrentDbContext CurrentContext { get; } + /// + /// Whether the configured execution strategy can retry. + /// + public bool IsRetryingExecutionStrategy { get; } + /// /// Clones this dependency parameter object with one service replaced. /// /// A replacement for the current dependency of this type. /// A new parameter object with the given service replaced. public CompiledQueryCacheKeyGeneratorDependencies With([NotNull] IModel model) - => new CompiledQueryCacheKeyGeneratorDependencies(model, CurrentContext); + => new CompiledQueryCacheKeyGeneratorDependencies(model, CurrentContext, _executionStrategyFactory); /// /// Clones this dependency parameter object with one service replaced. @@ -91,6 +105,14 @@ public CompiledQueryCacheKeyGeneratorDependencies With([NotNull] IModel model) /// A replacement for the current dependency of this type. /// A new parameter object with the given service replaced. public CompiledQueryCacheKeyGeneratorDependencies With([NotNull] ICurrentDbContext currentContext) - => new CompiledQueryCacheKeyGeneratorDependencies(Model, currentContext); + => new CompiledQueryCacheKeyGeneratorDependencies(Model, currentContext, _executionStrategyFactory); + + /// + /// Clones this dependency parameter object with one service replaced. + /// + /// A replacement for the current dependency of this type. + /// A new parameter object with the given service replaced. + public CompiledQueryCacheKeyGeneratorDependencies With([NotNull] IExecutionStrategyFactory executionStrategyFactory) + => new CompiledQueryCacheKeyGeneratorDependencies(Model, CurrentContext, executionStrategyFactory); } } diff --git a/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.ExpressionVisitors.cs b/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.ExpressionVisitors.cs index 1cfea0a06b8..5497686ecff 100644 --- a/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.ExpressionVisitors.cs +++ b/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.ExpressionVisitors.cs @@ -165,8 +165,8 @@ protected Expression ExpandNavigation( innerEntityReference.SetIncludePaths(innerIncludeTreeNode); } - var innerSoureSequenceType = innerSource.Type.GetSequenceType(); - var innerParameter = Expression.Parameter(innerSoureSequenceType, "i"); + var innerSourceSequenceType = innerSource.Type.GetSequenceType(); + var innerParameter = Expression.Parameter(innerSourceSequenceType, "i"); Expression outerKey; if (root is NavigationExpansionExpression innerNavigationExpansionExpression && innerNavigationExpansionExpression.CardinalityReducingGenericMethodInfo != null) @@ -228,7 +228,7 @@ protected Expression ExpandNavigation( : Expression.Equal(outerKey, innerKey); var subquery = Expression.Call( - QueryableMethods.Where.MakeGenericMethod(innerSoureSequenceType), + QueryableMethods.Where.MakeGenericMethod(innerSourceSequenceType), innerSource, Expression.Quote( Expression.Lambda( diff --git a/src/EFCore/Query/QueryCompilationContext.cs b/src/EFCore/Query/QueryCompilationContext.cs index 689c0c07004..71d60801dc0 100644 --- a/src/EFCore/Query/QueryCompilationContext.cs +++ b/src/EFCore/Query/QueryCompilationContext.cs @@ -36,6 +36,7 @@ public QueryCompilationContext( IsAsync = async; IsTracking = context.ChangeTracker.QueryTrackingBehavior == QueryTrackingBehavior.TrackAll; + IsBuffering = dependencies.IsRetryingExecutionStrategy; Model = dependencies.Model; ContextOptions = dependencies.ContextOptions; ContextType = context.GetType(); @@ -51,6 +52,7 @@ public QueryCompilationContext( public virtual IModel Model { get; } public virtual IDbContextOptions ContextOptions { get; } public virtual bool IsTracking { get; internal set; } + public virtual bool IsBuffering { get; } public virtual bool IgnoreQueryFilters { get; internal set; } public virtual ISet Tags { get; } = new HashSet(); public virtual IDiagnosticsLogger Logger { get; } diff --git a/src/EFCore/Query/QueryCompilationContextDependencies.cs b/src/EFCore/Query/QueryCompilationContextDependencies.cs index 943b21787a7..95101aa3623 100644 --- a/src/EFCore/Query/QueryCompilationContextDependencies.cs +++ b/src/EFCore/Query/QueryCompilationContextDependencies.cs @@ -5,6 +5,7 @@ using Microsoft.EntityFrameworkCore.Diagnostics; using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Storage; using Microsoft.EntityFrameworkCore.Utilities; using Microsoft.Extensions.DependencyInjection; @@ -35,6 +36,8 @@ namespace Microsoft.EntityFrameworkCore.Query /// public sealed class QueryCompilationContextDependencies { + private readonly IExecutionStrategyFactory _executionStrategyFactory; + /// /// /// Creates the service dependencies parameter object for a . @@ -61,6 +64,7 @@ public QueryCompilationContextDependencies( [NotNull] IQueryableMethodTranslatingExpressionVisitorFactory queryableMethodTranslatingExpressionVisitorFactory, [NotNull] IQueryTranslationPostprocessorFactory queryTranslationPostprocessorFactory, [NotNull] IShapedQueryCompilingExpressionVisitorFactory shapedQueryCompilingExpressionVisitorFactory, + [NotNull] IExecutionStrategyFactory executionStrategyFactory, [NotNull] ICurrentDbContext currentContext, [NotNull] IDbContextOptions contextOptions, [NotNull] IDiagnosticsLogger logger) @@ -70,6 +74,7 @@ public QueryCompilationContextDependencies( Check.NotNull(queryableMethodTranslatingExpressionVisitorFactory, nameof(queryableMethodTranslatingExpressionVisitorFactory)); Check.NotNull(queryTranslationPostprocessorFactory, nameof(queryTranslationPostprocessorFactory)); Check.NotNull(shapedQueryCompilingExpressionVisitorFactory, nameof(shapedQueryCompilingExpressionVisitorFactory)); + Check.NotNull(executionStrategyFactory, nameof(executionStrategyFactory)); Check.NotNull(currentContext, nameof(currentContext)); Check.NotNull(contextOptions, nameof(contextOptions)); Check.NotNull(logger, nameof(logger)); @@ -80,6 +85,8 @@ public QueryCompilationContextDependencies( QueryableMethodTranslatingExpressionVisitorFactory = queryableMethodTranslatingExpressionVisitorFactory; QueryTranslationPostprocessorFactory = queryTranslationPostprocessorFactory; ShapedQueryCompilingExpressionVisitorFactory = shapedQueryCompilingExpressionVisitorFactory; + _executionStrategyFactory = executionStrategyFactory; + IsRetryingExecutionStrategy = executionStrategyFactory.Create().RetriesOnFailure; ContextOptions = contextOptions; Logger = logger; } @@ -114,6 +121,11 @@ public QueryCompilationContextDependencies( /// public IShapedQueryCompilingExpressionVisitorFactory ShapedQueryCompilingExpressionVisitorFactory { get; } + /// + /// Whether the configured execution strategy can retry. + /// + public bool IsRetryingExecutionStrategy { get; } + /// /// The context options. /// @@ -136,6 +148,7 @@ public QueryCompilationContextDependencies With([NotNull] IModel model) QueryableMethodTranslatingExpressionVisitorFactory, QueryTranslationPostprocessorFactory, ShapedQueryCompilingExpressionVisitorFactory, + _executionStrategyFactory, CurrentContext, ContextOptions, Logger); @@ -152,6 +165,7 @@ public QueryCompilationContextDependencies With([NotNull] IQueryTranslationPrepr QueryableMethodTranslatingExpressionVisitorFactory, QueryTranslationPostprocessorFactory, ShapedQueryCompilingExpressionVisitorFactory, + _executionStrategyFactory, CurrentContext, ContextOptions, Logger); @@ -169,6 +183,7 @@ public QueryCompilationContextDependencies With( queryableMethodTranslatingExpressionVisitorFactory, QueryTranslationPostprocessorFactory, ShapedQueryCompilingExpressionVisitorFactory, + _executionStrategyFactory, CurrentContext, ContextOptions, Logger); @@ -186,6 +201,7 @@ public QueryCompilationContextDependencies With( QueryableMethodTranslatingExpressionVisitorFactory, queryTranslationPostprocessorFactory, ShapedQueryCompilingExpressionVisitorFactory, + _executionStrategyFactory, CurrentContext, ContextOptions, Logger); @@ -203,6 +219,24 @@ public QueryCompilationContextDependencies With( QueryableMethodTranslatingExpressionVisitorFactory, QueryTranslationPostprocessorFactory, shapedQueryCompilingExpressionVisitorFactory, + _executionStrategyFactory, + CurrentContext, + ContextOptions, + Logger); + + /// + /// Clones this dependency parameter object with one service replaced. + /// + /// A replacement for the current dependency of this type. + /// A new parameter object with the given service replaced. + public QueryCompilationContextDependencies With([NotNull] IExecutionStrategyFactory executionStrategyFactory) + => new QueryCompilationContextDependencies( + Model, + QueryTranslationPreprocessorFactory, + QueryableMethodTranslatingExpressionVisitorFactory, + QueryTranslationPostprocessorFactory, + ShapedQueryCompilingExpressionVisitorFactory, + executionStrategyFactory, CurrentContext, ContextOptions, Logger); @@ -219,6 +253,7 @@ public QueryCompilationContextDependencies With([NotNull] ICurrentDbContext curr QueryableMethodTranslatingExpressionVisitorFactory, QueryTranslationPostprocessorFactory, ShapedQueryCompilingExpressionVisitorFactory, + _executionStrategyFactory, currentContext, ContextOptions, Logger); @@ -235,6 +270,7 @@ public QueryCompilationContextDependencies With([NotNull] IDbContextOptions cont QueryableMethodTranslatingExpressionVisitorFactory, QueryTranslationPostprocessorFactory, ShapedQueryCompilingExpressionVisitorFactory, + _executionStrategyFactory, CurrentContext, contextOptions, Logger); @@ -251,6 +287,7 @@ public QueryCompilationContextDependencies With([NotNull] IDiagnosticsLogger Dependencies.QueryProvider; + /// + /// The execution strategy factory. + /// + /// + /// The execution strategy factory. + /// + public virtual IExecutionStrategyFactory ExecutionStrategyFactory + => Dependencies.ExecutionStrategyFactory; + /// /// Gets the concurrency detector. /// diff --git a/src/EFCore/Query/QueryContextDependencies.cs b/src/EFCore/Query/QueryContextDependencies.cs index 8fb976b5ce7..b9c57db4867 100644 --- a/src/EFCore/Query/QueryContextDependencies.cs +++ b/src/EFCore/Query/QueryContextDependencies.cs @@ -7,6 +7,7 @@ using Microsoft.EntityFrameworkCore.Diagnostics; using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Internal; +using Microsoft.EntityFrameworkCore.Storage; using Microsoft.EntityFrameworkCore.Utilities; using Microsoft.Extensions.DependencyInjection; @@ -59,16 +60,19 @@ public sealed class QueryContextDependencies [EntityFrameworkInternal] public QueryContextDependencies( [NotNull] ICurrentDbContext currentContext, + [NotNull] IExecutionStrategyFactory executionStrategyFactory, [NotNull] IConcurrencyDetector concurrencyDetector, [NotNull] IDiagnosticsLogger commandLogger, [NotNull] IDiagnosticsLogger queryLogger) { Check.NotNull(currentContext, nameof(currentContext)); + Check.NotNull(executionStrategyFactory, nameof(executionStrategyFactory)); Check.NotNull(concurrencyDetector, nameof(concurrencyDetector)); Check.NotNull(commandLogger, nameof(commandLogger)); Check.NotNull(queryLogger, nameof(queryLogger)); CurrentContext = currentContext; + ExecutionStrategyFactory = executionStrategyFactory; ConcurrencyDetector = concurrencyDetector; CommandLogger = commandLogger; QueryLogger = queryLogger; @@ -93,6 +97,11 @@ public QueryContextDependencies( /// public IQueryProvider QueryProvider => CurrentContext.GetDependencies().QueryProvider; + /// + /// The execution strategy. + /// + public IExecutionStrategyFactory ExecutionStrategyFactory { get; } + /// /// Gets the concurrency detector. /// @@ -114,7 +123,14 @@ public QueryContextDependencies( /// A replacement for the current dependency of this type. /// A new parameter object with the given service replaced. public QueryContextDependencies With([NotNull] ICurrentDbContext currentContext) - => new QueryContextDependencies(currentContext, ConcurrencyDetector, CommandLogger, QueryLogger); + => new QueryContextDependencies(currentContext, ExecutionStrategyFactory, ConcurrencyDetector, CommandLogger, QueryLogger); + /// + /// Clones this dependency parameter object with one service replaced. + /// + /// A replacement for the current dependency of this type. + /// A new parameter object with the given service replaced. + public QueryContextDependencies With([NotNull] IExecutionStrategyFactory executionStrategyFactor) + => new QueryContextDependencies(CurrentContext, executionStrategyFactor, ConcurrencyDetector, CommandLogger, QueryLogger); /// /// Clones this dependency parameter object with one service replaced. @@ -122,7 +138,7 @@ public QueryContextDependencies With([NotNull] ICurrentDbContext currentContext) /// A replacement for the current dependency of this type. /// A new parameter object with the given service replaced. public QueryContextDependencies With([NotNull] IConcurrencyDetector concurrencyDetector) - => new QueryContextDependencies(CurrentContext, concurrencyDetector, CommandLogger, QueryLogger); + => new QueryContextDependencies(CurrentContext, ExecutionStrategyFactory, concurrencyDetector, CommandLogger, QueryLogger); /// /// Clones this dependency parameter object with one service replaced. @@ -130,7 +146,7 @@ public QueryContextDependencies With([NotNull] IConcurrencyDetector concurrencyD /// A replacement for the current dependency of this type. /// A new parameter object with the given service replaced. public QueryContextDependencies With([NotNull] IDiagnosticsLogger commandLogger) - => new QueryContextDependencies(CurrentContext, ConcurrencyDetector, commandLogger, QueryLogger); + => new QueryContextDependencies(CurrentContext, ExecutionStrategyFactory, ConcurrencyDetector, commandLogger, QueryLogger); /// /// Clones this dependency parameter object with one service replaced. @@ -138,6 +154,6 @@ public QueryContextDependencies With([NotNull] IDiagnosticsLogger A replacement for the current dependency of this type. /// A new parameter object with the given service replaced. public QueryContextDependencies With([NotNull] IDiagnosticsLogger queryLogger) - => new QueryContextDependencies(CurrentContext, ConcurrencyDetector, CommandLogger, queryLogger); + => new QueryContextDependencies(CurrentContext, ExecutionStrategyFactory, ConcurrencyDetector, CommandLogger, queryLogger); } } diff --git a/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs b/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs index 97a4ff7b8d1..94c67f2a4cd 100644 --- a/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs +++ b/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs @@ -41,6 +41,7 @@ protected ShapedQueryCompilingExpressionVisitor( _constantVerifyingExpressionVisitor = new ConstantVerifyingExpressionVisitor(dependencies.TypeMappingSource); + IsBuffering = queryCompilationContext.IsBuffering; IsAsync = queryCompilationContext.IsAsync; if (queryCompilationContext.IsAsync) @@ -55,6 +56,8 @@ protected ShapedQueryCompilingExpressionVisitor( protected virtual bool IsTracking { get; } + public virtual bool IsBuffering { get; internal set; } + protected virtual bool IsAsync { get; } protected override Expression VisitExtension(Expression extensionExpression) diff --git a/src/EFCore/Storage/ExecutionStrategyExtensions.cs b/src/EFCore/Storage/ExecutionStrategyExtensions.cs index 9cbcd076c9e..c1c13774605 100644 --- a/src/EFCore/Storage/ExecutionStrategyExtensions.cs +++ b/src/EFCore/Storage/ExecutionStrategyExtensions.cs @@ -283,7 +283,7 @@ public static TResult Execute( [NotNull] this IExecutionStrategy strategy, [CanBeNull] TState state, [NotNull] Func operation) - => strategy.Execute(operation, verifySucceeded: null, state: state); + => strategy.Execute(state, operation, verifySucceeded: null); /// /// Executes the specified asynchronous operation and returns the result. @@ -330,14 +330,40 @@ public static Task ExecuteAsync( /// public static TResult Execute( [NotNull] this IExecutionStrategy strategy, + [CanBeNull] TState state, [NotNull] Func operation, - [CanBeNull] Func> verifySucceeded, - [CanBeNull] TState state) + [CanBeNull] Func> verifySucceeded) => Check.NotNull(strategy, nameof(strategy)).Execute( state, (c, s) => operation(s), verifySucceeded == null ? (Func>)null : (c, s) => verifySucceeded(s)); + /// + /// Executes the specified operation and returns the result. + /// + /// The strategy that will be used for the execution. + /// + /// A delegate representing an executable operation that returns the result of type . + /// + /// A delegate that tests whether the operation succeeded even though an exception was thrown. + /// The state that will be passed to the operation. + /// The type of the state. + /// The return type of . + /// The result from the operation. + /// + /// The operation has not succeeded after the configured number of retries. + /// + [Obsolete("Use overload that takes the state first")] + public static TResult Execute( + [NotNull] this IExecutionStrategy strategy, + [NotNull] Func operation, + [CanBeNull] Func> verifySucceeded, + [CanBeNull] TState state) + => strategy.Execute( + state, + operation, + verifySucceeded); + /// /// Executes the specified asynchronous operation and returns the result. /// diff --git a/test/EFCore.Cosmos.FunctionalTests/TestUtilities/TestSqlServerRetryingExecutionStrategy.cs b/test/EFCore.Cosmos.FunctionalTests/TestUtilities/TestCosmosExecutionStrategy.cs similarity index 96% rename from test/EFCore.Cosmos.FunctionalTests/TestUtilities/TestSqlServerRetryingExecutionStrategy.cs rename to test/EFCore.Cosmos.FunctionalTests/TestUtilities/TestCosmosExecutionStrategy.cs index f61dcd5be22..e56eea6d9c4 100644 --- a/test/EFCore.Cosmos.FunctionalTests/TestUtilities/TestSqlServerRetryingExecutionStrategy.cs +++ b/test/EFCore.Cosmos.FunctionalTests/TestUtilities/TestCosmosExecutionStrategy.cs @@ -6,6 +6,7 @@ using Microsoft.EntityFrameworkCore.Cosmos.TestUtilities; using Microsoft.EntityFrameworkCore.Storage; +// ReSharper disable once CheckNamespace namespace Microsoft.EntityFrameworkCore.TestUtilities { public class TestCosmosExecutionStrategy : CosmosExecutionStrategy diff --git a/test/EFCore.Relational.Specification.Tests/CommandInterceptionTestBase.cs b/test/EFCore.Relational.Specification.Tests/CommandInterceptionTestBase.cs index f0de0b127a3..419d9b95e8f 100644 --- a/test/EFCore.Relational.Specification.Tests/CommandInterceptionTestBase.cs +++ b/test/EFCore.Relational.Specification.Tests/CommandInterceptionTestBase.cs @@ -14,6 +14,7 @@ using Microsoft.EntityFrameworkCore.Storage; using Xunit; +// ReSharper disable InconsistentNaming namespace Microsoft.EntityFrameworkCore { public abstract class CommandInterceptionTestBase : InterceptionTestBase @@ -78,7 +79,7 @@ public virtual async Task Intercept_scalar_passively(bool async, bool inject) var connection = context.GetService(); var logger = context.GetService>(); - var commandParameterObject = new RelationalCommandParameterObject(connection, null, context, logger); + var commandParameterObject = new RelationalCommandParameterObject(connection, null, null, context, logger); using (var listener = Fixture.SubscribeToDiagnosticListener(context.ContextId)) { @@ -292,7 +293,7 @@ public virtual async Task Intercept_scalar_to_suppress_execution(bool async, boo var connection = context.GetService(); var logger = context.GetService>(); - var commandParameterObject = new RelationalCommandParameterObject(connection, null, context, logger); + var commandParameterObject = new RelationalCommandParameterObject(connection, null, null, context, logger); using (var listener = Fixture.SubscribeToDiagnosticListener(context.ContextId)) { @@ -483,7 +484,7 @@ public virtual async Task Intercept_scalar_to_mutate_command(bool async, bool in var connection = context.GetService(); var logger = context.GetService>(); - var commandParameterObject = new RelationalCommandParameterObject(connection, null, context, logger); + var commandParameterObject = new RelationalCommandParameterObject(connection, null, null, context, logger); using (var listener = Fixture.SubscribeToDiagnosticListener(context.ContextId)) { @@ -683,7 +684,7 @@ public virtual async Task Intercept_scalar_to_replace_execution(bool async, bool var connection = context.GetService(); var logger = context.GetService>(); - var commandParameterObject = new RelationalCommandParameterObject(connection, null, context, logger); + var commandParameterObject = new RelationalCommandParameterObject(connection, null, null, context, logger); using (var listener = Fixture.SubscribeToDiagnosticListener(context.ContextId)) { @@ -891,6 +892,20 @@ public CompositeFakeDbDataReader(DbDataReader firstReader, DbDataReader secondRe _secondReader = secondReader; } + public override int FieldCount => _firstReader.FieldCount; + public override int RecordsAffected => _firstReader.RecordsAffected + _secondReader.RecordsAffected; + public override bool HasRows => _firstReader.HasRows || _secondReader.HasRows; + public override bool IsClosed => _firstReader.IsClosed; + public override int Depth => _firstReader.Depth; + + public override string GetDataTypeName(int ordinal) => _firstReader.GetDataTypeName(ordinal); + public override Type GetFieldType(int ordinal) => _firstReader.GetFieldType(ordinal); + public override string GetName(int ordinal) => _firstReader.GetName(ordinal); + public override bool NextResult() => _firstReader.NextResult() || _secondReader.NextResult(); + + public override async Task NextResultAsync(CancellationToken cancellationToken) + => await _firstReader.NextResultAsync(cancellationToken) || await _secondReader.NextResultAsync(cancellationToken); + public override void Close() { _firstReader.Close(); @@ -948,7 +963,7 @@ public virtual async Task Intercept_scalar_to_replace_result(bool async, bool in var connection = context.GetService(); var logger = context.GetService>(); - var commandParameterObject = new RelationalCommandParameterObject(connection, null, context, logger); + var commandParameterObject = new RelationalCommandParameterObject(connection, null, null, context, logger); using (var listener = Fixture.SubscribeToDiagnosticListener(context.ContextId)) { @@ -1104,7 +1119,7 @@ public virtual async Task Intercept_scalar_that_throws(bool async, bool inject) var connection = context.GetService(); var logger = context.GetService>(); - var commandParameterObject = new RelationalCommandParameterObject(connection, null, context, logger); + var commandParameterObject = new RelationalCommandParameterObject(connection, null, null, context, logger); try { @@ -1188,7 +1203,7 @@ public virtual async Task Intercept_scalar_to_throw(bool async, bool inject) var connection = context.GetService(); var logger = context.GetService>(); - var commandParameterObject = new RelationalCommandParameterObject(connection, null, context, logger); + var commandParameterObject = new RelationalCommandParameterObject(connection, null, null, context, logger); var exception = async ? await Assert.ThrowsAsync(() => command.ExecuteScalarAsync(commandParameterObject)) @@ -1323,7 +1338,7 @@ private static async Task TestCompositeScalarInterceptors(UniverseContext contex var connection = context.GetService(); var logger = context.GetService>(); - var commandParameterObject = new RelationalCommandParameterObject(connection, null, context, logger); + var commandParameterObject = new RelationalCommandParameterObject(connection, null, null, context, logger); Assert.Equal( ResultReplacingScalarCommandInterceptor.InterceptedResult, @@ -1520,6 +1535,11 @@ private class FakeDbDataReader : DbDataReader private readonly int[] _ints = { 977, 988, 999 }; private readonly string[] _strings = { "<977>", "<988>", "<999>" }; + public override int FieldCount { get; } + public override int RecordsAffected { get; } + public override bool HasRows { get; } + public override bool IsClosed { get; } + public override int Depth { get; } public override bool Read() => _index++ < _ints.Length; @@ -1557,14 +1577,9 @@ public override long GetChars(int ordinal, long dataOffset, char[] buffer, int b public override int GetOrdinal(string name) => throw new NotImplementedException(); public override object GetValue(int ordinal) => throw new NotImplementedException(); public override int GetValues(object[] values) => throw new NotImplementedException(); - public override int FieldCount { get; } public override object this[int ordinal] => throw new NotImplementedException(); public override object this[string name] => throw new NotImplementedException(); - public override int RecordsAffected { get; } - public override bool HasRows { get; } - public override bool IsClosed { get; } public override bool NextResult() => throw new NotImplementedException(); - public override int Depth { get; } public override IEnumerator GetEnumerator() => throw new NotImplementedException(); } diff --git a/test/EFCore.Relational.Specification.Tests/Query/AsyncFromSqlQueryTestBase.cs b/test/EFCore.Relational.Specification.Tests/Query/AsyncFromSqlQueryTestBase.cs index e523e716dba..b26d4f52601 100644 --- a/test/EFCore.Relational.Specification.Tests/Query/AsyncFromSqlQueryTestBase.cs +++ b/test/EFCore.Relational.Specification.Tests/Query/AsyncFromSqlQueryTestBase.cs @@ -30,7 +30,7 @@ public virtual async Task FromSqlRaw_queryable_simple() using (var context = CreateContext()) { var actual = await context.Set() - .FromSqlRaw(NormalizeDelimetersInRawString("SELECT * FROM [Customers] WHERE [ContactName] LIKE '%z%'")) + .FromSqlRaw(NormalizeDelimitersInRawString("SELECT * FROM [Customers] WHERE [ContactName] LIKE '%z%'")) .ToArrayAsync(); Assert.Equal(14, actual.Length); @@ -44,7 +44,7 @@ public virtual async Task FromSqlRaw_queryable_simple_columns_out_of_order() using (var context = CreateContext()) { var actual = await context.Set().FromSqlRaw( - NormalizeDelimetersInRawString( + NormalizeDelimitersInRawString( "SELECT [Region], [PostalCode], [Phone], [Fax], [CustomerID], [Country], [ContactTitle], [ContactName], [CompanyName], [City], [Address] FROM [Customers]")) .ToArrayAsync(); @@ -59,7 +59,7 @@ public virtual async Task FromSqlRaw_queryable_simple_columns_out_of_order_and_e using (var context = CreateContext()) { var actual = await context.Set().FromSqlRaw( - NormalizeDelimetersInRawString( + NormalizeDelimitersInRawString( "SELECT [Region], [PostalCode], [PostalCode] AS [Foo], [Phone], [Fax], [CustomerID], [Country], [ContactTitle], [ContactName], [CompanyName], [City], [Address] FROM [Customers]")) .ToArrayAsync(); @@ -73,7 +73,7 @@ public virtual async Task FromSqlRaw_queryable_composed() { using (var context = CreateContext()) { - var actual = await context.Set().FromSqlRaw(NormalizeDelimetersInRawString("SELECT * FROM [Customers]")) + var actual = await context.Set().FromSqlRaw(NormalizeDelimitersInRawString("SELECT * FROM [Customers]")) .Where(c => c.ContactName.Contains("z")) .ToArrayAsync(); @@ -87,8 +87,8 @@ public virtual async Task FromSqlRaw_queryable_multiple_composed() using (var context = CreateContext()) { var actual - = await (from c in context.Set().FromSqlRaw(NormalizeDelimetersInRawString("SELECT * FROM [Customers]")) - from o in context.Set().FromSqlRaw(NormalizeDelimetersInRawString("SELECT * FROM [Orders]")) + = await (from c in context.Set().FromSqlRaw(NormalizeDelimitersInRawString("SELECT * FROM [Customers]")) + from o in context.Set().FromSqlRaw(NormalizeDelimitersInRawString("SELECT * FROM [Orders]")) where c.CustomerID == o.CustomerID select new { c, o }) .ToArrayAsync(); @@ -106,9 +106,9 @@ public virtual async Task FromSqlRaw_queryable_multiple_composed_with_closure_pa using (var context = CreateContext()) { var actual - = await (from c in context.Set().FromSqlRaw(NormalizeDelimetersInRawString("SELECT * FROM [Customers]")) + = await (from c in context.Set().FromSqlRaw(NormalizeDelimitersInRawString("SELECT * FROM [Customers]")) from o in context.Set().FromSqlRaw( - NormalizeDelimetersInRawString("SELECT * FROM [Orders] WHERE [OrderDate] BETWEEN {0} AND {1}"), startDate, + NormalizeDelimitersInRawString("SELECT * FROM [Orders] WHERE [OrderDate] BETWEEN {0} AND {1}"), startDate, endDate) where c.CustomerID == o.CustomerID select new { c, o }) @@ -129,9 +129,9 @@ public virtual async Task FromSqlRaw_queryable_multiple_composed_with_parameters { var actual = await (from c in context.Set().FromSqlRaw( - NormalizeDelimetersInRawString("SELECT * FROM [Customers] WHERE [City] = {0}"), city) + NormalizeDelimitersInRawString("SELECT * FROM [Customers] WHERE [City] = {0}"), city) from o in context.Set().FromSqlRaw( - NormalizeDelimetersInRawString("SELECT * FROM [Orders] WHERE [OrderDate] BETWEEN {0} AND {1}"), startDate, + NormalizeDelimitersInRawString("SELECT * FROM [Orders] WHERE [OrderDate] BETWEEN {0} AND {1}"), startDate, endDate) where c.CustomerID == o.CustomerID select new { c, o }) @@ -147,7 +147,7 @@ public virtual async Task FromSqlRaw_queryable_multiple_line_query() using (var context = CreateContext()) { var actual = await context.Set().FromSqlRaw( - NormalizeDelimetersInRawString( + NormalizeDelimitersInRawString( @"SELECT * FROM [Customers] WHERE [City] = 'London'")) @@ -164,7 +164,7 @@ public virtual async Task FromSqlRaw_queryable_composed_multiple_line_query() using (var context = CreateContext()) { var actual = await context.Set().FromSqlRaw( - NormalizeDelimetersInRawString( + NormalizeDelimitersInRawString( @"SELECT * FROM [Customers]")) .Where(c => c.City == "London") @@ -184,7 +184,7 @@ public virtual async Task FromSqlRaw_queryable_with_parameters() using (var context = CreateContext()) { var actual = await context.Set().FromSqlRaw( - NormalizeDelimetersInRawString("SELECT * FROM [Customers] WHERE [City] = {0} AND [ContactTitle] = {1}"), city, + NormalizeDelimitersInRawString("SELECT * FROM [Customers] WHERE [City] = {0} AND [ContactTitle] = {1}"), city, contactTitle) .ToArrayAsync(); @@ -203,7 +203,7 @@ public virtual async Task FromSqlRaw_queryable_with_parameters_and_closure() using (var context = CreateContext()) { var actual = await context.Set().FromSqlRaw( - NormalizeDelimetersInRawString("SELECT * FROM [Customers] WHERE [City] = {0}"), city) + NormalizeDelimitersInRawString("SELECT * FROM [Customers] WHERE [City] = {0}"), city) .Where(c => c.ContactTitle == contactTitle) .ToArrayAsync(); @@ -219,14 +219,14 @@ public virtual async Task FromSqlRaw_queryable_simple_cache_key_includes_query_s using (var context = CreateContext()) { var actual = await context.Set() - .FromSqlRaw(NormalizeDelimetersInRawString("SELECT * FROM [Customers] WHERE [City] = 'London'")) + .FromSqlRaw(NormalizeDelimitersInRawString("SELECT * FROM [Customers] WHERE [City] = 'London'")) .ToArrayAsync(); Assert.Equal(6, actual.Length); Assert.True(actual.All(c => c.City == "London")); actual = await context.Set() - .FromSqlRaw(NormalizeDelimetersInRawString("SELECT * FROM [Customers] WHERE [City] = 'Seattle'")) + .FromSqlRaw(NormalizeDelimitersInRawString("SELECT * FROM [Customers] WHERE [City] = 'Seattle'")) .ToArrayAsync(); Assert.Single(actual); @@ -243,7 +243,7 @@ public virtual async Task FromSqlRaw_queryable_with_parameters_cache_key_include using (var context = CreateContext()) { - var actual = await context.Set().FromSqlRaw(NormalizeDelimetersInRawString(sql), city, contactTitle) + var actual = await context.Set().FromSqlRaw(NormalizeDelimitersInRawString(sql), city, contactTitle) .ToArrayAsync(); Assert.Equal(3, actual.Length); @@ -253,7 +253,7 @@ public virtual async Task FromSqlRaw_queryable_with_parameters_cache_key_include city = "Madrid"; contactTitle = "Accounting Manager"; - actual = await context.Set().FromSqlRaw(NormalizeDelimetersInRawString(sql), city, contactTitle) + actual = await context.Set().FromSqlRaw(NormalizeDelimitersInRawString(sql), city, contactTitle) .ToArrayAsync(); Assert.Equal(2, actual.Length); @@ -267,7 +267,7 @@ public virtual async Task FromSqlRaw_queryable_simple_as_no_tracking_not_compose { using (var context = CreateContext()) { - var actual = await context.Set().FromSqlRaw(NormalizeDelimetersInRawString("SELECT * FROM [Customers]")) + var actual = await context.Set().FromSqlRaw(NormalizeDelimitersInRawString("SELECT * FROM [Customers]")) .AsNoTracking() .ToArrayAsync(); @@ -281,7 +281,7 @@ public virtual async Task FromSqlRaw_queryable_simple_projection_not_composed() { using (var context = CreateContext()) { - var actual = await context.Set().FromSqlRaw(NormalizeDelimetersInRawString("SELECT * FROM [Customers]")) + var actual = await context.Set().FromSqlRaw(NormalizeDelimitersInRawString("SELECT * FROM [Customers]")) .Select( c => new { c.CustomerID, c.City }) .AsNoTracking() @@ -297,7 +297,7 @@ public virtual async Task FromSqlRaw_queryable_simple_include() { using (var context = CreateContext()) { - var actual = await context.Set().FromSqlRaw(NormalizeDelimetersInRawString("SELECT * FROM [Customers]")) + var actual = await context.Set().FromSqlRaw(NormalizeDelimitersInRawString("SELECT * FROM [Customers]")) .Include(c => c.Orders) .ToArrayAsync(); @@ -310,7 +310,7 @@ public virtual async Task FromSqlRaw_queryable_simple_composed_include() { using (var context = CreateContext()) { - var actual = await context.Set().FromSqlRaw(NormalizeDelimetersInRawString("SELECT * FROM [Customers]")) + var actual = await context.Set().FromSqlRaw(NormalizeDelimitersInRawString("SELECT * FROM [Customers]")) .Include(c => c.Orders) .Where(c => c.City == "London") .ToArrayAsync(); @@ -325,7 +325,7 @@ public virtual async Task FromSqlRaw_annotations_do_not_affect_successive_calls( using (var context = CreateContext()) { var actual = await context.Customers - .FromSqlRaw(NormalizeDelimetersInRawString("SELECT * FROM [Customers] WHERE [ContactName] LIKE '%z%'")) + .FromSqlRaw(NormalizeDelimitersInRawString("SELECT * FROM [Customers] WHERE [ContactName] LIKE '%z%'")) .ToArrayAsync(); Assert.Equal(14, actual.Length); @@ -342,7 +342,7 @@ public virtual async Task FromSqlRaw_composed_with_nullable_predicate() { using (var context = CreateContext()) { - var actual = await context.Set().FromSqlRaw(NormalizeDelimetersInRawString("SELECT * FROM [Customers]")) + var actual = await context.Set().FromSqlRaw(NormalizeDelimitersInRawString("SELECT * FROM [Customers]")) .Where(c => c.ContactName == c.CompanyName) .ToArrayAsync(); @@ -400,11 +400,11 @@ public virtual async Task Include_closed_connection_opened_by_it_when_buffering( } } - private string NormalizeDelimetersInRawString(string sql) - => Fixture.TestStore.NormalizeDelimetersInRawString(sql); + private string NormalizeDelimitersInRawString(string sql) + => Fixture.TestStore.NormalizeDelimitersInRawString(sql); - private FormattableString NormalizeDelimetersInInterpolatedString(FormattableString sql) - => Fixture.TestStore.NormalizeDelimetersInInterpolatedString(sql); + private FormattableString NormalizeDelimitersInInterpolatedString(FormattableString sql) + => Fixture.TestStore.NormalizeDelimitersInInterpolatedString(sql); protected NorthwindContext CreateContext() => Fixture.CreateContext(); } diff --git a/test/EFCore.Relational.Specification.Tests/Query/FromSqlQueryTestBase.cs b/test/EFCore.Relational.Specification.Tests/Query/FromSqlQueryTestBase.cs index b2667c03de5..fc4e2df2515 100644 --- a/test/EFCore.Relational.Specification.Tests/Query/FromSqlQueryTestBase.cs +++ b/test/EFCore.Relational.Specification.Tests/Query/FromSqlQueryTestBase.cs @@ -42,7 +42,7 @@ public virtual void Bad_data_error_handling_invalid_cast_key() Assert.Throws( () => context.Set().FromSqlRaw( - NormalizeDelimetersInRawString( + NormalizeDelimitersInRawString( @"SELECT [ProductID] AS [ProductName], [ProductName] AS [ProductID], [SupplierID], [UnitPrice], [UnitsInStock], [Discontinued] FROM [Products]")) .ToList()).Message); @@ -59,7 +59,7 @@ public virtual void Bad_data_error_handling_invalid_cast() Assert.Throws( () => context.Set().FromSqlRaw( - NormalizeDelimetersInRawString( + NormalizeDelimitersInRawString( @"SELECT [ProductID], [SupplierID] AS [UnitPrice], [ProductName], [SupplierID], [UnitsInStock], [Discontinued] FROM [Products]")) .ToList()).Message); @@ -76,7 +76,7 @@ public virtual void Bad_data_error_handling_invalid_cast_projection() Assert.Throws( () => context.Set().FromSqlRaw( - NormalizeDelimetersInRawString( + NormalizeDelimitersInRawString( @"SELECT [ProductID], [SupplierID] AS [UnitPrice], [ProductName], [UnitsInStock], [Discontinued] FROM [Products]")) .Select(p => p.UnitPrice) @@ -95,7 +95,7 @@ public virtual void Bad_data_error_handling_invalid_cast_no_tracking() () => context.Set() .FromSqlRaw( - NormalizeDelimetersInRawString( + NormalizeDelimitersInRawString( @"SELECT [ProductID] AS [ProductName], [ProductName] AS [ProductID], [SupplierID], [UnitPrice], [UnitsInStock], [Discontinued] FROM [Products]")).AsNoTracking() .ToList()).Message); @@ -108,7 +108,7 @@ public virtual void Bad_data_error_handling_null() using (var context = CreateContext()) { context.Set().FromSqlRaw( - NormalizeDelimetersInRawString( + NormalizeDelimitersInRawString( @"SELECT [ProductID], [ProductName], [SupplierID], [UnitPrice], [UnitsInStock], NULL AS [Discontinued] FROM [Products]")) .ToList(); @@ -117,7 +117,7 @@ public virtual void Bad_data_error_handling_null() Assert.Throws( () => context.Set().FromSqlRaw( - NormalizeDelimetersInRawString( + NormalizeDelimitersInRawString( @"SELECT [ProductID], [ProductName], [SupplierID], [UnitPrice], [UnitsInStock], NULL AS [Discontinued] FROM [Products]")) .ToList()).Message); @@ -134,7 +134,7 @@ public virtual void Bad_data_error_handling_null_projection() Assert.Throws( () => context.Set().FromSqlRaw( - NormalizeDelimetersInRawString( + NormalizeDelimitersInRawString( @"SELECT [ProductID], [ProductName], [SupplierID], [UnitPrice], [UnitsInStock], NULL AS [Discontinued] FROM [Products]")) .Select(p => p.Discontinued) @@ -153,7 +153,7 @@ public virtual void Bad_data_error_handling_null_no_tracking() () => context.Set() .FromSqlRaw( - NormalizeDelimetersInRawString( + NormalizeDelimitersInRawString( @"SELECT [ProductID], [ProductName], [SupplierID], [UnitPrice], [UnitsInStock], NULL AS [Discontinued] FROM [Products]")).AsNoTracking() .ToList()).Message); @@ -166,7 +166,7 @@ public virtual void FromSqlRaw_queryable_simple() using (var context = CreateContext()) { var actual = context.Set() - .FromSqlRaw(NormalizeDelimetersInRawString("SELECT * FROM [Customers] WHERE [ContactName] LIKE '%z%'")) + .FromSqlRaw(NormalizeDelimitersInRawString("SELECT * FROM [Customers] WHERE [ContactName] LIKE '%z%'")) .ToArray(); Assert.Equal(14, actual.Length); @@ -180,7 +180,7 @@ public virtual void FromSqlRaw_queryable_simple_columns_out_of_order() using (var context = CreateContext()) { var actual = context.Set().FromSqlRaw( - NormalizeDelimetersInRawString( + NormalizeDelimitersInRawString( "SELECT [Region], [PostalCode], [Phone], [Fax], [CustomerID], [Country], [ContactTitle], [ContactName], [CompanyName], [City], [Address] FROM [Customers]")) .ToArray(); @@ -195,7 +195,7 @@ public virtual void FromSqlRaw_queryable_simple_columns_out_of_order_and_extra_c using (var context = CreateContext()) { var actual = context.Set().FromSqlRaw( - NormalizeDelimetersInRawString( + NormalizeDelimitersInRawString( "SELECT [Region], [PostalCode], [PostalCode] AS [Foo], [Phone], [Fax], [CustomerID], [Country], [ContactTitle], [ContactName], [CompanyName], [City], [Address] FROM [Customers]")) .ToArray(); @@ -213,7 +213,7 @@ public virtual void FromSqlRaw_queryable_simple_columns_out_of_order_and_not_eno RelationalStrings.FromSqlMissingColumn("Region"), Assert.Throws( () => context.Set().FromSqlRaw( - NormalizeDelimetersInRawString( + NormalizeDelimitersInRawString( "SELECT [PostalCode], [Phone], [Fax], [CustomerID], [Country], [ContactTitle], [ContactName], [CompanyName], [City], [Address] FROM [Customers]")) .ToArray() ).Message); @@ -225,7 +225,7 @@ public virtual void FromSqlRaw_queryable_composed() { using (var context = CreateContext()) { - var actual = context.Set().FromSqlRaw(NormalizeDelimetersInRawString("SELECT * FROM [Customers]")) + var actual = context.Set().FromSqlRaw(NormalizeDelimitersInRawString("SELECT * FROM [Customers]")) .Where(c => c.ContactName.Contains("z")) .ToArray(); @@ -239,7 +239,7 @@ public virtual void FromSqlRaw_queryable_composed_after_removing_whitespaces() using (var context = CreateContext()) { var actual = context.Set().FromSqlRaw( - NormalizeDelimetersInRawString( + NormalizeDelimitersInRawString( _eol + " " + _eol + _eol + _eol + "SELECT" + _eol + "* FROM [Customers]")) .Where(c => c.ContactName.Contains("z")) .ToArray(); @@ -253,7 +253,7 @@ public virtual void FromSqlRaw_queryable_composed_compiled() { var query = EF.CompileQuery( (NorthwindContext context) => context.Set() - .FromSqlRaw(NormalizeDelimetersInRawString("SELECT * FROM [Customers]")) + .FromSqlRaw(NormalizeDelimitersInRawString("SELECT * FROM [Customers]")) .Where(c => c.ContactName.Contains("z"))); using (var context = CreateContext()) @@ -270,7 +270,7 @@ public virtual void FromSqlRaw_queryable_composed_compiled_with_parameter() var query = EF.CompileQuery( (NorthwindContext context) => context.Set() .FromSqlRaw( - NormalizeDelimetersInRawString("SELECT * FROM [Customers] WHERE [CustomerID] = {0}"), "CONSH") + NormalizeDelimitersInRawString("SELECT * FROM [Customers] WHERE [CustomerID] = {0}"), "CONSH") .Where(c => c.ContactName.Contains("z"))); using (var context = CreateContext()) @@ -287,7 +287,7 @@ public virtual void FromSqlRaw_queryable_composed_compiled_with_DbParameter() var query = EF.CompileQuery( (NorthwindContext context) => context.Set() .FromSqlRaw( - NormalizeDelimetersInRawString("SELECT * FROM [Customers] WHERE [CustomerID] = @customer"), + NormalizeDelimitersInRawString("SELECT * FROM [Customers] WHERE [CustomerID] = @customer"), CreateDbParameter("customer", "CONSH")) .Where(c => c.ContactName.Contains("z"))); @@ -305,7 +305,7 @@ public virtual void FromSqlRaw_queryable_composed_compiled_with_nameless_DbParam var query = EF.CompileQuery( (NorthwindContext context) => context.Set() .FromSqlRaw( - NormalizeDelimetersInRawString("SELECT * FROM [Customers] WHERE [CustomerID] = {0}"), + NormalizeDelimitersInRawString("SELECT * FROM [Customers] WHERE [CustomerID] = {0}"), CreateDbParameter(null, "CONSH")) .Where(c => c.ContactName.Contains("z"))); @@ -324,7 +324,7 @@ public virtual void FromSqlRaw_composed_contains() { var actual = (from c in context.Set() - where context.Orders.FromSqlRaw(NormalizeDelimetersInRawString("SELECT * FROM [Orders]")) + where context.Orders.FromSqlRaw(NormalizeDelimitersInRawString("SELECT * FROM [Orders]")) .Select(o => o.CustomerID) .Contains(c.CustomerID) select c) @@ -343,7 +343,7 @@ var actual = (from c in context.Set() where c.CustomerID == "ALFKI" - && context.Orders.FromSqlRaw(NormalizeDelimetersInRawString("SELECT * FROM [Orders]")) + && context.Orders.FromSqlRaw(NormalizeDelimitersInRawString("SELECT * FROM [Orders]")) .Select(o => o.CustomerID) .Contains(c.CustomerID) select c) @@ -359,8 +359,8 @@ public virtual void FromSqlRaw_queryable_multiple_composed() using (var context = CreateContext()) { var actual - = (from c in context.Set().FromSqlRaw(NormalizeDelimetersInRawString("SELECT * FROM [Customers]")) - from o in context.Set().FromSqlRaw(NormalizeDelimetersInRawString("SELECT * FROM [Orders]")) + = (from c in context.Set().FromSqlRaw(NormalizeDelimitersInRawString("SELECT * FROM [Customers]")) + from o in context.Set().FromSqlRaw(NormalizeDelimitersInRawString("SELECT * FROM [Orders]")) where c.CustomerID == o.CustomerID select new { c, o }) .ToArray(); @@ -378,9 +378,9 @@ public virtual void FromSqlRaw_queryable_multiple_composed_with_closure_paramete using (var context = CreateContext()) { var actual - = (from c in context.Set().FromSqlRaw(NormalizeDelimetersInRawString("SELECT * FROM [Customers]")) + = (from c in context.Set().FromSqlRaw(NormalizeDelimitersInRawString("SELECT * FROM [Customers]")) from o in context.Set().FromSqlRaw( - NormalizeDelimetersInRawString("SELECT * FROM [Orders] WHERE [OrderDate] BETWEEN {0} AND {1}"), + NormalizeDelimitersInRawString("SELECT * FROM [Orders] WHERE [OrderDate] BETWEEN {0} AND {1}"), startDate, endDate) where c.CustomerID == o.CustomerID @@ -402,9 +402,9 @@ public virtual void FromSqlRaw_queryable_multiple_composed_with_parameters_and_c { var actual = (from c in context.Set().FromSqlRaw( - NormalizeDelimetersInRawString("SELECT * FROM [Customers] WHERE [City] = {0}"), city) + NormalizeDelimitersInRawString("SELECT * FROM [Customers] WHERE [City] = {0}"), city) from o in context.Set().FromSqlRaw( - NormalizeDelimetersInRawString("SELECT * FROM [Orders] WHERE [OrderDate] BETWEEN {0} AND {1}"), + NormalizeDelimitersInRawString("SELECT * FROM [Orders] WHERE [OrderDate] BETWEEN {0} AND {1}"), startDate, endDate) where c.CustomerID == o.CustomerID @@ -419,9 +419,9 @@ from o in context.Set().FromSqlRaw( actual = (from c in context.Set().FromSqlRaw( - NormalizeDelimetersInRawString("SELECT * FROM [Customers] WHERE [City] = {0}"), city) + NormalizeDelimitersInRawString("SELECT * FROM [Customers] WHERE [City] = {0}"), city) from o in context.Set().FromSqlRaw( - NormalizeDelimetersInRawString("SELECT * FROM [Orders] WHERE [OrderDate] BETWEEN {0} AND {1}"), + NormalizeDelimitersInRawString("SELECT * FROM [Orders] WHERE [OrderDate] BETWEEN {0} AND {1}"), startDate, endDate) where c.CustomerID == o.CustomerID @@ -438,7 +438,7 @@ public virtual void FromSqlRaw_queryable_multiple_line_query() using (var context = CreateContext()) { var actual = context.Set().FromSqlRaw( - NormalizeDelimetersInRawString( + NormalizeDelimitersInRawString( @"SELECT * FROM [Customers] WHERE [City] = 'London'")) @@ -455,7 +455,7 @@ public virtual void FromSqlRaw_queryable_composed_multiple_line_query() using (var context = CreateContext()) { var actual = context.Set().FromSqlRaw( - NormalizeDelimetersInRawString( + NormalizeDelimitersInRawString( @"SELECT * FROM [Customers]")) .Where(c => c.City == "London") @@ -475,7 +475,7 @@ public virtual void FromSqlRaw_queryable_with_parameters() using (var context = CreateContext()) { var actual = context.Set().FromSqlRaw( - NormalizeDelimetersInRawString("SELECT * FROM [Customers] WHERE [City] = {0} AND [ContactTitle] = {1}"), city, + NormalizeDelimitersInRawString("SELECT * FROM [Customers] WHERE [City] = {0} AND [ContactTitle] = {1}"), city, contactTitle) .ToArray(); @@ -491,7 +491,7 @@ public virtual void FromSqlRaw_queryable_with_parameters_inline() using (var context = CreateContext()) { var actual = context.Set().FromSqlRaw( - NormalizeDelimetersInRawString("SELECT * FROM [Customers] WHERE [City] = {0} AND [ContactTitle] = {1}"), "London", + NormalizeDelimitersInRawString("SELECT * FROM [Customers] WHERE [City] = {0} AND [ContactTitle] = {1}"), "London", "Sales Representative") .ToArray(); @@ -510,7 +510,7 @@ public virtual void FromSqlInterpolated_queryable_with_parameters_interpolated() using (var context = CreateContext()) { var actual = context.Set().FromSqlInterpolated( - NormalizeDelimetersInInterpolatedString( + NormalizeDelimitersInInterpolatedString( $"SELECT * FROM [Customers] WHERE [City] = {city} AND [ContactTitle] = {contactTitle}")) .ToArray(); @@ -526,7 +526,7 @@ public virtual void FromSqlInterpolated_queryable_with_parameters_inline_interpo using (var context = CreateContext()) { var actual = context.Set().FromSqlInterpolated( - NormalizeDelimetersInInterpolatedString( + NormalizeDelimitersInInterpolatedString( $"SELECT * FROM [Customers] WHERE [City] = {"London"} AND [ContactTitle] = {"Sales Representative"}")) .ToArray(); @@ -547,9 +547,9 @@ public virtual void FromSqlInterpolated_queryable_multiple_composed_with_paramet { var actual = (from c in context.Set().FromSqlRaw( - NormalizeDelimetersInRawString("SELECT * FROM [Customers] WHERE [City] = {0}"), city) + NormalizeDelimitersInRawString("SELECT * FROM [Customers] WHERE [City] = {0}"), city) from o in context.Set().FromSqlInterpolated( - NormalizeDelimetersInInterpolatedString( + NormalizeDelimitersInInterpolatedString( $"SELECT * FROM [Orders] WHERE [OrderDate] BETWEEN {startDate} AND {endDate}")) where c.CustomerID == o.CustomerID select new { c, o }) @@ -563,9 +563,9 @@ from o in context.Set().FromSqlInterpolated( actual = (from c in context.Set().FromSqlRaw( - NormalizeDelimetersInRawString("SELECT * FROM [Customers] WHERE [City] = {0}"), city) + NormalizeDelimitersInRawString("SELECT * FROM [Customers] WHERE [City] = {0}"), city) from o in context.Set().FromSqlInterpolated( - NormalizeDelimetersInInterpolatedString( + NormalizeDelimitersInInterpolatedString( $"SELECT * FROM [Orders] WHERE [OrderDate] BETWEEN {startDate} AND {endDate}")) where c.CustomerID == o.CustomerID select new { c, o }) @@ -583,7 +583,7 @@ public virtual void FromSqlRaw_queryable_with_null_parameter() using (var context = CreateContext()) { var actual = context.Set().FromSqlRaw( - NormalizeDelimetersInRawString( + NormalizeDelimitersInRawString( // ReSharper disable once ExpressionIsAlwaysNull "SELECT * FROM [Employees] WHERE [ReportsTo] = {0} OR ([ReportsTo] IS NULL AND {0} IS NULL)"), reportsTo) .ToArray(); @@ -601,7 +601,7 @@ public virtual void FromSqlRaw_queryable_with_parameters_and_closure() using (var context = CreateContext()) { var actual = context.Set().FromSqlRaw( - NormalizeDelimetersInRawString("SELECT * FROM [Customers] WHERE [City] = {0}"), city) + NormalizeDelimitersInRawString("SELECT * FROM [Customers] WHERE [City] = {0}"), city) .Where(c => c.ContactTitle == contactTitle) .ToArray(); @@ -617,14 +617,14 @@ public virtual void FromSqlRaw_queryable_simple_cache_key_includes_query_string( using (var context = CreateContext()) { var actual = context.Set() - .FromSqlRaw(NormalizeDelimetersInRawString("SELECT * FROM [Customers] WHERE [City] = 'London'")) + .FromSqlRaw(NormalizeDelimitersInRawString("SELECT * FROM [Customers] WHERE [City] = 'London'")) .ToArray(); Assert.Equal(6, actual.Length); Assert.True(actual.All(c => c.City == "London")); actual = context.Set() - .FromSqlRaw(NormalizeDelimetersInRawString("SELECT * FROM [Customers] WHERE [City] = 'Seattle'")) + .FromSqlRaw(NormalizeDelimitersInRawString("SELECT * FROM [Customers] WHERE [City] = 'Seattle'")) .ToArray(); Assert.Single(actual); @@ -641,7 +641,7 @@ public virtual void FromSqlRaw_queryable_with_parameters_cache_key_includes_para using (var context = CreateContext()) { - var actual = context.Set().FromSqlRaw(NormalizeDelimetersInRawString(sql), city, contactTitle) + var actual = context.Set().FromSqlRaw(NormalizeDelimitersInRawString(sql), city, contactTitle) .ToArray(); Assert.Equal(3, actual.Length); @@ -651,7 +651,7 @@ public virtual void FromSqlRaw_queryable_with_parameters_cache_key_includes_para city = "Madrid"; contactTitle = "Accounting Manager"; - actual = context.Set().FromSqlRaw(NormalizeDelimetersInRawString(sql), city, contactTitle) + actual = context.Set().FromSqlRaw(NormalizeDelimitersInRawString(sql), city, contactTitle) .ToArray(); Assert.Equal(2, actual.Length); @@ -665,7 +665,7 @@ public virtual void FromSqlRaw_queryable_simple_as_no_tracking_not_composed() { using (var context = CreateContext()) { - var actual = context.Set().FromSqlRaw(NormalizeDelimetersInRawString("SELECT * FROM [Customers]")) + var actual = context.Set().FromSqlRaw(NormalizeDelimitersInRawString("SELECT * FROM [Customers]")) .AsNoTracking() .ToArray(); @@ -681,7 +681,7 @@ public virtual void FromSqlRaw_queryable_simple_projection_composed() { var boolMapping = (RelationalTypeMapping)context.GetService().FindMapping(typeof(bool)); var actual = context.Set().FromSqlRaw( - NormalizeDelimetersInRawString( + NormalizeDelimitersInRawString( @"SELECT * FROM [Products] WHERE [Discontinued] <> " @@ -700,7 +700,7 @@ public virtual void FromSqlRaw_queryable_simple_include() { using (var context = CreateContext()) { - var actual = context.Set().FromSqlRaw(NormalizeDelimetersInRawString("SELECT * FROM [Customers]")) + var actual = context.Set().FromSqlRaw(NormalizeDelimitersInRawString("SELECT * FROM [Customers]")) .Include(c => c.Orders) .ToArray(); @@ -713,7 +713,7 @@ public virtual void FromSqlRaw_queryable_simple_composed_include() { using (var context = CreateContext()) { - var actual = context.Set().FromSqlRaw(NormalizeDelimetersInRawString("SELECT * FROM [Customers]")) + var actual = context.Set().FromSqlRaw(NormalizeDelimitersInRawString("SELECT * FROM [Customers]")) .Include(c => c.Orders) .Where(c => c.City == "London") .ToArray(); @@ -728,7 +728,7 @@ public virtual void FromSqlRaw_annotations_do_not_affect_successive_calls() using (var context = CreateContext()) { var actual = context.Customers - .FromSqlRaw(NormalizeDelimetersInRawString("SELECT * FROM [Customers] WHERE [ContactName] LIKE '%z%'")) + .FromSqlRaw(NormalizeDelimitersInRawString("SELECT * FROM [Customers] WHERE [ContactName] LIKE '%z%'")) .ToArray(); Assert.Equal(14, actual.Length); @@ -745,7 +745,7 @@ public virtual void FromSqlRaw_composed_with_nullable_predicate() { using (var context = CreateContext()) { - var actual = context.Set().FromSqlRaw(NormalizeDelimetersInRawString("SELECT * FROM [Customers]")) + var actual = context.Set().FromSqlRaw(NormalizeDelimitersInRawString("SELECT * FROM [Customers]")) .Where(c => c.ContactName == c.CompanyName) .ToArray(); @@ -761,7 +761,7 @@ public virtual void FromSqlRaw_with_dbParameter() var parameter = CreateDbParameter("@city", "London"); var actual = context.Customers.FromSqlRaw( - NormalizeDelimetersInRawString("SELECT * FROM [Customers] WHERE [City] = @city"), parameter) + NormalizeDelimitersInRawString("SELECT * FROM [Customers] WHERE [City] = @city"), parameter) .ToArray(); Assert.Equal(6, actual.Length); @@ -777,7 +777,7 @@ public virtual void FromSqlRaw_with_dbParameter_without_name_prefix() var parameter = CreateDbParameter("city", "London"); var actual = context.Customers.FromSqlRaw( - NormalizeDelimetersInRawString("SELECT * FROM [Customers] WHERE [City] = @city"), parameter) + NormalizeDelimitersInRawString("SELECT * FROM [Customers] WHERE [City] = @city"), parameter) .ToArray(); Assert.Equal(6, actual.Length); @@ -796,7 +796,7 @@ public virtual void FromSqlRaw_with_dbParameter_mixed() var titleParameter = CreateDbParameter("@title", title); var actual = context.Customers.FromSqlRaw( - NormalizeDelimetersInRawString( + NormalizeDelimitersInRawString( "SELECT * FROM [Customers] WHERE [City] = {0} AND [ContactTitle] = @title"), city, titleParameter) .ToArray(); @@ -807,7 +807,7 @@ public virtual void FromSqlRaw_with_dbParameter_mixed() var cityParameter = CreateDbParameter("@city", city); actual = context.Customers.FromSqlRaw( - NormalizeDelimetersInRawString( + NormalizeDelimitersInRawString( "SELECT * FROM [Customers] WHERE [City] = @city AND [ContactTitle] = {1}"), cityParameter, title) .ToArray(); @@ -855,7 +855,7 @@ public virtual void FromSqlRaw_with_db_parameters_called_multiple_times() var parameter = CreateDbParameter("@id", "ALFKI"); var query = context.Customers.FromSqlRaw( - NormalizeDelimetersInRawString("SELECT * FROM [Customers] WHERE [CustomerID] = @id"), parameter); + NormalizeDelimitersInRawString("SELECT * FROM [Customers] WHERE [CustomerID] = @id"), parameter); // ReSharper disable PossibleMultipleEnumeration var result1 = query.ToList(); @@ -875,9 +875,9 @@ public virtual void FromSqlRaw_with_SelectMany_and_include() using (var context = CreateContext()) { var query = from c1 in context.Set() - .FromSqlRaw(NormalizeDelimetersInRawString("SELECT * FROM [Customers] WHERE [CustomerID] = 'ALFKI'")) + .FromSqlRaw(NormalizeDelimitersInRawString("SELECT * FROM [Customers] WHERE [CustomerID] = 'ALFKI'")) from c2 in context.Set().FromSqlRaw( - NormalizeDelimetersInRawString("SELECT * FROM [Customers] WHERE [CustomerID] = 'AROUT'")) + NormalizeDelimitersInRawString("SELECT * FROM [Customers] WHERE [CustomerID] = 'AROUT'")) .Include(c => c.Orders) select new { c1, c2 }; @@ -904,9 +904,9 @@ public virtual void FromSqlRaw_with_join_and_include() using (var context = CreateContext()) { var query = from c in context.Set() - .FromSqlRaw(NormalizeDelimetersInRawString("SELECT * FROM [Customers] WHERE [CustomerID] = 'ALFKI'")) + .FromSqlRaw(NormalizeDelimitersInRawString("SELECT * FROM [Customers] WHERE [CustomerID] = 'ALFKI'")) join o in context.Set().FromSqlRaw( - NormalizeDelimetersInRawString("SELECT * FROM [Orders] WHERE [OrderID] <> 1")) + NormalizeDelimitersInRawString("SELECT * FROM [Orders] WHERE [OrderID] <> 1")) .Include(o => o.OrderDetails) on c.CustomerID equals o.CustomerID select new { c, o }; @@ -952,7 +952,7 @@ public virtual void FromSqlInterpolated_with_inlined_db_parameter() var actual = context.Customers .FromSqlInterpolated( - NormalizeDelimetersInInterpolatedString($"SELECT * FROM [Customers] WHERE [CustomerID] = {parameter}")) + NormalizeDelimitersInInterpolatedString($"SELECT * FROM [Customers] WHERE [CustomerID] = {parameter}")) .ToList(); Assert.Single(actual); @@ -969,7 +969,7 @@ public virtual void FromSqlInterpolated_with_inlined_db_parameter_without_name_p var actual = context.Customers .FromSqlInterpolated( - NormalizeDelimetersInInterpolatedString($"SELECT * FROM [Customers] WHERE [CustomerID] = {parameter}")) + NormalizeDelimitersInInterpolatedString($"SELECT * FROM [Customers] WHERE [CustomerID] = {parameter}")) .ToList(); Assert.Single(actual); @@ -986,7 +986,7 @@ public virtual void FromSqlInterpolated_parameterization_issue_12213() var max = 10400; var query1 = context.Orders - .FromSqlInterpolated(NormalizeDelimetersInInterpolatedString($"SELECT * FROM [Orders] WHERE [OrderID] >= {min}")) + .FromSqlInterpolated(NormalizeDelimitersInInterpolatedString($"SELECT * FROM [Orders] WHERE [OrderID] >= {min}")) .Select(i => i.OrderID); query1.ToList(); @@ -1000,7 +1000,7 @@ public virtual void FromSqlInterpolated_parameterization_issue_12213() o => o.OrderID <= max && context.Orders .FromSqlInterpolated( - NormalizeDelimetersInInterpolatedString($"SELECT * FROM [Orders] WHERE [OrderID] >= {min}")) + NormalizeDelimitersInInterpolatedString($"SELECT * FROM [Orders] WHERE [OrderID] >= {min}")) .Select(i => i.OrderID) .Contains(o.OrderID)) .Select(o => o.OrderID); @@ -1016,7 +1016,7 @@ public virtual void FromSqlRaw_does_not_parameterize_interpolated_string() var tableName = "Orders"; var max = 10250; var query = context.Orders.FromSqlRaw( - NormalizeDelimetersInRawString($"SELECT * FROM [{tableName}] WHERE [OrderID] < {{0}}"), max) + NormalizeDelimitersInRawString($"SELECT * FROM [{tableName}] WHERE [OrderID] < {{0}}"), max) .ToList(); Assert.Equal(2, query.Count); @@ -1029,7 +1029,7 @@ public virtual void Entity_equality_through_fromsql() using (var context = CreateContext()) { var actual = context.Set() - .FromSqlRaw(NormalizeDelimetersInRawString("SELECT * FROM [Orders]")) + .FromSqlRaw(NormalizeDelimitersInRawString("SELECT * FROM [Orders]")) .Where(o => o.Customer == new Customer { CustomerID = "VINET" }) .ToArray(); @@ -1043,20 +1043,20 @@ public virtual void FromSqlRaw_with_set_operation() using var context = CreateContext(); var actual = context.Set() - .FromSqlRaw(NormalizeDelimetersInRawString("SELECT * FROM [Customers] WHERE [City] = 'London'")) + .FromSqlRaw(NormalizeDelimitersInRawString("SELECT * FROM [Customers] WHERE [City] = 'London'")) .Concat( context.Set() - .FromSqlRaw(NormalizeDelimetersInRawString("SELECT * FROM [Customers] WHERE [City] = 'Berlin'"))) + .FromSqlRaw(NormalizeDelimitersInRawString("SELECT * FROM [Customers] WHERE [City] = 'Berlin'"))) .ToArray(); Assert.Equal(7, actual.Length); } - protected string NormalizeDelimetersInRawString(string sql) - => Fixture.TestStore.NormalizeDelimetersInRawString(sql); + protected string NormalizeDelimitersInRawString(string sql) + => Fixture.TestStore.NormalizeDelimitersInRawString(sql); - protected FormattableString NormalizeDelimetersInInterpolatedString(FormattableString sql) - => Fixture.TestStore.NormalizeDelimetersInInterpolatedString(sql); + protected FormattableString NormalizeDelimitersInInterpolatedString(FormattableString sql) + => Fixture.TestStore.NormalizeDelimitersInInterpolatedString(sql); protected abstract DbParameter CreateDbParameter(string name, object value); diff --git a/test/EFCore.Relational.Specification.Tests/Query/FromSqlSprocQueryTestBase.cs b/test/EFCore.Relational.Specification.Tests/Query/FromSqlSprocQueryTestBase.cs index ad293f32261..33ed6b13d65 100644 --- a/test/EFCore.Relational.Specification.Tests/Query/FromSqlSprocQueryTestBase.cs +++ b/test/EFCore.Relational.Specification.Tests/Query/FromSqlSprocQueryTestBase.cs @@ -370,7 +370,7 @@ public virtual async Task From_sql_queryable_stored_procedure_and_select(bool as var query = from mep in context.Set() .FromSqlRaw(TenMostExpensiveProductsSproc, GetTenMostExpensiveProductsParameters()) from p in context.Set() - .FromSqlRaw(NormalizeDelimetersInRawString("SELECT * FROM [Products]")) + .FromSqlRaw(NormalizeDelimitersInRawString("SELECT * FROM [Products]")) where mep.TenMostExpensiveProducts == p.ProductName select new { mep, p }; @@ -390,7 +390,7 @@ public virtual async Task From_sql_queryable_stored_procedure_and_select_on_clie var query1 = context.Set() .FromSqlRaw(TenMostExpensiveProductsSproc, GetTenMostExpensiveProductsParameters()); var query2 = context.Set() - .FromSqlRaw(NormalizeDelimetersInRawString("SELECT * FROM [Products]")); + .FromSqlRaw(NormalizeDelimitersInRawString("SELECT * FROM [Products]")); var results1 = async ? await query1.ToListAsync() : query1.ToList(); var results2 = async ? await query2.ToListAsync() : query2.ToList(); @@ -409,7 +409,7 @@ from p in results2 public virtual async Task From_sql_queryable_select_and_stored_procedure(bool async) { using var context = CreateContext(); - var query = from p in context.Set().FromSqlRaw(NormalizeDelimetersInRawString("SELECT * FROM [Products]")) + var query = from p in context.Set().FromSqlRaw(NormalizeDelimitersInRawString("SELECT * FROM [Products]")) from mep in context.Set() .FromSqlRaw(TenMostExpensiveProductsSproc, GetTenMostExpensiveProductsParameters()) where mep.TenMostExpensiveProducts == p.ProductName @@ -430,7 +430,7 @@ public virtual async Task From_sql_queryable_select_and_stored_procedure_on_clie using var context = CreateContext(); var query1 = context.Set() - .FromSqlRaw(NormalizeDelimetersInRawString("SELECT * FROM [Products]")); + .FromSqlRaw(NormalizeDelimitersInRawString("SELECT * FROM [Products]")); var query2 = context.Set() .FromSqlRaw(TenMostExpensiveProductsSproc, GetTenMostExpensiveProductsParameters()); @@ -445,8 +445,8 @@ from mep in results2 Assert.Equal(10, actual.Length); } - private string NormalizeDelimetersInRawString(string sql) - => Fixture.TestStore.NormalizeDelimetersInRawString(sql); + private string NormalizeDelimitersInRawString(string sql) + => Fixture.TestStore.NormalizeDelimitersInRawString(sql); protected virtual object[] GetTenMostExpensiveProductsParameters() => Array.Empty(); diff --git a/test/EFCore.Relational.Specification.Tests/Query/GearsOfWarFromSqlQueryTestBase.cs b/test/EFCore.Relational.Specification.Tests/Query/GearsOfWarFromSqlQueryTestBase.cs index c23cc5c575f..fc550c6b24f 100644 --- a/test/EFCore.Relational.Specification.Tests/Query/GearsOfWarFromSqlQueryTestBase.cs +++ b/test/EFCore.Relational.Specification.Tests/Query/GearsOfWarFromSqlQueryTestBase.cs @@ -22,7 +22,7 @@ public virtual void From_sql_queryable_simple_columns_out_of_order() using (var context = CreateContext()) { var actual = context.Set().FromSqlRaw( - NormalizeDelimetersInRawString( + NormalizeDelimitersInRawString( "SELECT [Id], [Name], [IsAutomatic], [AmmunitionType], [OwnerFullName], [SynergyWithId] FROM [Weapons] ORDER BY [Name]")) .ToArray(); @@ -35,11 +35,11 @@ public virtual void From_sql_queryable_simple_columns_out_of_order() } } - private string NormalizeDelimetersInRawString(string sql) - => Fixture.TestStore.NormalizeDelimetersInRawString(sql); + private string NormalizeDelimitersInRawString(string sql) + => Fixture.TestStore.NormalizeDelimitersInRawString(sql); - private FormattableString NormalizeDelimetersInInterpolatedString(FormattableString sql) - => Fixture.TestStore.NormalizeDelimetersInInterpolatedString(sql); + private FormattableString NormalizeDelimitersInInterpolatedString(FormattableString sql) + => Fixture.TestStore.NormalizeDelimitersInInterpolatedString(sql); protected GearsOfWarContext CreateContext() => Fixture.CreateContext(); diff --git a/test/EFCore.Relational.Specification.Tests/Query/InheritanceRelationalTestBase.cs b/test/EFCore.Relational.Specification.Tests/Query/InheritanceRelationalTestBase.cs index 797dddd616b..eb742f683f2 100644 --- a/test/EFCore.Relational.Specification.Tests/Query/InheritanceRelationalTestBase.cs +++ b/test/EFCore.Relational.Specification.Tests/Query/InheritanceRelationalTestBase.cs @@ -23,7 +23,7 @@ public virtual void FromSql_on_root() { using (var context = CreateContext()) { - context.Set().FromSqlRaw(NormalizeDelimetersInRawString("select * from [Animal]")).ToList(); + context.Set().FromSqlRaw(NormalizeDelimitersInRawString("select * from [Animal]")).ToList(); } } @@ -32,14 +32,14 @@ public virtual void FromSql_on_derived() { using (var context = CreateContext()) { - context.Set().FromSqlRaw(NormalizeDelimetersInRawString("select * from [Animal]")).ToList(); + context.Set().FromSqlRaw(NormalizeDelimitersInRawString("select * from [Animal]")).ToList(); } } - private string NormalizeDelimetersInRawString(string sql) - => ((RelationalTestStore)Fixture.TestStore).NormalizeDelimetersInRawString(sql); + private string NormalizeDelimitersInRawString(string sql) + => ((RelationalTestStore)Fixture.TestStore).NormalizeDelimitersInRawString(sql); - private FormattableString NormalizeDelimetersInInterpolatedString(FormattableString sql) - => ((RelationalTestStore)Fixture.TestStore).NormalizeDelimetersInInterpolatedString(sql); + private FormattableString NormalizeDelimitersInInterpolatedString(FormattableString sql) + => ((RelationalTestStore)Fixture.TestStore).NormalizeDelimitersInInterpolatedString(sql); } } diff --git a/test/EFCore.Relational.Specification.Tests/Query/NullSemanticsQueryTestBase.cs b/test/EFCore.Relational.Specification.Tests/Query/NullSemanticsQueryTestBase.cs index 8470e2a762d..4887d1391d1 100644 --- a/test/EFCore.Relational.Specification.Tests/Query/NullSemanticsQueryTestBase.cs +++ b/test/EFCore.Relational.Specification.Tests/Query/NullSemanticsQueryTestBase.cs @@ -745,7 +745,7 @@ public virtual void From_sql_composed_with_relational_null_comparison() using (var context = CreateContext(useRelationalNulls: true)) { var actual = context.Entities1 - .FromSqlRaw(NormalizeDelimetersInRawString("SELECT * FROM [Entities1]")) + .FromSqlRaw(NormalizeDelimitersInRawString("SELECT * FROM [Entities1]")) .Where(c => c.StringA == c.StringB) .ToArray(); @@ -1065,11 +1065,11 @@ protected static TResult Maybe(object caller, Func expression) return caller == null ? null : expression(); } - private string NormalizeDelimetersInRawString(string sql) - => Fixture.TestStore.NormalizeDelimetersInRawString(sql); + private string NormalizeDelimitersInRawString(string sql) + => Fixture.TestStore.NormalizeDelimitersInRawString(sql); private FormattableString NormalizeDelimetersInInterpolatedString(FormattableString sql) - => Fixture.TestStore.NormalizeDelimetersInInterpolatedString(sql); + => Fixture.TestStore.NormalizeDelimitersInInterpolatedString(sql); protected abstract NullSemanticsContext CreateContext(bool useRelationalNulls = false); diff --git a/test/EFCore.Relational.Specification.Tests/Query/QueryNoClientEvalTestBase.cs b/test/EFCore.Relational.Specification.Tests/Query/QueryNoClientEvalTestBase.cs index 8ef7b5112d8..451f5c30918 100644 --- a/test/EFCore.Relational.Specification.Tests/Query/QueryNoClientEvalTestBase.cs +++ b/test/EFCore.Relational.Specification.Tests/Query/QueryNoClientEvalTestBase.cs @@ -80,7 +80,7 @@ public virtual void Throws_when_from_sql_composed() { AssertTranslationFailed( () => context.Customers - .FromSqlRaw(NormalizeDelimetersInRawString("select * from [Customers]")) + .FromSqlRaw(NormalizeDelimitersInRawString("select * from [Customers]")) .Where(c => c.IsLondon) .ToList()); } @@ -93,7 +93,7 @@ public virtual void Doesnt_throw_when_from_sql_not_composed() { var customers = context.Customers - .FromSqlRaw(NormalizeDelimetersInRawString("select * from [Customers]")) + .FromSqlRaw(NormalizeDelimitersInRawString("select * from [Customers]")) .ToList(); Assert.Equal(91, customers.Count); @@ -232,8 +232,8 @@ public virtual void Throws_when_single_or_default() } } - private string NormalizeDelimetersInRawString(string sql) - => Fixture.TestStore.NormalizeDelimetersInRawString(sql); + private string NormalizeDelimitersInRawString(string sql) + => Fixture.TestStore.NormalizeDelimitersInRawString(sql); private void AssertTranslationFailed(Action testCode) { diff --git a/test/EFCore.Relational.Specification.Tests/TestUtilities/RelationalDatabaseCleaner.cs b/test/EFCore.Relational.Specification.Tests/TestUtilities/RelationalDatabaseCleaner.cs index 85b981d5e87..709f36b83fb 100644 --- a/test/EFCore.Relational.Specification.Tests/TestUtilities/RelationalDatabaseCleaner.cs +++ b/test/EFCore.Relational.Specification.Tests/TestUtilities/RelationalDatabaseCleaner.cs @@ -129,7 +129,7 @@ private static void ExecuteScript(IRelationalConnection connection, IRawSqlComma } sqlBuilder.Build(batches[i]) - .ExecuteNonQuery(new RelationalCommandParameterObject(connection, null, null, null)); + .ExecuteNonQuery(new RelationalCommandParameterObject(connection, null, null, null, null)); } } diff --git a/test/EFCore.Relational.Specification.Tests/TestUtilities/RelationalTestStore.cs b/test/EFCore.Relational.Specification.Tests/TestUtilities/RelationalTestStore.cs index f458093cd1c..eceb4d9a564 100644 --- a/test/EFCore.Relational.Specification.Tests/TestUtilities/RelationalTestStore.cs +++ b/test/EFCore.Relational.Specification.Tests/TestUtilities/RelationalTestStore.cs @@ -43,14 +43,14 @@ public override void Dispose() base.Dispose(); } - public virtual string NormalizeDelimetersInRawString(string sql) - => sql.Replace("[", OpenDelimeter).Replace("]", CloseDelimeter); + public virtual string NormalizeDelimitersInRawString(string sql) + => sql.Replace("[", OpenDelimiter).Replace("]", CloseDelimiter); - public virtual FormattableString NormalizeDelimetersInInterpolatedString(FormattableString sql) - => new TestFormattableString(NormalizeDelimetersInRawString(sql.Format), sql.GetArguments()); + public virtual FormattableString NormalizeDelimitersInInterpolatedString(FormattableString sql) + => new TestFormattableString(NormalizeDelimitersInRawString(sql.Format), sql.GetArguments()); - protected virtual string OpenDelimeter => "\""; + protected virtual string OpenDelimiter => "\""; - protected virtual string CloseDelimeter => "\""; + protected virtual string CloseDelimiter => "\""; } } diff --git a/test/EFCore.Relational.Tests/Query/Internal/BufferedDataReaderTest.cs b/test/EFCore.Relational.Tests/Query/Internal/BufferedDataReaderTest.cs new file mode 100644 index 00000000000..825d3cf2e6e --- /dev/null +++ b/test/EFCore.Relational.Tests/Query/Internal/BufferedDataReaderTest.cs @@ -0,0 +1,217 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Data.Common; +using System.Reflection; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.EntityFrameworkCore.Storage; +using Microsoft.EntityFrameworkCore.TestUtilities.FakeProvider; +using Xunit; + +// ReSharper disable InconsistentNaming +namespace Microsoft.EntityFrameworkCore.Query.Internal +{ + public class BufferedDataReaderTest + { + public static IEnumerable IsAsyncData = new[] { new object[] { false }, new object[] { true } }; + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public async Task Metadata_methods_return_expected_results(bool async) + { + var reader = new FakeDbDataReader(new[] { "columnName" }, new[] { new[] { new object() }, new[] { new object() } }); + var columns = new ReaderColumn[] { new ReaderColumn(true, null, (r, _) => r.GetValue(0)) }; + var bufferedDataReader = new BufferedDataReader(reader); + if (async) + { + await bufferedDataReader.InitializeAsync(columns, CancellationToken.None); + } + else + { + bufferedDataReader.Initialize(columns); + } + + Assert.Equal(1, bufferedDataReader.FieldCount); + Assert.Equal(0, bufferedDataReader.GetOrdinal("columnName")); + Assert.Equal(typeof(object).Name, bufferedDataReader.GetDataTypeName(0)); + Assert.Equal(typeof(object), bufferedDataReader.GetFieldType(0)); + Assert.Equal("columnName", bufferedDataReader.GetName(0)); + Assert.Equal(2, bufferedDataReader.RecordsAffected); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public async Task Manipulation_methods_perform_expected_actions(bool async) + { + var reader = new FakeDbDataReader( + new[] { "id", "name" }, + new List> { new[] { new object[] { 1, "a" } }, new object[0][] }); + var columns = new ReaderColumn[] + { + new ReaderColumn(false, null, (r, _) => r.GetInt32(0)), new ReaderColumn(true, null, (r, _) => r.GetValue(1)) + }; + + var bufferedDataReader = new BufferedDataReader(reader); + + Assert.False(bufferedDataReader.IsClosed); + if (async) + { + await bufferedDataReader.InitializeAsync(columns, CancellationToken.None); + } + else + { + bufferedDataReader.Initialize(columns); + } + + Assert.False(bufferedDataReader.IsClosed); + + Assert.True(bufferedDataReader.HasRows); + + if (async) + { + Assert.True(await bufferedDataReader.ReadAsync()); + Assert.False(await bufferedDataReader.ReadAsync()); + } + else + { + Assert.True(bufferedDataReader.Read()); + Assert.False(bufferedDataReader.Read()); + } + + Assert.True(bufferedDataReader.HasRows); + + if (async) + { + Assert.True(await bufferedDataReader.NextResultAsync()); + } + else + { + Assert.True(bufferedDataReader.NextResult()); + } + + Assert.False(bufferedDataReader.HasRows); + + if (async) + { + Assert.False(await bufferedDataReader.ReadAsync()); + Assert.False(await bufferedDataReader.NextResultAsync()); + } + else + { + Assert.False(bufferedDataReader.Read()); + Assert.False(bufferedDataReader.NextResult()); + } + + Assert.False(bufferedDataReader.IsClosed); + bufferedDataReader.Close(); + Assert.True(bufferedDataReader.IsClosed); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public async Task Initialize_is_idempotent(bool isAsync) + { + var reader = new FakeDbDataReader(new[] { "name" }, new[] { new[] { new object() } }); + var columns = new ReaderColumn[] { new ReaderColumn(true, null, (r, _) => r.GetValue(0)) }; + var bufferedReader = new BufferedDataReader(reader); + + Assert.False(reader.IsClosed); + if (isAsync) + { + await bufferedReader.InitializeAsync(columns, CancellationToken.None); + } + else + { + bufferedReader.Initialize(columns); + } + + Assert.True(reader.IsClosed); + + if (isAsync) + { + await bufferedReader.InitializeAsync(columns, CancellationToken.None); + } + else + { + bufferedReader.Initialize(columns); + } + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public async Task Data_methods_return_expected_results(bool async) + { + await Verify_get_method_returns_supplied_value(true, async); + await Verify_get_method_returns_supplied_value((byte)1, async); + await Verify_get_method_returns_supplied_value((short)1, async); + await Verify_get_method_returns_supplied_value(1, async); + await Verify_get_method_returns_supplied_value(1L, async); + await Verify_get_method_returns_supplied_value(1F, async); + await Verify_get_method_returns_supplied_value(1D, async); + await Verify_get_method_returns_supplied_value(1M, async); + await Verify_get_method_returns_supplied_value('a', async); + await Verify_get_method_returns_supplied_value("a", async); + await Verify_get_method_returns_supplied_value(DateTime.Now, async); + await Verify_get_method_returns_supplied_value(Guid.NewGuid(), async); + var obj = new object(); + await Verify_method_result(r => r.GetValue(0), async, obj, new[] { obj }); + await Verify_method_result(r => r.GetFieldValue(0), async, obj, new[] { obj }); + await Verify_method_result(r => r.GetFieldValueAsync(0).Result, async, obj, new[] { obj }); + await Verify_method_result(r => r.IsDBNull(0), async, true, new object[] { DBNull.Value }); + await Verify_method_result(r => r.IsDBNull(0), async, false, new object[] { true }); + await Verify_method_result(r => r.IsDBNullAsync(0).Result, async, true, new object[] { DBNull.Value }); + await Verify_method_result(r => r.IsDBNullAsync(0).Result, async, false, new object[] { true }); + + await Assert.ThrowsAsync( + () => Verify_method_result(r => r.GetBytes(0, 0, new byte[0], 0, 0), async, 0, new object[] { 1L })); + await Assert.ThrowsAsync( + () => Verify_method_result(r => r.GetChars(0, 0, new char[0], 0, 0), async, 0, new object[] { 1L })); + } + + private async Task Verify_method_result( + Func method, bool async, T expectedResult, + params object[][] dataReaderContents) + { + var reader = new FakeDbDataReader(new[] { "name" }, dataReaderContents); + var columnType = typeof(T); + if (!columnType.IsValueType) + { + columnType = typeof(object); + } + + var columns = new[] + { + ReaderColumn.Create(columnType, true, null, (Func)((r, _) => r.GetFieldValue(0))) + }; + + var bufferedReader = new BufferedDataReader(reader); + if (async) + { + await bufferedReader.InitializeAsync(columns, CancellationToken.None); + + Assert.True(await bufferedReader.ReadAsync()); + } + else + { + bufferedReader.Initialize(columns); + + Assert.True(bufferedReader.Read()); + } + + Assert.Equal(expectedResult, method(bufferedReader)); + } + + private Task Verify_get_method_returns_supplied_value(T value, bool async) + { + // use the specific reader.GetXXX method + var readerMethod = GetReaderMethod(typeof(T)); + return Verify_method_result( + r => (T)readerMethod.Invoke(r, new object[] { 0 }), async, value, new object[] { value }); + } + + private static MethodInfo GetReaderMethod(Type type) => RelationalTypeMapping.GetDataReaderMethod(type); + } +} diff --git a/test/EFCore.Relational.Tests/Storage/RelationalCommandTest.cs b/test/EFCore.Relational.Tests/Storage/RelationalCommandTest.cs index fe76bc30dd8..de1a836b2e6 100644 --- a/test/EFCore.Relational.Tests/Storage/RelationalCommandTest.cs +++ b/test/EFCore.Relational.Tests/Storage/RelationalCommandTest.cs @@ -44,7 +44,7 @@ public void Configures_DbCommand() var relationalCommand = CreateRelationalCommand(commandText: "CommandText"); relationalCommand.ExecuteNonQuery( - new RelationalCommandParameterObject(fakeConnection, null, null, null)); + new RelationalCommandParameterObject(fakeConnection, null, null, null, null)); Assert.Equal(1, fakeConnection.DbConnections.Count); Assert.Equal(1, fakeConnection.DbConnections[0].DbCommands.Count); @@ -66,7 +66,7 @@ public void Configures_DbCommand_with_transaction() var relationalCommand = CreateRelationalCommand(); relationalCommand.ExecuteNonQuery( - new RelationalCommandParameterObject(fakeConnection, null, null, null)); + new RelationalCommandParameterObject(fakeConnection, null, null, null, null)); Assert.Equal(1, fakeConnection.DbConnections.Count); Assert.Equal(1, fakeConnection.DbConnections[0].DbCommands.Count); @@ -88,7 +88,7 @@ public void Configures_DbCommand_with_timeout() var relationalCommand = CreateRelationalCommand(); relationalCommand.ExecuteNonQuery( - new RelationalCommandParameterObject(fakeConnection, null, null, null)); + new RelationalCommandParameterObject(fakeConnection, null, null, null, null)); Assert.Equal(1, fakeConnection.DbConnections.Count); Assert.Equal(1, fakeConnection.DbConnections[0].DbCommands.Count); @@ -122,7 +122,7 @@ public void Can_ExecuteNonQuery() var result = relationalCommand.ExecuteNonQuery( new RelationalCommandParameterObject( - new FakeRelationalConnection(options), null, null, null)); + new FakeRelationalConnection(options), null, null, null, null)); Assert.Equal(1, result); @@ -162,7 +162,7 @@ public virtual async Task Can_ExecuteNonQueryAsync() var result = await relationalCommand.ExecuteNonQueryAsync( new RelationalCommandParameterObject( - new FakeRelationalConnection(options), null, null, null)); + new FakeRelationalConnection(options), null, null, null, null)); Assert.Equal(1, result); @@ -202,7 +202,7 @@ public void Can_ExecuteScalar() var result = (string)relationalCommand.ExecuteScalar( new RelationalCommandParameterObject( - new FakeRelationalConnection(options), null, null, null)); + new FakeRelationalConnection(options), null, null, null, null)); Assert.Equal("ExecuteScalar Result", result); @@ -242,7 +242,7 @@ public async Task Can_ExecuteScalarAsync() var result = (string)await relationalCommand.ExecuteScalarAsync( new RelationalCommandParameterObject( - new FakeRelationalConnection(options), null, null, null)); + new FakeRelationalConnection(options), null, null, null, null)); Assert.Equal("ExecuteScalar Result", result); @@ -284,7 +284,7 @@ public void Can_ExecuteReader() var result = relationalCommand.ExecuteReader( new RelationalCommandParameterObject( - new FakeRelationalConnection(options), null, null, null)); + new FakeRelationalConnection(options), null, null, null, null)); Assert.Same(dbDataReader, result.DbDataReader); Assert.Equal(0, fakeDbConnection.CloseCount); @@ -333,7 +333,7 @@ public async Task Can_ExecuteReaderAsync() var result = await relationalCommand.ExecuteReaderAsync( new RelationalCommandParameterObject( - new FakeRelationalConnection(options), null, null, null)); + new FakeRelationalConnection(options), null, null, null, null)); Assert.Same(dbDataReader, result.DbDataReader); Assert.Equal(0, fakeDbConnection.CloseCount); @@ -363,42 +363,42 @@ public static TheoryData CommandActions new CommandAction( (connection, command, parameterValues, logger) => command.ExecuteNonQuery( - new RelationalCommandParameterObject(connection, parameterValues, null, logger))), + new RelationalCommandParameterObject(connection, parameterValues, null, null, logger))), DbCommandMethod.ExecuteNonQuery, false }, { new CommandAction( (connection, command, parameterValues, logger) => command.ExecuteScalar( - new RelationalCommandParameterObject(connection, parameterValues, null, logger))), + new RelationalCommandParameterObject(connection, parameterValues, null, null, logger))), DbCommandMethod.ExecuteScalar, false }, { new CommandAction( (connection, command, parameterValues, logger) => command.ExecuteReader( - new RelationalCommandParameterObject(connection, parameterValues, null, logger))), + new RelationalCommandParameterObject(connection, parameterValues, null, null, logger))), DbCommandMethod.ExecuteReader, false }, { new CommandFunc( (connection, command, parameterValues, logger) => command.ExecuteNonQueryAsync( - new RelationalCommandParameterObject(connection, parameterValues, null, logger))), + new RelationalCommandParameterObject(connection, parameterValues, null, null, logger))), DbCommandMethod.ExecuteNonQuery, true }, { new CommandFunc( (connection, command, parameterValues, logger) => command.ExecuteScalarAsync( - new RelationalCommandParameterObject(connection, parameterValues, null, logger))), + new RelationalCommandParameterObject(connection, parameterValues, null, null, logger))), DbCommandMethod.ExecuteScalar, true }, { new CommandFunc( (connection, command, parameterValues, logger) => command.ExecuteReaderAsync( - new RelationalCommandParameterObject(connection, parameterValues, null, logger))), + new RelationalCommandParameterObject(connection, parameterValues, null, null, logger))), DbCommandMethod.ExecuteReader, true } }; diff --git a/test/EFCore.Relational.Tests/TestUtilities/FakeProvider/FakeDbDataReader.cs b/test/EFCore.Relational.Tests/TestUtilities/FakeProvider/FakeDbDataReader.cs index cf430d3b8d9..48957c7523f 100644 --- a/test/EFCore.Relational.Tests/TestUtilities/FakeProvider/FakeDbDataReader.cs +++ b/test/EFCore.Relational.Tests/TestUtilities/FakeProvider/FakeDbDataReader.cs @@ -5,6 +5,7 @@ using System.Collections; using System.Collections.Generic; using System.Data.Common; +using System.Linq; using System.Threading; using System.Threading.Tasks; @@ -13,15 +14,26 @@ namespace Microsoft.EntityFrameworkCore.TestUtilities.FakeProvider public class FakeDbDataReader : DbDataReader { private readonly string[] _columnNames; - private readonly IList _results; + private IList _results; + private readonly IList> _resultSets; + private int _currentResultSet; private object[] _currentRow; private int _rowIndex; + private bool _closed; public FakeDbDataReader(string[] columnNames = null, IList results = null) { _columnNames = columnNames ?? Array.Empty(); _results = results ?? new List(); + _resultSets = new List> { _results }; + } + + public FakeDbDataReader(string[] columnNames, IList> resultSets) + { + _columnNames = columnNames ?? Array.Empty(); + _resultSets = resultSets ?? new List> { new List() }; + _results = _resultSets[0]; } public override bool Read() @@ -51,6 +63,7 @@ public override Task ReadAsync(CancellationToken cancellationToken) public override void Close() { CloseCount++; + _closed = true; } public int DisposeCount { get; private set; } @@ -63,6 +76,8 @@ protected override void Dispose(bool disposing) base.Dispose(true); } + + _closed = true; } public override int FieldCount => _columnNames.Length; @@ -88,11 +103,11 @@ public override int GetInt32(int ordinal) public override int Depth => throw new NotImplementedException(); - public override bool HasRows => throw new NotImplementedException(); + public override bool HasRows => _results.Count != 0; - public override bool IsClosed => throw new NotImplementedException(); + public override bool IsClosed => _closed; - public override int RecordsAffected => 0; + public override int RecordsAffected => _resultSets.Aggregate(0, (a, r) => a + r.Count); public override bool GetBoolean(int ordinal) => (bool)_currentRow[ordinal]; @@ -110,10 +125,7 @@ public override long GetChars(int ordinal, long dataOffset, char[] buffer, int b throw new NotImplementedException(); } - public override string GetDataTypeName(int ordinal) - { - throw new NotImplementedException(); - } + public override string GetDataTypeName(int ordinal) => GetFieldType(ordinal).Name; public override DateTime GetDateTime(int ordinal) => (DateTime)_currentRow[ordinal]; @@ -127,9 +139,9 @@ public override IEnumerator GetEnumerator() } public override Type GetFieldType(int ordinal) - { - throw new NotImplementedException(); - } + => _results.Count > 0 + ? _results[0][ordinal]?.GetType() ?? typeof(object) + : typeof(object); public override float GetFloat(int ordinal) => (float)_currentRow[ordinal]; @@ -153,7 +165,13 @@ public override int GetValues(object[] values) public override bool NextResult() { - throw new NotImplementedException(); + var hasResult = _resultSets.Count > ++_currentResultSet; + if (hasResult) + { + _results = _resultSets[_currentResultSet]; + } + + return hasResult; } } } diff --git a/test/EFCore.Specification.Tests/InterceptionTestBase.cs b/test/EFCore.Specification.Tests/InterceptionTestBase.cs index e678eed898d..941a4a69b78 100644 --- a/test/EFCore.Specification.Tests/InterceptionTestBase.cs +++ b/test/EFCore.Specification.Tests/InterceptionTestBase.cs @@ -177,7 +177,7 @@ public virtual ITestDiagnosticListener SubscribeToDiagnosticListener(DbContextId public virtual DbContextOptions CreateOptions( IEnumerable appInterceptors, IEnumerable injectedInterceptors) - => base.AddOptions( + => AddOptions( TestStore .AddProviderOptions( new DbContextOptionsBuilder() diff --git a/test/EFCore.Specification.Tests/TestUtilities/DataGenerator.cs b/test/EFCore.Specification.Tests/TestUtilities/DataGenerator.cs new file mode 100644 index 00000000000..91c12789524 --- /dev/null +++ b/test/EFCore.Specification.Tests/TestUtilities/DataGenerator.cs @@ -0,0 +1,47 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Concurrent; +using System.Linq; + +namespace Microsoft.EntityFrameworkCore.TestUtilities +{ + public static class DataGenerator + { + private static readonly ConcurrentDictionary _boolCombinations + = new ConcurrentDictionary(); + + public static object[][] GetBoolCombinations(int length) + => _boolCombinations.GetOrAdd(length, l => GetCombinations(new object[] { false, true }, l)); + + public static object[][] GetCombinations(object[] set, int length) + { + var sets = new object[length][]; + Array.Fill(sets, set); + return GetCombinations(sets); + } + + public static object[][] GetCombinations(object[][] sets) + { + var numberOfCombinations = sets.Aggregate(1L, (current, set) => current * set.Length); + var combinations = new object[numberOfCombinations][]; + + for (var i = 0L; i < numberOfCombinations; i++) + { + var combination = new object[sets.Length]; + var temp = i; + for (var j = 0; j < sets.Length; j++) + { + var set = sets[j]; + combination[j] = set[(int)(temp % set.Length)]; + temp /= set.Length; + } + + combinations[i] = combination; + } + + return combinations; + } + } +} diff --git a/test/EFCore.Specification.Tests/TestUtilities/TestHelpers.cs b/test/EFCore.Specification.Tests/TestUtilities/TestHelpers.cs index dd02d3347f3..db8b0b2482f 100644 --- a/test/EFCore.Specification.Tests/TestUtilities/TestHelpers.cs +++ b/test/EFCore.Specification.Tests/TestUtilities/TestHelpers.cs @@ -75,7 +75,7 @@ public void TestDependenciesClone(params string[] ignorePropertie } else { - Assert.Same(property.GetValue(clone), property.GetValue(dependencies)); + Assert.Equal(property.GetValue(clone), property.GetValue(dependencies)); } } } diff --git a/test/EFCore.SqlServer.FunctionalTests/CommandInterceptionSqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/CommandInterceptionSqlServerTest.cs index c84c66240d0..6a5a347c193 100644 --- a/test/EFCore.SqlServer.FunctionalTests/CommandInterceptionSqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/CommandInterceptionSqlServerTest.cs @@ -4,6 +4,8 @@ using System.Collections.Generic; using System.Threading.Tasks; using Microsoft.EntityFrameworkCore.Diagnostics; +using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.SqlServer.Storage.Internal; using Microsoft.EntityFrameworkCore.TestUtilities; using Microsoft.Extensions.DependencyInjection; using Xunit; @@ -66,6 +68,13 @@ public CommandInterceptionSqlServerTest(InterceptionSqlServerFixture fixture) public class InterceptionSqlServerFixture : InterceptionSqlServerFixtureBase { protected override bool ShouldSubscribeToDiagnosticListener => false; + + public override DbContextOptionsBuilder AddOptions(DbContextOptionsBuilder builder) + { + new SqlServerDbContextOptionsBuilder(base.AddOptions(builder)) + .ExecutionStrategy(d => new SqlServerExecutionStrategy(d)); + return builder; + } } } @@ -81,6 +90,13 @@ public CommandInterceptionWithDiagnosticsSqlServerTest(InterceptionSqlServerFixt public class InterceptionSqlServerFixture : InterceptionSqlServerFixtureBase { protected override bool ShouldSubscribeToDiagnosticListener => true; + + public override DbContextOptionsBuilder AddOptions(DbContextOptionsBuilder builder) + { + new SqlServerDbContextOptionsBuilder(base.AddOptions(builder)) + .ExecutionStrategy(d => new SqlServerExecutionStrategy(d)); + return builder; + } } } } diff --git a/test/EFCore.SqlServer.FunctionalTests/ExecutionStrategyTest.cs b/test/EFCore.SqlServer.FunctionalTests/ExecutionStrategyTest.cs index 772d4caa7ea..55ef0ee8b5b 100644 --- a/test/EFCore.SqlServer.FunctionalTests/ExecutionStrategyTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/ExecutionStrategyTest.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; using System.Data; using System.Linq; using System.Threading; @@ -31,65 +32,56 @@ public ExecutionStrategyTest(ExecutionStrategyFixture fixture) protected ExecutionStrategyFixture Fixture { get; } - [ConditionalFact] - public void Does_not_throw_or_retry_on_false_commit_failure() - { - Test_commit_failure(false); - } - - [ConditionalFact] - public void Retries_on_true_commit_failure() - { - Test_commit_failure(true); - } - - private void Test_commit_failure(bool realFailure) + [ConditionalTheory] + [MemberData(nameof(DataGenerator.GetBoolCombinations), 1, MemberType = typeof(DataGenerator))] + public void Handles_commit_failure(bool realFailure) { + // Use all overloads of ExecuteInTransaction Test_commit_failure( realFailure, (e, db) => e.ExecuteInTransaction( - () => db.SaveChanges(acceptAllChangesOnSuccess: false), + () => { db.SaveChanges(false); }, () => db.Products.AsNoTracking().Any())); Test_commit_failure( realFailure, (e, db) => e.ExecuteInTransaction( - () => db.SaveChanges(acceptAllChangesOnSuccess: false), + () => db.SaveChanges(false), () => db.Products.AsNoTracking().Any())); Test_commit_failure( realFailure, (e, db) => e.ExecuteInTransaction( db, - c => c.SaveChanges(acceptAllChangesOnSuccess: false), + c => { c.SaveChanges(false); }, c => c.Products.AsNoTracking().Any())); Test_commit_failure( realFailure, (e, db) => e.ExecuteInTransaction( db, - c => c.SaveChanges(acceptAllChangesOnSuccess: false), + c => c.SaveChanges(false), c => c.Products.AsNoTracking().Any())); Test_commit_failure( realFailure, (e, db) => e.ExecuteInTransaction( - () => db.SaveChanges(acceptAllChangesOnSuccess: false), + () => { db.SaveChanges(false); }, () => db.Products.AsNoTracking().Any(), IsolationLevel.Serializable)); Test_commit_failure( realFailure, (e, db) => e.ExecuteInTransaction( - () => db.SaveChanges(acceptAllChangesOnSuccess: false), + () => db.SaveChanges(false), () => db.Products.AsNoTracking().Any(), IsolationLevel.Serializable)); Test_commit_failure( realFailure, (e, db) => e.ExecuteInTransaction( db, - c => c.SaveChanges(acceptAllChangesOnSuccess: false), + c => { c.SaveChanges(false); }, c => c.Products.AsNoTracking().Any(), IsolationLevel.Serializable)); Test_commit_failure( realFailure, (e, db) => e.ExecuteInTransaction( db, - c => c.SaveChanges(acceptAllChangesOnSuccess: false), + c => c.SaveChanges(false), c => c.Products.AsNoTracking().Any(), IsolationLevel.Serializable)); } @@ -133,87 +125,77 @@ private void Test_commit_failure(bool realFailure, Action e.ExecuteInTransactionAsync( - () => db.SaveChangesAsync(acceptAllChangesOnSuccess: false), + () => db.SaveChangesAsync(false), () => db.Products.AsNoTracking().AnyAsync())); - var cancellationToken = CancellationToken.None; await Test_commit_failure_async( realFailure, (e, db) => e.ExecuteInTransactionAsync( - async ct => await db.SaveChangesAsync(acceptAllChangesOnSuccess: false), + async ct => { await db.SaveChangesAsync(false); }, ct => db.Products.AsNoTracking().AnyAsync(), - cancellationToken)); + CancellationToken.None)); await Test_commit_failure_async( realFailure, (e, db) => e.ExecuteInTransactionAsync( - ct => db.SaveChangesAsync(acceptAllChangesOnSuccess: false), + ct => db.SaveChangesAsync(false, ct), ct => db.Products.AsNoTracking().AnyAsync(), - cancellationToken)); + CancellationToken.None)); await Test_commit_failure_async( realFailure, (e, db) => e.ExecuteInTransactionAsync( db, - async (c, ct) => await c.SaveChangesAsync(acceptAllChangesOnSuccess: false), + async (c, ct) => { await c.SaveChangesAsync(false, ct); }, (c, ct) => c.Products.AsNoTracking().AnyAsync(), - cancellationToken)); + CancellationToken.None)); await Test_commit_failure_async( realFailure, (e, db) => e.ExecuteInTransactionAsync( db, - (c, ct) => c.SaveChangesAsync(acceptAllChangesOnSuccess: false), + (c, ct) => c.SaveChangesAsync(false, ct), (c, ct) => c.Products.AsNoTracking().AnyAsync(), - cancellationToken)); + CancellationToken.None)); await Test_commit_failure_async( realFailure, (e, db) => e.ExecuteInTransactionAsync( - () => db.SaveChangesAsync(acceptAllChangesOnSuccess: false), + () => db.SaveChangesAsync(false), () => db.Products.AsNoTracking().AnyAsync(), IsolationLevel.Serializable)); await Test_commit_failure_async( realFailure, (e, db) => e.ExecuteInTransactionAsync( - async ct => await db.SaveChangesAsync(acceptAllChangesOnSuccess: false), - ct => db.Products.AsNoTracking().AnyAsync(), + async ct => { await db.SaveChangesAsync(false, ct); }, + ct => db.Products.AsNoTracking().AnyAsync(ct), IsolationLevel.Serializable, - cancellationToken)); + CancellationToken.None)); await Test_commit_failure_async( realFailure, (e, db) => e.ExecuteInTransactionAsync( - ct => db.SaveChangesAsync(acceptAllChangesOnSuccess: false), - ct => db.Products.AsNoTracking().AnyAsync(), + ct => db.SaveChangesAsync(false, ct), + ct => db.Products.AsNoTracking().AnyAsync(ct), IsolationLevel.Serializable, - cancellationToken)); + CancellationToken.None)); await Test_commit_failure_async( realFailure, (e, db) => e.ExecuteInTransactionAsync( db, - async (c, ct) => await c.SaveChangesAsync(acceptAllChangesOnSuccess: false), - (c, ct) => c.Products.AsNoTracking().AnyAsync(), + async (c, ct) => { await c.SaveChangesAsync(false, ct); }, + (c, ct) => c.Products.AsNoTracking().AnyAsync(ct), IsolationLevel.Serializable, - cancellationToken)); + CancellationToken.None)); await Test_commit_failure_async( realFailure, (e, db) => e.ExecuteInTransactionAsync( db, - (c, ct) => c.SaveChangesAsync(acceptAllChangesOnSuccess: false), - (c, ct) => c.Products.AsNoTracking().AnyAsync(), + (c, ct) => c.SaveChangesAsync(false, ct), + (c, ct) => c.Products.AsNoTracking().AnyAsync(ct), IsolationLevel.Serializable, - cancellationToken)); + CancellationToken.None)); } private async Task Test_commit_failure_async( @@ -226,11 +208,27 @@ private async Task Test_commit_failure_async( var connection = (TestSqlServerConnection)context.GetService(); connection.CommitFailures.Enqueue(new bool?[] { realFailure }); + Fixture.TestSqlLoggerFactory.Clear(); context.Products.Add(new Product()); await execute(new TestSqlServerRetryingExecutionStrategy(context), context); context.ChangeTracker.AcceptAllChanges(); + var retryMessage = + "A transient exception has been encountered during execution and the operation will be retried after 0ms." + + Environment.NewLine + + "Microsoft.Data.SqlClient.SqlException (0x80131904): Bang!"; + if (realFailure) + { + var logEntry = Fixture.TestSqlLoggerFactory.Log.Single(l => l.Id == CoreEventId.ExecutionStrategyRetrying); + Assert.Contains(retryMessage, logEntry.Message); + Assert.Equal(LogLevel.Information, logEntry.Level); + } + else + { + Assert.Empty(Fixture.TestSqlLoggerFactory.Log.Where(l => l.Id == CoreEventId.ExecutionStrategyRetrying)); + } + Assert.Equal(realFailure ? 3 : 2, connection.OpenCount); } @@ -240,19 +238,9 @@ private async Task Test_commit_failure_async( } } - [ConditionalFact] - public void Does_not_throw_or_retry_on_false_commit_failure_multiple_SaveChanges() - { - Test_commit_failure_multiple_SaveChanges(false); - } - - [ConditionalFact] - public void Retries_on_true_commit_failure_multiple_SaveChanges() - { - Test_commit_failure_multiple_SaveChanges(true); - } - - private void Test_commit_failure_multiple_SaveChanges(bool realFailure) + [ConditionalTheory] + [MemberData(nameof(DataGenerator.GetBoolCombinations), 1, MemberType = typeof(DataGenerator))] + public void Handles_commit_failure_multiple_SaveChanges(bool realFailure) { CleanContext(); @@ -274,9 +262,9 @@ private void Test_commit_failure_multiple_SaveChanges(bool realFailure) context2.Database.UseTransaction(null); context2.Database.UseTransaction(context1.Database.CurrentTransaction.GetDbTransaction()); - c1.SaveChanges(acceptAllChangesOnSuccess: false); + c1.SaveChanges(false); - return context2.SaveChanges(acceptAllChangesOnSuccess: false); + return context2.SaveChanges(false); }, c => c.Products.AsNoTracking().Any()); @@ -292,15 +280,9 @@ private void Test_commit_failure_multiple_SaveChanges(bool realFailure) } [ConditionalTheory] - [InlineData(false, false, false)] - [InlineData(true, false, false)] - [InlineData(false, true, false)] - [InlineData(true, true, false)] - [InlineData(false, false, true)] - [InlineData(true, false, true)] - [InlineData(false, true, true)] - [InlineData(true, true, true)] - public async Task Retries_only_on_true_execution_failure(bool realFailure, bool openConnection, bool async) + [MemberData(nameof(DataGenerator.GetBoolCombinations), 4, MemberType = typeof(DataGenerator))] + public async Task Retries_SaveChanges_on_execution_failure( + bool realFailure, bool externalStrategy, bool openConnection, bool async) { CleanContext(); @@ -331,31 +313,45 @@ public async Task Retries_only_on_true_execution_failure(bool realFailure, bool if (async) { - await new TestSqlServerRetryingExecutionStrategy(context).ExecuteInTransactionAsync( - context, - (c, _) => c.SaveChangesAsync(acceptAllChangesOnSuccess: false), - (c, _) => - { - // This shouldn't be called if SaveChanges failed - Assert.True(false); - return Task.FromResult(false); - }); + if (externalStrategy) + { + await new TestSqlServerRetryingExecutionStrategy(context).ExecuteInTransactionAsync( + context, + (c, ct) => c.SaveChangesAsync(false, ct), + (c, _) => + { + Assert.True(false); + return Task.FromResult(false); + }); + + context.ChangeTracker.AcceptAllChanges(); + } + else + { + await context.SaveChangesAsync(); + } } else { - new TestSqlServerRetryingExecutionStrategy(context).ExecuteInTransaction( - context, - c => c.SaveChanges(acceptAllChangesOnSuccess: false), - c => - { - // This shouldn't be called if SaveChanges failed - Assert.True(false); - return false; - }); + if (externalStrategy) + { + new TestSqlServerRetryingExecutionStrategy(context).ExecuteInTransaction( + context, + c => c.SaveChanges(false), + c => + { + Assert.True(false); + return false; + }); + + context.ChangeTracker.AcceptAllChanges(); + } + else + { + context.SaveChanges(); + } } - context.ChangeTracker.AcceptAllChanges(); - Assert.Equal(2, connection.OpenCount); Assert.Equal(4, connection.ExecutionCount); @@ -386,9 +382,8 @@ public async Task Retries_only_on_true_execution_failure(bool realFailure, bool } [ConditionalTheory] - [InlineData(false)] - //[InlineData(true)] - public async Task Retries_query_on_execution_failure(bool async) + [MemberData(nameof(DataGenerator.GetBoolCombinations), 2, MemberType = typeof(DataGenerator))] + public async Task Retries_query_on_execution_failure(bool externalStrategy, bool async) { CleanContext(); @@ -408,36 +403,100 @@ public async Task Retries_query_on_execution_failure(bool async) Assert.Equal(ConnectionState.Closed, context.Database.GetDbConnection().State); + List list; if (async) { - var list = await new TestSqlServerRetryingExecutionStrategy(context).ExecuteInTransactionAsync( - context, - (c, _) => context.Products.ToListAsync(), - (c, _) => - { - // This shouldn't be called if query failed - Assert.True(false); - return Task.FromResult(false); - }); - - Assert.Equal(2, list.Count); + if (externalStrategy) + { + list = await new TestSqlServerRetryingExecutionStrategy(context) + .ExecuteAsync(context, (c, ct) => c.Products.ToListAsync(ct), null); + } + else + { + list = await context.Products.ToListAsync(); + } } else { - var list = new TestSqlServerRetryingExecutionStrategy(context).ExecuteInTransaction( - context, - c => context.Products.ToList(), - c => - { - // This shouldn't be called if query failed - Assert.True(false); - return false; - }); + if (externalStrategy) + { + list = new TestSqlServerRetryingExecutionStrategy(context) + .Execute(context, c => c.Products.ToList(), null); + } + else + { + list = context.Products.ToList(); + } + } + + Assert.Equal(2, list.Count); + Assert.Equal(1, connection.OpenCount); + Assert.Equal(2, connection.ExecutionCount); + + Assert.Equal(ConnectionState.Closed, context.Database.GetDbConnection().State); + } + } + + [ConditionalTheory] + [MemberData(nameof(DataGenerator.GetBoolCombinations), 2, MemberType = typeof(DataGenerator))] + public async Task Retries_FromSqlRaw_on_execution_failure(bool externalStrategy, bool async) + { + CleanContext(); + + using (var context = CreateContext()) + { + context.Products.Add(new Product()); + context.Products.Add(new Product()); + + context.SaveChanges(); + } + + using (var context = CreateContext()) + { + var connection = (TestSqlServerConnection)context.GetService(); + + connection.ExecutionFailures.Enqueue(new bool?[] { true }); + + Assert.Equal(ConnectionState.Closed, context.Database.GetDbConnection().State); - Assert.Equal(2, list.Count); + List list; + if (async) + { + if (externalStrategy) + { + list = await new TestSqlServerRetryingExecutionStrategy(context) + .ExecuteAsync( + context, (c, ct) => c.Set().FromSqlRaw( + @"SELECT [ID], [name] + FROM [Products]").ToListAsync(ct), null); + } + else + { + list = await context.Set().FromSqlRaw( + @"SELECT [ID], [name] + FROM [Products]").ToListAsync(); + } + } + else + { + if (externalStrategy) + { + list = new TestSqlServerRetryingExecutionStrategy(context) + .Execute( + context, c => c.Set().FromSqlRaw( + @"SELECT [ID], [name] + FROM [Products]").ToList(), null); + } + else + { + list = context.Set().FromSqlRaw( + @"SELECT [ID], [name] + FROM [Products]").ToList(); + } } - Assert.Equal(2, connection.OpenCount); + Assert.Equal(2, list.Count); + Assert.Equal(1, connection.OpenCount); Assert.Equal(2, connection.ExecutionCount); Assert.Equal(ConnectionState.Closed, context.Database.GetDbConnection().State); @@ -445,9 +504,8 @@ public async Task Retries_query_on_execution_failure(bool async) } [ConditionalTheory] - [InlineData(false)] - [InlineData(true)] - public async Task Retries_OpenConnection_on_execution_failure(bool async) + [MemberData(nameof(DataGenerator.GetBoolCombinations), 2, MemberType = typeof(DataGenerator))] + public async Task Retries_OpenConnection_on_execution_failure(bool externalStrategy, bool async) { using (var context = CreateContext()) { @@ -459,15 +517,29 @@ public async Task Retries_OpenConnection_on_execution_failure(bool async) if (async) { - await new TestSqlServerRetryingExecutionStrategy(context).ExecuteAsync( - context, - c => context.Database.OpenConnectionAsync()); + if (externalStrategy) + { + await new TestSqlServerRetryingExecutionStrategy(context).ExecuteAsync( + context, + c => c.Database.OpenConnectionAsync()); + } + else + { + await context.Database.OpenConnectionAsync(); + } } else { - new TestSqlServerRetryingExecutionStrategy(context).Execute( - context, - c => context.Database.OpenConnection()); + if (externalStrategy) + { + new TestSqlServerRetryingExecutionStrategy(context).Execute( + context, + c => c.Database.OpenConnection()); + } + else + { + context.Database.OpenConnection(); + } } Assert.Equal(2, connection.OpenCount); @@ -539,7 +611,7 @@ public void Verification_is_retried_using_same_retry_limit() new TestSqlServerRetryingExecutionStrategy(context, TimeSpan.FromMilliseconds(100)) .ExecuteInTransaction( context, - c => c.SaveChanges(acceptAllChangesOnSuccess: false), + c => c.SaveChanges(false), c => false)); context.ChangeTracker.AcceptAllChanges(); @@ -594,12 +666,10 @@ public class ExecutionStrategyFixture : SharedStoreFixtureBase protected override Type ContextType { get; } = typeof(ExecutionStrategyContext); protected override IServiceCollection AddServices(IServiceCollection serviceCollection) - { - return base.AddServices(serviceCollection) + => base.AddServices(serviceCollection) .AddSingleton() .AddScoped() .AddSingleton(); - } public override DbContextOptionsBuilder AddOptions(DbContextOptionsBuilder builder) { diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/AsyncSimpleQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/AsyncSimpleQuerySqlServerTest.cs index 6be9d5d268d..d7566603f4e 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/AsyncSimpleQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/AsyncSimpleQuerySqlServerTest.cs @@ -124,22 +124,10 @@ public async Task Concurrent_async_queries_when_raw_query() { while (await asyncEnumerator.MoveNextAsync()) { - if (!context.GetService().IsMultipleActiveResultSetsEnabled) - { - // Not supported, we could make it work by triggering buffering - // from RelationalCommand. - - await Assert.ThrowsAsync( - () => context.Database.ExecuteSqlRawAsync( - "[dbo].[CustOrderHist] @CustomerID = {0}", - asyncEnumerator.Current.CustomerID)); - } - else - { - await context.Database.ExecuteSqlRawAsync( - "[dbo].[CustOrderHist] @CustomerID = {0}", - asyncEnumerator.Current.CustomerID); - } + // Outer query is buffered by default + await context.Database.ExecuteSqlRawAsync( + "[dbo].[CustOrderHist] @CustomerID = {0}", + asyncEnumerator.Current.CustomerID); } } } diff --git a/test/EFCore.SqlServer.FunctionalTests/TestUtilities/TestSqlServerRetryingExecutionStrategy.cs b/test/EFCore.SqlServer.FunctionalTests/TestUtilities/TestSqlServerRetryingExecutionStrategy.cs index c1b1477a3ee..9ee9e32f08f 100644 --- a/test/EFCore.SqlServer.FunctionalTests/TestUtilities/TestSqlServerRetryingExecutionStrategy.cs +++ b/test/EFCore.SqlServer.FunctionalTests/TestUtilities/TestSqlServerRetryingExecutionStrategy.cs @@ -66,9 +66,7 @@ protected override bool ShouldRetryOn(Exception exception) } return exception is InvalidOperationException invalidOperationException - && invalidOperationException.Message == "Internal .Net Framework Data Provider error 6." - ? true - : false; + && invalidOperationException.Message == "Internal .Net Framework Data Provider error 6."; } public new virtual TimeSpan? GetNextDelay(Exception lastException) diff --git a/test/EFCore.SqlServer.Tests/Storage/SqlServerRetryingExecutionStrategyTests.cs b/test/EFCore.SqlServer.Tests/Storage/SqlServerRetryingExecutionStrategyTests.cs index 195eeb26d8d..e6b4dbd267b 100644 --- a/test/EFCore.SqlServer.Tests/Storage/SqlServerRetryingExecutionStrategyTests.cs +++ b/test/EFCore.SqlServer.Tests/Storage/SqlServerRetryingExecutionStrategyTests.cs @@ -6,6 +6,7 @@ using Microsoft.EntityFrameworkCore.TestUtilities; using Xunit; +// ReSharper disable InconsistentNaming namespace Microsoft.EntityFrameworkCore.Storage { public class SqlServerRetryingExecutionStrategyTests @@ -14,13 +15,13 @@ public class SqlServerRetryingExecutionStrategyTests public void GetNextDelay_returns_shorter_delay_for_InMemory_transient_errors() { var strategy = new TestSqlServerRetryingExecutionStrategy(CreateContext()); - var inMemoryOltpError = SqlExceptionFactory.CreateSqlException(41302); + var inMemoryError = SqlExceptionFactory.CreateSqlException(41302); var delays = new List(); - var delay = strategy.GetNextDelay(inMemoryOltpError); + var delay = strategy.GetNextDelay(inMemoryError); while (delay != null) { delays.Add(delay.Value); - delay = strategy.GetNextDelay(inMemoryOltpError); + delay = strategy.GetNextDelay(inMemoryError); } var expectedDelays = new List @@ -39,7 +40,7 @@ public void GetNextDelay_returns_shorter_delay_for_InMemory_transient_errors() Assert.True( Math.Abs((delays[i] - expectedDelays[i]).TotalMilliseconds) <= expectedDelays[i].TotalMilliseconds * 0.1 + 1, - string.Format("Expected: {0}; Actual: {1}", expectedDelays[i], delays[i])); + $"Expected: {expectedDelays[i]}; Actual: {delays[i]}"); } } diff --git a/test/EFCore.Tests/Query/CompiledQueryCacheKeyGeneratorDependenciesTest.cs b/test/EFCore.Tests/Query/CompiledQueryCacheKeyGeneratorDependenciesTest.cs index 2a764195436..ddd60403e24 100644 --- a/test/EFCore.Tests/Query/CompiledQueryCacheKeyGeneratorDependenciesTest.cs +++ b/test/EFCore.Tests/Query/CompiledQueryCacheKeyGeneratorDependenciesTest.cs @@ -4,6 +4,7 @@ using Microsoft.EntityFrameworkCore.TestUtilities; using Xunit; +// ReSharper disable InconsistentNaming namespace Microsoft.EntityFrameworkCore.Query { public class CompiledQueryCacheKeyGeneratorDependenciesTest diff --git a/test/EFCore.Tests/Query/QueryCompilationContextDependenciesTest.cs b/test/EFCore.Tests/Query/QueryCompilationContextDependenciesTest.cs index 0f54c5a15f1..83cfe5c25ca 100644 --- a/test/EFCore.Tests/Query/QueryCompilationContextDependenciesTest.cs +++ b/test/EFCore.Tests/Query/QueryCompilationContextDependenciesTest.cs @@ -4,6 +4,7 @@ using Microsoft.EntityFrameworkCore.TestUtilities; using Xunit; +// ReSharper disable InconsistentNaming namespace Microsoft.EntityFrameworkCore.Query { public class QueryCompilationContextDependenciesTest diff --git a/test/EFCore.Tests/Storage/ExecutionStrategyTest.cs b/test/EFCore.Tests/Storage/ExecutionStrategyTest.cs index 6cd9632e233..718d2d68fc5 100644 --- a/test/EFCore.Tests/Storage/ExecutionStrategyTest.cs +++ b/test/EFCore.Tests/Storage/ExecutionStrategyTest.cs @@ -65,7 +65,7 @@ public void GetNextDelay_returns_the_expected_default_sequence() Assert.True( Math.Abs((delays[i] - expectedDelays[i]).TotalMilliseconds) <= expectedDelays[i].TotalMilliseconds * 0.1 + 1, - string.Format("Expected: {0}; Actual: {1}", expectedDelays[i], delays[i])); + $"Expected: {expectedDelays[i]}; Actual: {delays[i]}"); } } @@ -281,18 +281,18 @@ private void Execute_retries_until_successful(Action e.Execute(() => f())); + Execute_retries_until_not_retriable_exception_is_thrown((e, f) => e.Execute(() => f())); } [ConditionalFact] - public void Execute_Func_retries_until_not_retrieable_exception_is_thrown() + public void Execute_Func_retries_until_not_retriable_exception_is_thrown() { - Execute_retries_until_not_retrieable_exception_is_thrown((e, f) => e.Execute(f)); + Execute_retries_until_not_retriable_exception_is_thrown((e, f) => e.Execute(f)); } - private void Execute_retries_until_not_retrieable_exception_is_thrown(Action> execute) + private void Execute_retries_until_not_retriable_exception_is_thrown(Action> execute) { var executionStrategyMock = new TestExecutionStrategy( Context,