diff --git a/src/Npgsql/NpgsqlCommand.cs b/src/Npgsql/NpgsqlCommand.cs index 88267749af..a159812bb6 100644 --- a/src/Npgsql/NpgsqlCommand.cs +++ b/src/Npgsql/NpgsqlCommand.cs @@ -660,6 +660,7 @@ Task Prepare(bool async, CancellationToken cancellationToken = default) ProcessRawQuery(connector.SqlQueryParser, connector.UseConformingStrings, batchCommand); needToPrepare = batchCommand.ExplicitPrepare(connector) || needToPrepare; + batchCommand.ConnectorPreparedOn = connector; } if (logger.IsEnabled(LogLevel.Debug) && needToPrepare) diff --git a/test/Npgsql.Tests/PrepareTests.cs b/test/Npgsql.Tests/PrepareTests.cs index bb539e6e04..a0b9c96a39 100644 --- a/test/Npgsql.Tests/PrepareTests.cs +++ b/test/Npgsql.Tests/PrepareTests.cs @@ -5,6 +5,8 @@ using System.Text; using System.Threading; using System.Threading.Tasks; +using Npgsql.BackendMessages; +using Npgsql.Tests.Support; using NpgsqlTypes; using NUnit.Framework; using static Npgsql.Tests.TestUtil; @@ -13,6 +15,8 @@ namespace Npgsql.Tests; public class PrepareTests: TestBase { + const int Int4Oid = 23; + [Test] public void Basic() { @@ -793,6 +797,127 @@ public async Task Explicit_prepare_unprepare_many_queries() await cmd.UnprepareAsync(); } + [Test] + public async Task Explicitly_prepared_batch_sends_prepared_queries() + { + await using var postmaster = PgPostmasterMock.Start(ConnectionString); + await using var dataSource = NpgsqlDataSource.Create(postmaster.ConnectionString); + + await using var conn = await dataSource.OpenConnectionAsync(); + var server = await postmaster.WaitForServerConnection(); + + await using var batch = new NpgsqlBatch(conn) + { + BatchCommands = { new("SELECT 1"), new("SELECT 2") } + }; + + var prepareTask = batch.PrepareAsync(); + + await server.ExpectMessages( + FrontendMessageCode.Parse, FrontendMessageCode.Describe, + FrontendMessageCode.Parse, FrontendMessageCode.Describe, + FrontendMessageCode.Sync); + + await server + .WriteParseComplete() + .WriteParameterDescription(new FieldDescription(Int4Oid)) + .WriteRowDescription(new FieldDescription(Int4Oid)) + .WriteParseComplete() + .WriteParameterDescription(new FieldDescription(Int4Oid)) + .WriteRowDescription(new FieldDescription(Int4Oid)) + .WriteReadyForQuery() + .FlushAsync(); + + await prepareTask; + + for (var i = 0; i < 2; i++) + await ExecutePreparedBatch(batch, server); + + async Task ExecutePreparedBatch(NpgsqlBatch batch, PgServerMock server) + { + var executeBatchTask = batch.ExecuteNonQueryAsync(); + + await server.ExpectMessages( + FrontendMessageCode.Bind, FrontendMessageCode.Execute, + FrontendMessageCode.Bind, FrontendMessageCode.Execute, + FrontendMessageCode.Sync); + + await server + .WriteBindComplete() + .WriteCommandComplete() + .WriteBindComplete() + .WriteCommandComplete() + .WriteReadyForQuery() + .FlushAsync(); + + await executeBatchTask; + } + } + + [Test] + public async Task Auto_prepared_batch_sends_prepared_queries() + { + var csb = new NpgsqlConnectionStringBuilder(ConnectionString) + { + AutoPrepareMinUsages = 1, + MaxAutoPrepare = 10 + }; + await using var postmaster = PgPostmasterMock.Start(csb.ConnectionString); + await using var dataSource = NpgsqlDataSource.Create(postmaster.ConnectionString); + + await using var conn = await dataSource.OpenConnectionAsync(); + var server = await postmaster.WaitForServerConnection(); + + await using var batch = new NpgsqlBatch(conn) + { + BatchCommands = { new("SELECT 1"), new("SELECT 2") } + }; + + var firstBatchExecuteTask = batch.ExecuteNonQueryAsync(); + + await server.ExpectMessages( + FrontendMessageCode.Parse, FrontendMessageCode.Bind, FrontendMessageCode.Describe, FrontendMessageCode.Execute, + FrontendMessageCode.Parse, FrontendMessageCode.Bind, FrontendMessageCode.Describe, FrontendMessageCode.Execute, + FrontendMessageCode.Sync); + + await server + .WriteParseComplete() + .WriteBindComplete() + .WriteRowDescription(new FieldDescription(Int4Oid)) + .WriteCommandComplete() + .WriteParseComplete() + .WriteBindComplete() + .WriteRowDescription(new FieldDescription(Int4Oid)) + .WriteCommandComplete() + .WriteReadyForQuery() + .FlushAsync(); + + await firstBatchExecuteTask; + + for (var i = 0; i < 2; i++) + await ExecutePreparedBatch(batch, server); + + async Task ExecutePreparedBatch(NpgsqlBatch batch, PgServerMock server) + { + var executeBatchTask = batch.ExecuteNonQueryAsync(); + + await server.ExpectMessages( + FrontendMessageCode.Bind, FrontendMessageCode.Execute, + FrontendMessageCode.Bind, FrontendMessageCode.Execute, + FrontendMessageCode.Sync); + + await server + .WriteBindComplete() + .WriteCommandComplete() + .WriteBindComplete() + .WriteCommandComplete() + .WriteReadyForQuery() + .FlushAsync(); + + await executeBatchTask; + } + } + NpgsqlConnection OpenConnectionAndUnprepare(string? connectionString = null) { var conn = OpenConnection(connectionString); diff --git a/test/Npgsql.Tests/Support/PgServerMock.cs b/test/Npgsql.Tests/Support/PgServerMock.cs index 6a83cc0248..c9b61d8226 100644 --- a/test/Npgsql.Tests/Support/PgServerMock.cs +++ b/test/Npgsql.Tests/Support/PgServerMock.cs @@ -225,6 +225,20 @@ internal PgServerMock WriteRowDescription(params FieldDescription[] fields) return this; } + internal PgServerMock WriteParameterDescription(params FieldDescription[] fields) + { + CheckDisposed(); + + _writeBuffer.WriteByte((byte)BackendMessageCode.ParameterDescription); + _writeBuffer.WriteInt32(1 + 4 + 2 + fields.Length * 4); + _writeBuffer.WriteUInt16((ushort)fields.Length); + + foreach (var field in fields) + _writeBuffer.WriteUInt32(field.TypeOID); + + return this; + } + internal PgServerMock WriteNoData() { CheckDisposed();