Skip to content

Commit

Permalink
Merge pull request from GHSA-x9vc-6hfv-hg8c
Browse files Browse the repository at this point in the history
(cherry picked from commit 4cce03ae2b0ab61f33d129a4e92620b5964edce5)

Co-authored-by: Nino Floris <mail@ninofloris.com>
  • Loading branch information
roji and NinoFloris committed May 9, 2024
1 parent 72610fa commit 703d9af
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 27 deletions.
47 changes: 33 additions & 14 deletions src/Npgsql/Internal/NpgsqlConnector.FrontendMessages.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ internal Task WriteDescribe(StatementOrPortal statementOrPortal, string name, bo
sizeof(byte) + // Statement or portal
(name.Length + 1); // Statement/portal name

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < len)
return FlushAndWrite(len, statementOrPortal, name, async, cancellationToken);

Expand Down Expand Up @@ -47,6 +48,7 @@ internal Task WriteSync(bool async, CancellationToken cancellationToken = defaul
const int len = sizeof(byte) + // Message code
sizeof(int); // Length

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < len)
return FlushAndWrite(async, cancellationToken);

Expand Down Expand Up @@ -76,6 +78,7 @@ internal Task WriteExecute(int maxRows, bool async, CancellationToken cancellati
sizeof(byte) + // Null-terminated portal name (always empty for now)
sizeof(int); // Max number of rows

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < len)
return FlushAndWrite(maxRows, async, cancellationToken);

Expand Down Expand Up @@ -113,9 +116,6 @@ internal async Task WriteParse(string sql, string statementName, List<NpgsqlPara
throw;
}

if (WriteBuffer.WriteSpaceLeft < 1 + 4 + statementName.Length + 1)
await Flush(async, cancellationToken);

var messageLength =
sizeof(byte) + // Message code
sizeof(int) + // Length
Expand All @@ -125,6 +125,10 @@ internal async Task WriteParse(string sql, string statementName, List<NpgsqlPara
sizeof(ushort) + // Number of parameters
inputParameters.Count * sizeof(int); // Parameter OIDs

WriteBuffer.StartMessage(messageLength);
if (WriteBuffer.WriteSpaceLeft < 1 + 4 + statementName.Length + 1)
await Flush(async, cancellationToken);

WriteBuffer.WriteByte(FrontendMessageCode.Parse);
WriteBuffer.WriteInt32(messageLength - 1);
WriteBuffer.WriteNullTerminatedString(statementName);
Expand Down Expand Up @@ -164,12 +168,6 @@ internal async Task WriteParse(string sql, string statementName, List<NpgsqlPara
statement.Length + sizeof(byte) + // Statement name plus null terminator
sizeof(ushort); // Number of parameter format codes that follow

if (WriteBuffer.WriteSpaceLeft < headerLength)
{
Debug.Assert(WriteBuffer.Size >= headerLength, "Write buffer too small for Bind header");
await Flush(async, cancellationToken);
}

var formatCodesSum = 0;
var paramsLength = 0;
for (var paramIndex = 0; paramIndex < parameters.Count; paramIndex++)
Expand All @@ -190,6 +188,13 @@ internal async Task WriteParse(string sql, string statementName, List<NpgsqlPara
sizeof(short) + // Number of result format codes
sizeof(short) * (unknownResultTypeList?.Length ?? 1); // Result format codes

WriteBuffer.StartMessage(messageLength);
if (WriteBuffer.WriteSpaceLeft < headerLength)
{
Debug.Assert(WriteBuffer.Size >= headerLength, "Write buffer too small for Bind header");
await Flush(async, cancellationToken);
}

WriteBuffer.WriteByte(FrontendMessageCode.Bind);
WriteBuffer.WriteInt32(messageLength - 1);
Debug.Assert(portal == string.Empty);
Expand Down Expand Up @@ -251,6 +256,7 @@ internal Task WriteClose(StatementOrPortal type, string name, bool async, Cancel
sizeof(byte) + // Statement or portal
name.Length + sizeof(byte); // Statement or portal name plus null terminator

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < len)
return FlushAndWrite(len, type, name, async, cancellationToken);

Expand Down Expand Up @@ -279,14 +285,17 @@ internal async Task WriteQuery(string sql, bool async, CancellationToken cancell
{
var queryByteLen = TextEncoding.GetByteCount(sql);

var len = sizeof(byte) +
sizeof(int) + // Message length (including self excluding code)
queryByteLen + // Query byte length
sizeof(byte);

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < 1 + 4)
await Flush(async, cancellationToken);

WriteBuffer.WriteByte(FrontendMessageCode.Query);
WriteBuffer.WriteInt32(
sizeof(int) + // Message length (including self excluding code)
queryByteLen + // Query byte length
sizeof(byte)); // Null terminator
WriteBuffer.WriteInt32(len - 1);

await WriteBuffer.WriteString(sql, queryByteLen, async, cancellationToken);
if (WriteBuffer.WriteSpaceLeft < 1)
Expand All @@ -301,6 +310,7 @@ internal async Task WriteCopyDone(bool async, CancellationToken cancellationToke
const int len = sizeof(byte) + // Message code
sizeof(int); // Length

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < len)
await Flush(async, cancellationToken);

Expand All @@ -316,6 +326,7 @@ internal async Task WriteCopyFail(bool async, CancellationToken cancellationToke
sizeof(int) + // Length
sizeof(byte); // Error message is always empty (only a null terminator)

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < len)
await Flush(async, cancellationToken);

Expand All @@ -333,6 +344,7 @@ internal void WriteCancelRequest(int backendProcessId, int backendSecretKey)

Debug.Assert(backendProcessId != 0);

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < len)
Flush(false).GetAwaiter().GetResult();

Expand All @@ -347,6 +359,7 @@ internal void WriteTerminate()
const int len = sizeof(byte) + // Message code
sizeof(int); // Length

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < len)
Flush(false).GetAwaiter().GetResult();

Expand All @@ -359,6 +372,7 @@ internal void WriteSslRequest()
const int len = sizeof(int) + // Length
sizeof(int); // SSL request code

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < len)
Flush(false).GetAwaiter().GetResult();

Expand All @@ -379,6 +393,7 @@ internal void WriteStartup(Dictionary<string, string> parameters)
PGUtil.UTF8Encoding.GetByteCount(kvp.Value) + 1;

// Should really never happen, just in case
WriteBuffer.StartMessage(len);
if (len > WriteBuffer.Size)
throw new Exception("Startup message bigger than buffer");

Expand All @@ -402,8 +417,10 @@ internal void WriteStartup(Dictionary<string, string> parameters)

internal async Task WritePassword(byte[] payload, int offset, int count, bool async, CancellationToken cancellationToken = default)
{
WriteBuffer.StartMessage(sizeof(byte) + sizeof(int) + count);
if (WriteBuffer.WriteSpaceLeft < sizeof(byte) + sizeof(int))
await WriteBuffer.Flush(async, cancellationToken);

WriteBuffer.WriteByte(FrontendMessageCode.Password);
WriteBuffer.WriteInt32(sizeof(int) + count);

Expand All @@ -426,6 +443,7 @@ internal async Task WriteSASLInitialResponse(string mechanism, byte[] initialRes
sizeof(int) + // Initial response length
(initialResponse?.Length ?? 0); // Initial response payload

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < len)
await WriteBuffer.Flush(async, cancellationToken);

Expand All @@ -449,6 +467,7 @@ internal async Task WriteSASLInitialResponse(string mechanism, byte[] initialRes

internal Task WritePregenerated(byte[] data, bool async = false, CancellationToken cancellationToken = default)
{
WriteBuffer.StartMessage(data.Length);
if (WriteBuffer.WriteSpaceLeft < data.Length)
return FlushAndWrite(data, async, cancellationToken);

Expand All @@ -466,4 +485,4 @@ async Task FlushAndWrite(byte[] data, bool async, CancellationToken cancellation
internal void Flush() => WriteBuffer.Flush(false).GetAwaiter().GetResult();

internal Task Flush(bool async, CancellationToken cancellationToken = default) => WriteBuffer.Flush(async, cancellationToken);
}
}
62 changes: 59 additions & 3 deletions src/Npgsql/Internal/NpgsqlWriteBuffer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ public sealed partial class NpgsqlWriteBuffer : IDisposable
internal Stream Underlying { private get; set; }

readonly Socket? _underlyingSocket;
internal bool MessageLengthValidation { get; set; } = true;

readonly ResettableCancellationTokenSource _timeoutCts;

Expand Down Expand Up @@ -72,6 +73,9 @@ internal TimeSpan Timeout

internal int WritePosition;

int _messageBytesFlushed;
int? _messageLength;

ParameterStream? _parameterStream;

bool _disposed;
Expand Down Expand Up @@ -126,6 +130,8 @@ public async Task Flush(bool async, CancellationToken cancellationToken = defaul
WritePosition = pos;
} else if (WritePosition == 0)
return;
else
AdvanceMessageBytesFlushed(WritePosition);

var finalCt = cancellationToken;
if (async && Timeout > TimeSpan.Zero)
Expand Down Expand Up @@ -193,15 +199,19 @@ internal void DirectWrite(ReadOnlySpan<byte> buffer)
Debug.Assert(WritePosition == 5);

WritePosition = 1;
WriteInt32(buffer.Length + 4);
WriteInt32(checked(buffer.Length + 4));
WritePosition = 5;
_copyMode = false;
StartMessage(5);
Flush();
_copyMode = true;
WriteCopyDataHeader(); // And ready the buffer after the direct write completes
}
else
{
Debug.Assert(WritePosition == 0);
AdvanceMessageBytesFlushed(buffer.Length);
}

try
{
Expand All @@ -224,15 +234,19 @@ internal async Task DirectWrite(ReadOnlyMemory<byte> memory, bool async, Cancell
Debug.Assert(WritePosition == 5);

WritePosition = 1;
WriteInt32(memory.Length + 4);
WriteInt32(checked(memory.Length + 4));
WritePosition = 5;
_copyMode = false;
StartMessage(5);
await Flush(async, cancellationToken);
_copyMode = true;
WriteCopyDataHeader(); // And ready the buffer after the direct write completes
}
else
{
Debug.Assert(WritePosition == 0);
AdvanceMessageBytesFlushed(memory.Length);
}

try
{
Expand Down Expand Up @@ -573,9 +587,51 @@ public void Dispose()

#region Misc

internal void StartMessage(int messageLength)
{
if (!MessageLengthValidation)
return;

if (_messageLength is not null && _messageBytesFlushed != _messageLength && WritePosition != -_messageBytesFlushed + _messageLength)
Throw();

// Add negative WritePosition to compensate for previous message(s) written without flushing.
_messageBytesFlushed = -WritePosition;
_messageLength = messageLength;

void Throw()
{
throw Connector.Break(new OverflowException("Did not write the amount of bytes the message length specified"));
}
}

void AdvanceMessageBytesFlushed(int count)
{
if (!MessageLengthValidation)
return;

if (count < 0 || _messageLength is null || (long)_messageBytesFlushed + count > _messageLength)
Throw();

_messageBytesFlushed += count;

void Throw()
{
if (count < 0)
throw new ArgumentOutOfRangeException(nameof(count), "Can't advance by a negative count");

if (_messageLength is null)
throw Connector.Break(new InvalidOperationException("No message was started"));

if ((long)_messageBytesFlushed + count > _messageLength)
throw Connector.Break(new OverflowException("Tried to write more bytes than the message length specified"));
}
}

internal void Clear()
{
WritePosition = 0;
_messageLength = null;
}

/// <summary>
Expand All @@ -590,4 +646,4 @@ internal byte[] GetContents()
}

#endregion
}
}
11 changes: 1 addition & 10 deletions src/Npgsql/NpgsqlTransaction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -223,16 +223,7 @@ public void Save(string name)

// Note: savepoint names are PostgreSQL identifiers, and so limited by default to 63 characters.
// Since we are prepending, we assume below that the statement will always fit in the buffer.
_connector.WriteBuffer.WriteByte(FrontendMessageCode.Query);
_connector.WriteBuffer.WriteInt32(
sizeof(int) + // Message length (including self excluding code)
_connector.TextEncoding.GetByteCount("SAVEPOINT ") +
_connector.TextEncoding.GetByteCount(name) +
sizeof(byte)); // Null terminator

_connector.WriteBuffer.WriteString("SAVEPOINT ");
_connector.WriteBuffer.WriteString(name);
_connector.WriteBuffer.WriteByte(0);
_connector.WriteQuery("SAVEPOINT " + name);

_connector.PendingPrependedResponses += 2;
}
Expand Down

0 comments on commit 703d9af

Please sign in to comment.