Skip to content

Commit

Permalink
Merge pull request from GHSA-x9vc-6hfv-hg8c
Browse files Browse the repository at this point in the history
  • Loading branch information
NinoFloris committed May 9, 2024
1 parent 12e5ec2 commit f7e7ead
Show file tree
Hide file tree
Showing 7 changed files with 272 additions and 30 deletions.
56 changes: 38 additions & 18 deletions src/Npgsql/Internal/NpgsqlConnector.FrontendMessages.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ internal Task WriteDescribe(StatementOrPortal statementOrPortal, byte[] asciiNam
(asciiName.Length + 1); // Statement/portal name

var writeBuffer = WriteBuffer;
writeBuffer.StartMessage(len);
if (writeBuffer.WriteSpaceLeft < len)
return FlushAndWrite(len, statementOrPortal, asciiName, async, cancellationToken);

Expand Down Expand Up @@ -48,6 +49,7 @@ internal Task WriteSync(bool async, CancellationToken cancellationToken = defaul
sizeof(int); // Length

var writeBuffer = WriteBuffer;
writeBuffer.StartMessage(len);
if (writeBuffer.WriteSpaceLeft < len)
return FlushAndWrite(async, cancellationToken);

Expand Down Expand Up @@ -79,6 +81,7 @@ internal Task WriteExecute(int maxRows, bool async, CancellationToken cancellati
sizeof(int); // Max number of rows

var writeBuffer = WriteBuffer;
writeBuffer.StartMessage(len);
if (writeBuffer.WriteSpaceLeft < len)
return FlushAndWrite(maxRows, async, cancellationToken);

Expand Down Expand Up @@ -118,9 +121,6 @@ internal async Task WriteParse(string sql, byte[] asciiName, List<NpgsqlParamete
}

var writeBuffer = WriteBuffer;
if (writeBuffer.WriteSpaceLeft < 1 + 4 + asciiName.Length + 1)
await Flush(async, cancellationToken).ConfigureAwait(false);

var messageLength =
sizeof(byte) + // Message code
sizeof(int) + // Length
Expand All @@ -130,9 +130,14 @@ internal async Task WriteParse(string sql, byte[] asciiName, List<NpgsqlParamete
sizeof(ushort) + // Number of parameters
inputParameters.Count * sizeof(int); // Parameter OIDs

writeBuffer.WriteByte(FrontendMessageCode.Parse);
writeBuffer.WriteInt32(messageLength - 1);
writeBuffer.WriteNullTerminatedString(asciiName);

WriteBuffer.StartMessage(messageLength);
if (WriteBuffer.WriteSpaceLeft < 1 + 4 + asciiName.Length + 1)
await Flush(async, cancellationToken).ConfigureAwait(false);

WriteBuffer.WriteByte(FrontendMessageCode.Parse);
WriteBuffer.WriteInt32(messageLength - 1);
WriteBuffer.WriteNullTerminatedString(asciiName);

await writeBuffer.WriteString(sql, queryByteLen, async, cancellationToken).ConfigureAwait(false);

Expand Down Expand Up @@ -171,12 +176,6 @@ internal async Task WriteParse(string sql, byte[] asciiName, List<NpgsqlParamete
sizeof(ushort); // Number of parameter format codes that follow

var writeBuffer = WriteBuffer;
if (writeBuffer.WriteSpaceLeft < headerLength)
{
Debug.Assert(writeBuffer.Size >= headerLength, "Write buffer too small for Bind header");
await Flush(async, cancellationToken).ConfigureAwait(false);
}

var formatCodesSum = 0;
var paramsLength = 0;
for (var paramIndex = 0; paramIndex < parameters.Count; paramIndex++)
Expand All @@ -197,8 +196,15 @@ internal async Task WriteParse(string sql, byte[] asciiName, List<NpgsqlParamete
sizeof(short) + // Number of result format codes
sizeof(short) * (unknownResultTypeList?.Length ?? 1); // Result format codes

writeBuffer.WriteByte(FrontendMessageCode.Bind);
writeBuffer.WriteInt32(messageLength - 1);
WriteBuffer.StartMessage(messageLength);
if (WriteBuffer.WriteSpaceLeft < headerLength)
{
Debug.Assert(WriteBuffer.Size >= headerLength, "Write buffer too small for Bind header");
await Flush(async, cancellationToken).ConfigureAwait(false);
}

WriteBuffer.WriteByte(FrontendMessageCode.Bind);
WriteBuffer.WriteInt32(messageLength - 1);
Debug.Assert(portal == string.Empty);
writeBuffer.WriteByte(0); // Portal is always empty

Expand Down Expand Up @@ -269,6 +275,7 @@ internal Task WriteClose(StatementOrPortal type, byte[] asciiName, bool async, C
asciiName.Length + sizeof(byte); // Statement or portal name plus null terminator

var writeBuffer = WriteBuffer;
writeBuffer.StartMessage(len);
if (writeBuffer.WriteSpaceLeft < len)
return FlushAndWrite(len, type, asciiName, async, cancellationToken);

Expand Down Expand Up @@ -296,14 +303,17 @@ internal async Task WriteQuery(string sql, bool async, CancellationToken cancell
{
var queryByteLen = TextEncoding.GetByteCount(sql);

var len = sizeof(byte) +
sizeof(int) + // Message length (including self excluding code)
queryByteLen + // Query byte length
sizeof(byte);

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < 1 + 4)
await Flush(async, cancellationToken).ConfigureAwait(false);

WriteBuffer.WriteByte(FrontendMessageCode.Query);
WriteBuffer.WriteInt32(
sizeof(int) + // Message length (including self excluding code)
queryByteLen + // Query byte length
sizeof(byte)); // Null terminator
WriteBuffer.WriteInt32(len - 1);

await WriteBuffer.WriteString(sql, queryByteLen, async, cancellationToken).ConfigureAwait(false);
if (WriteBuffer.WriteSpaceLeft < 1)
Expand All @@ -316,6 +326,7 @@ internal async Task WriteCopyDone(bool async, CancellationToken cancellationToke
const int len = sizeof(byte) + // Message code
sizeof(int); // Length

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < len)
await Flush(async, cancellationToken).ConfigureAwait(false);

Expand All @@ -331,6 +342,7 @@ internal async Task WriteCopyFail(bool async, CancellationToken cancellationToke
sizeof(int) + // Length
sizeof(byte); // Error message is always empty (only a null terminator)

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < len)
await Flush(async, cancellationToken).ConfigureAwait(false);

Expand All @@ -348,6 +360,7 @@ internal void WriteCancelRequest(int backendProcessId, int backendSecretKey)

Debug.Assert(backendProcessId != 0);

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < len)
Flush(false).GetAwaiter().GetResult();

Expand All @@ -362,6 +375,7 @@ internal void WriteTerminate()
const int len = sizeof(byte) + // Message code
sizeof(int); // Length

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < len)
Flush(false).GetAwaiter().GetResult();

Expand All @@ -374,6 +388,7 @@ internal void WriteSslRequest()
const int len = sizeof(int) + // Length
sizeof(int); // SSL request code

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < len)
Flush(false).GetAwaiter().GetResult();

Expand All @@ -394,6 +409,7 @@ internal void WriteStartup(Dictionary<string, string> parameters)
NpgsqlWriteBuffer.UTF8Encoding.GetByteCount(kvp.Value) + 1;

// Should really never happen, just in case
WriteBuffer.StartMessage(len);
if (len > WriteBuffer.Size)
throw new Exception("Startup message bigger than buffer");

Expand All @@ -417,8 +433,10 @@ internal void WriteStartup(Dictionary<string, string> parameters)

internal async Task WritePassword(byte[] payload, int offset, int count, bool async, CancellationToken cancellationToken = default)
{
WriteBuffer.StartMessage(sizeof(byte) + sizeof(int) + count);
if (WriteBuffer.WriteSpaceLeft < sizeof(byte) + sizeof(int))
await WriteBuffer.Flush(async, cancellationToken).ConfigureAwait(false);

WriteBuffer.WriteByte(FrontendMessageCode.Password);
WriteBuffer.WriteInt32(sizeof(int) + count);

Expand All @@ -441,6 +459,7 @@ internal async Task WriteSASLInitialResponse(string mechanism, byte[] initialRes
sizeof(int) + // Initial response length
(initialResponse?.Length ?? 0); // Initial response payload

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < len)
await WriteBuffer.Flush(async, cancellationToken).ConfigureAwait(false);

Expand All @@ -464,6 +483,7 @@ internal async Task WriteSASLInitialResponse(string mechanism, byte[] initialRes

internal Task WritePregenerated(byte[] data, bool async = false, CancellationToken cancellationToken = default)
{
WriteBuffer.StartMessage(data.Length);
if (WriteBuffer.WriteSpaceLeft < data.Length)
return FlushAndWrite(data, async, cancellationToken);

Expand Down
61 changes: 59 additions & 2 deletions src/Npgsql/Internal/NpgsqlWriteBuffer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ sealed class NpgsqlWriteBuffer : IDisposable
internal Stream Underlying { private get; set; }

readonly Socket? _underlyingSocket;
internal bool MessageLengthValidation { get; set; } = true;

readonly ResettableCancellationTokenSource _timeoutCts;
readonly MetricsReporter? _metricsReporter;

Expand Down Expand Up @@ -76,6 +78,9 @@ internal PgWriter GetWriter(NpgsqlDatabaseInfo typeCatalog, FlushMode flushMode

internal int WritePosition;

int _messageBytesFlushed;
int? _messageLength;

bool _disposed;
readonly PgWriter _pgWriter;

Expand Down Expand Up @@ -131,6 +136,8 @@ public async Task Flush(bool async, CancellationToken cancellationToken = defaul
WritePosition = pos;
} else if (WritePosition == 0)
return;
else
AdvanceMessageBytesFlushed(WritePosition);

var finalCt = async && Timeout > TimeSpan.Zero
? _timeoutCts.Start(cancellationToken)
Expand Down Expand Up @@ -197,15 +204,19 @@ internal void DirectWrite(ReadOnlySpan<byte> buffer)
Debug.Assert(WritePosition == 5);

WritePosition = 1;
WriteInt32(buffer.Length + 4);
WriteInt32(checked(buffer.Length + 4));
WritePosition = 5;
_copyMode = false;
StartMessage(5);
Flush();
_copyMode = true;
WriteCopyDataHeader(); // And ready the buffer after the direct write completes
}
else
{
Debug.Assert(WritePosition == 0);
AdvanceMessageBytesFlushed(buffer.Length);
}

try
{
Expand All @@ -228,15 +239,19 @@ internal async Task DirectWrite(ReadOnlyMemory<byte> memory, bool async, Cancell
Debug.Assert(WritePosition == 5);

WritePosition = 1;
WriteInt32(memory.Length + 4);
WriteInt32(checked(memory.Length + 4));
WritePosition = 5;
_copyMode = false;
StartMessage(5);
await Flush(async, cancellationToken).ConfigureAwait(false);
_copyMode = true;
WriteCopyDataHeader(); // And ready the buffer after the direct write completes
}
else
{
Debug.Assert(WritePosition == 0);
AdvanceMessageBytesFlushed(memory.Length);
}

try
{
Expand Down Expand Up @@ -534,9 +549,51 @@ public void Dispose()

#region Misc

internal void StartMessage(int messageLength)
{
if (!MessageLengthValidation)
return;

if (_messageLength is not null && _messageBytesFlushed != _messageLength && WritePosition != -_messageBytesFlushed + _messageLength)
Throw();

// Add negative WritePosition to compensate for previous message(s) written without flushing.
_messageBytesFlushed = -WritePosition;
_messageLength = messageLength;

void Throw()
{
throw Connector.Break(new OverflowException("Did not write the amount of bytes the message length specified"));
}
}

void AdvanceMessageBytesFlushed(int count)
{
if (!MessageLengthValidation)
return;

if (count < 0 || _messageLength is null || (long)_messageBytesFlushed + count > _messageLength)
Throw();

_messageBytesFlushed += count;

void Throw()
{
if (count < 0)
throw new ArgumentOutOfRangeException(nameof(count), "Can't advance by a negative count");

if (_messageLength is null)
throw Connector.Break(new InvalidOperationException("No message was started"));

if ((long)_messageBytesFlushed + count > _messageLength)
throw Connector.Break(new OverflowException("Tried to write more bytes than the message length specified"));
}
}

internal void Clear()
{
WritePosition = 0;
_messageLength = null;
}

/// <summary>
Expand Down
11 changes: 1 addition & 10 deletions src/Npgsql/NpgsqlTransaction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -212,16 +212,7 @@ public override void Save(string name)

// Note: savepoint names are PostgreSQL identifiers, and so limited by default to 63 characters.
// Since we are prepending, we assume below that the statement will always fit in the buffer.
_connector.WriteBuffer.WriteByte(FrontendMessageCode.Query);
_connector.WriteBuffer.WriteInt32(
sizeof(int) + // Message length (including self excluding code)
_connector.TextEncoding.GetByteCount("SAVEPOINT ") +
_connector.TextEncoding.GetByteCount(name) +
sizeof(byte)); // Null terminator

_connector.WriteBuffer.WriteString("SAVEPOINT ");
_connector.WriteBuffer.WriteString(name);
_connector.WriteBuffer.WriteByte(0);
_connector.WriteQuery("SAVEPOINT " + name, async: false).GetAwaiter().GetResult();

_connector.PendingPrependedResponses += 2;
}
Expand Down

0 comments on commit f7e7ead

Please sign in to comment.