diff --git a/src/MongoDB.Driver/BulkWriteResult.cs b/src/MongoDB.Driver/BulkWriteResult.cs index 5fa2b3f76c..9383280f96 100644 --- a/src/MongoDB.Driver/BulkWriteResult.cs +++ b/src/MongoDB.Driver/BulkWriteResult.cs @@ -104,10 +104,10 @@ public abstract class BulkWriteResult : BulkWriteResult /// The processed requests. protected BulkWriteResult( int requestCount, - IEnumerable> processedRequests) + IReadOnlyList> processedRequests) : base(requestCount) { - _processedRequests = processedRequests.ToList(); + _processedRequests = processedRequests; } // public properties @@ -130,16 +130,16 @@ internal static BulkWriteResult FromCore(Core.Operations.BulkWriteOpe result.DeletedCount, result.InsertedCount, result.IsModifiedCountAvailable ? (long?)result.ModifiedCount : null, - result.ProcessedRequests.Select(r => WriteModel.FromCore(r)), - result.Upserts.Select(u => BulkWriteUpsert.FromCore(u))); + result.ProcessedRequests.Select(WriteModel.FromCore).ToArray(), + result.Upserts.Select(BulkWriteUpsert.FromCore)); } return new Unacknowledged( result.RequestCount, - result.ProcessedRequests.Select(r => WriteModel.FromCore(r))); + result.ProcessedRequests.Select(WriteModel.FromCore).ToArray()); } - internal static BulkWriteResult FromCore(Core.Operations.BulkWriteOperationResult result, IEnumerable> requests) + internal static BulkWriteResult FromCore(Core.Operations.BulkWriteOperationResult result, IReadOnlyList> requests) { if (result.IsAcknowledged) { @@ -150,7 +150,7 @@ internal static BulkWriteResult FromCore(Core.Operations.BulkWriteOpe result.InsertedCount, result.IsModifiedCountAvailable ? (long?)result.ModifiedCount : null, requests, - result.Upserts.Select(u => BulkWriteUpsert.FromCore(u))); + result.Upserts.Select(BulkWriteUpsert.FromCore)); } return new Unacknowledged( @@ -174,7 +174,7 @@ public class Acknowledged : BulkWriteResult // constructors /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// /// The request count. /// The matched count. @@ -189,7 +189,7 @@ public class Acknowledged : BulkWriteResult long deletedCount, long insertedCount, long? modifiedCount, - IEnumerable> processedRequests, + IReadOnlyList> processedRequests, IEnumerable upserts) : base(requestCount, processedRequests) { @@ -259,13 +259,13 @@ public class Unacknowledged : BulkWriteResult { // constructors /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// /// The request count. /// The processed requests. public Unacknowledged( int requestCount, - IEnumerable> processedRequests) + IReadOnlyList> processedRequests) : base(requestCount, processedRequests) { } diff --git a/src/MongoDB.Driver/MongoBulkWriteException.cs b/src/MongoDB.Driver/MongoBulkWriteException.cs index e05cb891c5..91799a4f56 100644 --- a/src/MongoDB.Driver/MongoBulkWriteException.cs +++ b/src/MongoDB.Driver/MongoBulkWriteException.cs @@ -234,7 +234,7 @@ internal static MongoBulkWriteException FromCore(MongoBulkWriteOperat return new MongoBulkWriteException( ex.ConnectionId, - BulkWriteResult.FromCore(ex.Result, processedRequests), + BulkWriteResult.FromCore(ex.Result, processedRequests.ToArray()), ex.WriteErrors.Select(e => BulkWriteError.FromCore(e)), WriteConcernError.FromCore(ex.WriteConcernError), unprocessedRequests); diff --git a/src/MongoDB.Driver/MongoCollectionImpl.cs b/src/MongoDB.Driver/MongoCollectionImpl.cs index 26a553f612..c83a334831 100644 --- a/src/MongoDB.Driver/MongoCollectionImpl.cs +++ b/src/MongoDB.Driver/MongoCollectionImpl.cs @@ -218,27 +218,30 @@ public override BulkWriteResult BulkWrite(IEnumerable BulkWrite(IClientSessionHandle session, IEnumerable> requests, BulkWriteOptions options, CancellationToken cancellationToken = default(CancellationToken)) { Ensure.IsNotNull(session, nameof(session)); - Ensure.IsNotNull(requests, nameof(requests)); - if (!requests.Any()) + Ensure.IsNotNull((object)requests, nameof(requests)); + + var requestsArray = requests.ToArray(); + if (requestsArray.Length == 0) { - throw new ArgumentException("Must contain at least 1 request.", "requests"); + throw new ArgumentException("Must contain at least 1 request.", nameof(requests)); } - foreach (var request in requests) + + foreach (var request in requestsArray) { request.ThrowIfNotValid(); } options = options ?? new BulkWriteOptions(); - var operation = CreateBulkWriteOperation(session, requests, options); + var operation = CreateBulkWriteOperation(session, requestsArray, options); try { var result = ExecuteWriteOperation(session, operation, cancellationToken); - return BulkWriteResult.FromCore(result, requests); + return BulkWriteResult.FromCore(result, requestsArray); } catch (MongoBulkWriteOperationException ex) { - throw MongoBulkWriteException.FromCore(ex, requests.ToList()); + throw MongoBulkWriteException.FromCore(ex, requestsArray); } } @@ -250,27 +253,30 @@ public override Task> BulkWriteAsync(IEnumerable> BulkWriteAsync(IClientSessionHandle session, IEnumerable> requests, BulkWriteOptions options, CancellationToken cancellationToken = default(CancellationToken)) { Ensure.IsNotNull(session, nameof(session)); - Ensure.IsNotNull(requests, nameof(requests)); - if (!requests.Any()) + Ensure.IsNotNull((object)requests, nameof(requests)); + + var requestsArray = requests.ToArray(); + if (requestsArray.Length == 0) { - throw new ArgumentException("Must contain at least 1 request.", "requests"); + throw new ArgumentException("Must contain at least 1 request.", nameof(requests)); } - foreach (var request in requests) + + foreach (var request in requestsArray) { request.ThrowIfNotValid(); } options = options ?? new BulkWriteOptions(); - var operation = CreateBulkWriteOperation(session, requests, options); + var operation = CreateBulkWriteOperation(session, requestsArray, options); try { var result = await ExecuteWriteOperationAsync(session, operation, cancellationToken).ConfigureAwait(false); - return BulkWriteResult.FromCore(result, requests); + return BulkWriteResult.FromCore(result, requestsArray); } catch (MongoBulkWriteOperationException ex) { - throw MongoBulkWriteException.FromCore(ex, requests.ToList()); + throw MongoBulkWriteException.FromCore(ex, requestsArray); } } diff --git a/tests/MongoDB.Driver.Tests/MongoCollectionImplTests.cs b/tests/MongoDB.Driver.Tests/MongoCollectionImplTests.cs index 776995e074..8c0eab1032 100644 --- a/tests/MongoDB.Driver.Tests/MongoCollectionImplTests.cs +++ b/tests/MongoDB.Driver.Tests/MongoCollectionImplTests.cs @@ -19,6 +19,7 @@ using System.Linq.Expressions; using System.Net; using System.Threading; +using System.Threading.Tasks; using FluentAssertions; using MongoDB.Bson; using MongoDB.Bson.Serialization; @@ -31,6 +32,7 @@ using MongoDB.Driver.Core.Operations; using MongoDB.Driver.Core.Servers; using MongoDB.Driver.Core.TestHelpers.XunitExtensions; +using MongoDB.Driver.TestHelpers; using MongoDB.Driver.Tests; using Moq; using Xunit; @@ -442,6 +444,37 @@ public void Settings_should_be_set() exception.Should().BeOfType(); } + [Theory] + [ParameterAttributeData] + public async Task BulkWrite_should_enumerate_requests_once([Values(false, true)] bool async) + { + var subject = CreateSubject(); + var document = new BsonDocument("_id", 1).Add("a", 1); + var requests = new WriteModel[] + { + new InsertOneModel(document) + }; + var processedRequest = new InsertRequest(document) { CorrelationId = 0 }; + var operationResult = new BulkWriteOperationResult.Acknowledged( + requestCount: 1, + matchedCount: 0, + deletedCount: 0, + insertedCount: 1, + modifiedCount: 0, + processedRequests: new[] { processedRequest }, + upserts: new List()); + _operationExecutor.EnqueueResult(operationResult); + var wrappedRequests = new Mock>>(); + wrappedRequests.Setup(e => e.GetEnumerator()).Returns(((IEnumerable>)requests).GetEnumerator()); + + var result = async ? await subject.BulkWriteAsync(wrappedRequests.Object) : subject.BulkWrite(wrappedRequests.Object); + + wrappedRequests.Verify(e => e.GetEnumerator(), Times.Once); + result.Should().NotBeNull(); + result.RequestCount.Should().Be(1); + result.ProcessedRequests.ShouldBeEquivalentTo(requests); + } + [Theory] [ParameterAttributeData] public void BulkWrite_should_execute_a_BulkMixedWriteOperation(