Skip to content

Commit

Permalink
Merge pull request from GHSA-x9vc-6hfv-hg8c
Browse files Browse the repository at this point in the history
* Changes to make project compile

(cherry picked from commit c3079f1492d27d38564255716678ce8a33ee7039)

* Add message length validation to flush and direct write

(cherry picked from commit 0e833ba4c278d6dab71a460a1fcc34f89cbb6f49)

* Fix direct write byte count

---------

Co-authored-by: Nino Floris <mail@ninofloris.com>
  • Loading branch information
roji and NinoFloris committed May 9, 2024
1 parent 337acb3 commit 67acbe0
Show file tree
Hide file tree
Showing 9 changed files with 202 additions and 24 deletions.
6 changes: 3 additions & 3 deletions Directory.Build.targets
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
<PackageReference Update="NetTopologySuite.IO.PostGIS" Version="2.0.0" />
<PackageReference Update="NodaTime" Version="2.4.7" />
<PackageReference Update="GeoJSON.Net" Version="1.1.73" />
<PackageReference Update="Newtonsoft.Json" Version="11.0.2" />
<PackageReference Update="Newtonsoft.Json" Version="13.0.3" />

<!-- Tests -->
<PackageReference Update="NUnit" Version="3.12.0" />
<PackageReference Update="NLog" Version="4.6.7" />
<PackageReference Update="Microsoft.CSharp" Version="4.6.0" />
<PackageReference Update="Microsoft.NET.Test.Sdk" Version="16.5.0" />
<PackageReference Update="NUnit3TestAdapter" Version="3.15.1" />
<PackageReference Update="Microsoft.NET.Test.Sdk" Version="17.9.0" />
<PackageReference Update="NUnit3TestAdapter" Version="4.5.0" />
<PackageReference Update="xunit" Version="2.4.1" />
<PackageReference Update="xunit.runner.visualstudio" Version="2.4.1" />
<PackageReference Update="GitHubActionsTestLogger" Version="1.1.0" />
Expand Down
6 changes: 3 additions & 3 deletions global.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"sdk": {
"version": "3.1.302",
"rollForward": "minor",
"allowPrerelease": "false"
"version": "6.0.401",
"rollForward": "latestMajor",
"allowPrerelease": "true"
}
}
43 changes: 31 additions & 12 deletions src/Npgsql/NpgsqlConnector.FrontendMessages.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ internal Task WriteDescribe(StatementOrPortal statementOrPortal, string name, bo
sizeof(byte) + // Statement or portal
(name.Length + 1); // Statement/portal name

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < len)
return FlushAndWrite(len, statementOrPortal, name, async);

Expand Down Expand Up @@ -46,6 +47,7 @@ internal Task WriteSync(bool async)
const int len = sizeof(byte) + // Message code
sizeof(int); // Length

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < len)
return FlushAndWrite(async);

Expand Down Expand Up @@ -75,6 +77,7 @@ internal Task WriteExecute(int maxRows, bool async)
sizeof(byte) + // Null-terminated portal name (always empty for now)
sizeof(int); // Max number of rows

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < len)
return FlushAndWrite(maxRows, async);

Expand Down Expand Up @@ -102,8 +105,6 @@ internal async Task WriteParse(string sql, string statementName, List<NpgsqlPara
Debug.Assert(statementName.All(c => c < 128));

var queryByteLen = TextEncoding.GetByteCount(sql);
if (WriteBuffer.WriteSpaceLeft < 1 + 4 + statementName.Length + 1)
await Flush(async);

var messageLength =
sizeof(byte) + // Message code
Expand All @@ -114,6 +115,10 @@ internal async Task WriteParse(string sql, string statementName, List<NpgsqlPara
sizeof(ushort) + // Number of parameters
inputParameters.Count * sizeof(int); // Parameter OIDs

WriteBuffer.StartMessage(messageLength);
if (WriteBuffer.WriteSpaceLeft < 1 + 4 + statementName.Length + 1)
await Flush(async);

WriteBuffer.WriteByte(FrontendMessageCode.Parse);
WriteBuffer.WriteInt32(messageLength - 1);
WriteBuffer.WriteNullTerminatedString(statementName);
Expand Down Expand Up @@ -152,12 +157,6 @@ internal async Task WriteParse(string sql, string statementName, List<NpgsqlPara
statement.Length + sizeof(byte) + // Statement name plus null terminator
sizeof(ushort); // Number of parameter format codes that follow

if (WriteBuffer.WriteSpaceLeft < headerLength)
{
Debug.Assert(WriteBuffer.Size >= headerLength, "Write buffer too small for Bind header");
await Flush(async);
}

var formatCodesSum = 0;
var paramsLength = 0;
foreach (var p in inputParameters)
Expand All @@ -177,6 +176,13 @@ internal async Task WriteParse(string sql, string statementName, List<NpgsqlPara
sizeof(short) + // Number of result format codes
sizeof(short) * (unknownResultTypeList?.Length ?? 1); // Result format codes

WriteBuffer.StartMessage(messageLength);
if (WriteBuffer.WriteSpaceLeft < headerLength)
{
Debug.Assert(WriteBuffer.Size >= headerLength, "Write buffer too small for Bind header");
await Flush(async);
}

WriteBuffer.WriteByte(FrontendMessageCode.Bind);
WriteBuffer.WriteInt32(messageLength - 1);
Debug.Assert(portal == string.Empty);
Expand Down Expand Up @@ -237,6 +243,7 @@ internal Task WriteClose(StatementOrPortal type, string name, bool async)
sizeof(byte) + // Statement or portal
name.Length + sizeof(byte); // Statement or portal name plus null terminator

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < 10)
return FlushAndWrite(len, type, name, async);

Expand Down Expand Up @@ -265,14 +272,17 @@ internal async Task WriteQuery(string sql, bool async)
{
var queryByteLen = TextEncoding.GetByteCount(sql);

var len = sizeof(byte) +
sizeof(int) + // Message length (including self excluding code)
queryByteLen + // Query byte length
sizeof(byte);

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < 1 + 4)
await Flush(async);

WriteBuffer.WriteByte(FrontendMessageCode.Query);
WriteBuffer.WriteInt32(
sizeof(int) + // Message length (including self excluding code)
queryByteLen + // Query byte length
sizeof(byte)); // Null terminator
WriteBuffer.WriteInt32(len - 1);

await WriteBuffer.WriteString(sql, queryByteLen, async);
if (WriteBuffer.WriteSpaceLeft < 1)
Expand All @@ -287,6 +297,7 @@ internal async Task WriteCopyDone(bool async)
const int len = sizeof(byte) + // Message code
sizeof(int); // Length

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < len)
await Flush(async);

Expand All @@ -302,6 +313,7 @@ internal async Task WriteCopyFail(bool async)
sizeof(int) + // Length
sizeof(byte); // Error message is always empty (only a null terminator)

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < len)
await Flush(async);

Expand All @@ -319,6 +331,7 @@ internal void WriteCancelRequest(int backendProcessId, int backendSecretKey)

Debug.Assert(backendProcessId != 0);

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < len)
Flush(false).GetAwaiter().GetResult();

Expand All @@ -333,6 +346,7 @@ internal void WriteTerminate()
const int len = sizeof(byte) + // Message code
sizeof(int); // Length

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < len)
Flush(false).GetAwaiter().GetResult();

Expand All @@ -345,6 +359,7 @@ internal void WriteSslRequest()
const int len = sizeof(int) + // Length
sizeof(int); // SSL request code

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < len)
Flush(false).GetAwaiter().GetResult();

Expand All @@ -365,6 +380,7 @@ internal void WriteStartup(Dictionary<string, string> parameters)
PGUtil.UTF8Encoding.GetByteCount(kvp.Value) + 1;

// Should really never happen, just in case
WriteBuffer.StartMessage(len);
if (len > WriteBuffer.Size)
throw new Exception("Startup message bigger than buffer");

Expand All @@ -388,6 +404,7 @@ internal void WriteStartup(Dictionary<string, string> parameters)

internal async Task WritePassword(byte[] payload, int offset, int count, bool async)
{
WriteBuffer.StartMessage(sizeof(byte) + sizeof(int) + count);
if (WriteBuffer.WriteSpaceLeft < sizeof(byte) + sizeof(int))
await WriteBuffer.Flush(async);
WriteBuffer.WriteByte(FrontendMessageCode.Password);
Expand All @@ -412,6 +429,7 @@ internal async Task WriteSASLInitialResponse(string mechanism, byte[] initialRes
sizeof(int) + // Initial response length
(initialResponse?.Length ?? 0); // Initial response payload

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < len)
await WriteBuffer.Flush(async);

Expand All @@ -435,6 +453,7 @@ internal async Task WriteSASLInitialResponse(string mechanism, byte[] initialRes

internal Task WritePregenerated(byte[] data, bool async=false)
{
WriteBuffer.StartMessage(data.Length);
if (WriteBuffer.WriteSpaceLeft < data.Length)
return FlushAndWrite(data, async);

Expand Down
3 changes: 3 additions & 0 deletions src/Npgsql/NpgsqlTransaction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,9 @@ async Task Save(string name, bool async)
CheckReady();
if (!_connector.DatabaseInfo.SupportsTransactions)
return;

// Note that creating a savepoint doesn't actually send anything to the backend (only prepends), so strictly speaking we don't
// have to start a user action. However, we do this for consistency as if we did (for the checks and exceptions)
using (_connector.StartUserAction())
{
Log.Debug($"Creating savepoint {name}", _connector.Id);
Expand Down
69 changes: 66 additions & 3 deletions src/Npgsql/NpgsqlWriteBuffer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ public sealed partial class NpgsqlWriteBuffer
internal readonly NpgsqlConnector Connector;

internal Stream Underlying { private get; set; }
internal bool MessageLengthValidation { get; set; } = true;

/// <summary>
/// The total byte length of the buffer.
Expand All @@ -37,6 +38,9 @@ public sealed partial class NpgsqlWriteBuffer

internal int WritePosition;

int _messageBytesFlushed;
int? _messageLength;

ParameterStream? _parameterStream;

/// <summary>
Expand Down Expand Up @@ -81,6 +85,8 @@ public async Task Flush(bool async)
WritePosition = pos;
} else if (WritePosition == 0)
return;
else
AdvanceMessageBytesFlushed(WritePosition);

try
{
Expand Down Expand Up @@ -133,15 +139,19 @@ internal async Task DirectWrite(byte[] buffer, int offset, int count, bool async
Debug.Assert(WritePosition == 5);

WritePosition = 1;
WriteInt32(count + 4);
WriteInt32(checked(count + 4));
WritePosition = 5;
_copyMode = false;
await Flush(async);
StartMessage(5);
Flush();
_copyMode = true;
WriteCopyDataHeader(); // And ready the buffer after the direct write completes
}
else
{
Debug.Assert(WritePosition == 0);
AdvanceMessageBytesFlushed(count);
}

try
{
Expand Down Expand Up @@ -169,15 +179,19 @@ internal async Task DirectWrite(ReadOnlyMemory<byte> memory, bool async)
Debug.Assert(WritePosition == 5);

WritePosition = 1;
WriteInt32(memory.Length + 4);
WriteInt32(checked(memory.Length + 4));
WritePosition = 5;
_copyMode = false;
StartMessage(5);
await Flush(async);
_copyMode = true;
WriteCopyDataHeader(); // And ready the buffer after the direct write completes
}
else
{
Debug.Assert(WritePosition == 0);
AdvanceMessageBytesFlushed(memory.Length);
}


try
Expand Down Expand Up @@ -508,9 +522,58 @@ void WriteCopyDataHeader()

#region Misc

internal void StartMessage(int messageLength)
{
if (!MessageLengthValidation)
return;

if (_messageLength != null && _messageBytesFlushed != _messageLength && WritePosition != -_messageBytesFlushed + _messageLength)
Throw();

// Add negative WritePosition to compensate for previous message(s) written without flushing.
_messageBytesFlushed = -WritePosition;
_messageLength = messageLength;

void Throw()
{
Connector.Break();
throw new OverflowException("Did not write the amount of bytes the message length specified");
}
}

void AdvanceMessageBytesFlushed(int count)
{
if (!MessageLengthValidation)
return;

if (count < 0 || _messageLength is null || (long)_messageBytesFlushed + count > _messageLength)
Throw();

_messageBytesFlushed += count;

void Throw()
{
if (count < 0)
throw new ArgumentOutOfRangeException(nameof(count), "Can't advance by a negative count");

if (_messageLength is null)
{
Connector.Break();
new InvalidOperationException("No message was started");
}

if ((long)_messageBytesFlushed + count > _messageLength)
{
Connector.Break();
throw new OverflowException("Tried to write more bytes than the message length specified");
}
}
}

internal void Clear()
{
WritePosition = 0;
_messageLength = null;
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ public override int ValidateAndGetLength(T value, ref NpgsqlLengthCache? lengthC
foreach (var member in _memberHandlers)
length += member.ValidateAndGetLength(value, ref lengthCache);

return lengthCache.Lengths[position] = length;
return lengthCache!.Lengths[position] = length;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
Expand Down
2 changes: 1 addition & 1 deletion src/Npgsql/TypeHandlers/HstoreHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public int ValidateAndGetLength(IDictionary<string, string?> value, ref NpgsqlLe
totalLen += _textHandler.ValidateAndGetLength(kv.Value!, ref lengthCache, null);
}

return lengthCache.Lengths[pos] = totalLen;
return lengthCache!.Lengths[pos] = totalLen;
}

/// <inheritdoc />
Expand Down
2 changes: 1 addition & 1 deletion test/Npgsql.PluginTests/JsonNetTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public void RoundtripJObject()
{
reader.Read();
var actual = reader.GetFieldValue<JObject>(0);
Assert.That((int)actual["Bar"], Is.EqualTo(8));
Assert.That((int)actual["Bar"]!, Is.EqualTo(8));
}
}
}
Expand Down

0 comments on commit 67acbe0

Please sign in to comment.