Skip to content

Commit

Permalink
Merge pull request from GHSA-x9vc-6hfv-hg8c
Browse files Browse the repository at this point in the history
* Changes to make project compile

* Add message length validation to flush and direct write

(cherry picked from commit be0c32912ea6783cb4da86f45ae6e5e121d0cdc9)

---------

Co-authored-by: Nino Floris <mail@ninofloris.com>
  • Loading branch information
roji and NinoFloris committed May 9, 2024
1 parent 42d5384 commit 3183efb
Show file tree
Hide file tree
Showing 12 changed files with 203 additions and 35 deletions.
6 changes: 3 additions & 3 deletions Directory.Build.targets
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
<PackageReference Update="NetTopologySuite.IO.PostGIS" Version="2.1.0" />
<PackageReference Update="NodaTime" Version="3.0.1" />
<PackageReference Update="GeoJSON.Net" Version="1.1.73" />
<PackageReference Update="Newtonsoft.Json" Version="12.0.2" />
<PackageReference Update="Newtonsoft.Json" Version="13.0.3" />

<!-- Tests -->
<PackageReference Update="NUnit" Version="3.13.0" />
<PackageReference Update="NLog" Version="4.6.7" />
<PackageReference Update="Microsoft.CSharp" Version="4.6.0" />
<PackageReference Update="Microsoft.NET.Test.Sdk" Version="16.5.0" />
<PackageReference Update="NUnit3TestAdapter" Version="3.17.0" />
<PackageReference Update="Microsoft.NET.Test.Sdk" Version="17.9.0" />
<PackageReference Update="NUnit3TestAdapter" Version="4.5.0" />
<PackageReference Update="xunit" Version="2.4.1" />
<PackageReference Update="xunit.runner.visualstudio" Version="2.4.1" />
<PackageReference Update="GitHubActionsTestLogger" Version="1.1.0" />
Expand Down
2 changes: 1 addition & 1 deletion src/Npgsql.Json.NET/JsonHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ protected override async ValueTask<T> Read<T>(NpgsqlReadBuffer buf, int len, boo
return await base.Read<T>(buf, len, async, fieldDescription);
}

return JsonConvert.DeserializeObject<T>(await base.Read<string>(buf, len, async, fieldDescription), _settings);
return JsonConvert.DeserializeObject<T>(await base.Read<string>(buf, len, async, fieldDescription), _settings)!;
}

protected override int ValidateAndGetLength<T2>(T2 value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter)
Expand Down
2 changes: 1 addition & 1 deletion src/Npgsql.Json.NET/JsonbHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ protected override async ValueTask<T> Read<T>(NpgsqlReadBuffer buf, int len, boo
return await base.Read<T>(buf, len, async, fieldDescription);
}

return JsonConvert.DeserializeObject<T>(await base.Read<string>(buf, len, async, fieldDescription), _settings);
return JsonConvert.DeserializeObject<T>(await base.Read<string>(buf, len, async, fieldDescription), _settings)!;
}

protected override int ValidateAndGetLength<T2>(T2 value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter)
Expand Down
45 changes: 32 additions & 13 deletions src/Npgsql/NpgsqlConnector.FrontendMessages.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ internal Task WriteDescribe(StatementOrPortal statementOrPortal, string name, bo
sizeof(byte) + // Statement or portal
(name.Length + 1); // Statement/portal name

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < len)
return FlushAndWrite(len, statementOrPortal, name, async);

Expand Down Expand Up @@ -47,6 +48,7 @@ internal Task WriteSync(bool async, CancellationToken cancellationToken = defaul
const int len = sizeof(byte) + // Message code
sizeof(int); // Length

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < len)
return FlushAndWrite(async);

Expand Down Expand Up @@ -76,6 +78,7 @@ internal Task WriteExecute(int maxRows, bool async, CancellationToken cancellati
sizeof(byte) + // Null-terminated portal name (always empty for now)
sizeof(int); // Max number of rows

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < len)
return FlushAndWrite(maxRows, async);

Expand Down Expand Up @@ -113,9 +116,6 @@ internal async Task WriteParse(string sql, string statementName, List<NpgsqlPara
throw;
}

if (WriteBuffer.WriteSpaceLeft < 1 + 4 + statementName.Length + 1)
await Flush(async, cancellationToken);

var messageLength =
sizeof(byte) + // Message code
sizeof(int) + // Length
Expand All @@ -125,6 +125,10 @@ internal async Task WriteParse(string sql, string statementName, List<NpgsqlPara
sizeof(ushort) + // Number of parameters
inputParameters.Count * sizeof(int); // Parameter OIDs

WriteBuffer.StartMessage(messageLength);
if (WriteBuffer.WriteSpaceLeft < 1 + 4 + statementName.Length + 1)
await Flush(async, cancellationToken);

WriteBuffer.WriteByte(FrontendMessageCode.Parse);
WriteBuffer.WriteInt32(messageLength - 1);
WriteBuffer.WriteNullTerminatedString(statementName);
Expand Down Expand Up @@ -164,12 +168,6 @@ internal async Task WriteParse(string sql, string statementName, List<NpgsqlPara
statement.Length + sizeof(byte) + // Statement name plus null terminator
sizeof(ushort); // Number of parameter format codes that follow

if (WriteBuffer.WriteSpaceLeft < headerLength)
{
Debug.Assert(WriteBuffer.Size >= headerLength, "Write buffer too small for Bind header");
await Flush(async, cancellationToken);
}

var formatCodesSum = 0;
var paramsLength = 0;
foreach (var p in inputParameters)
Expand All @@ -189,6 +187,13 @@ internal async Task WriteParse(string sql, string statementName, List<NpgsqlPara
sizeof(short) + // Number of result format codes
sizeof(short) * (unknownResultTypeList?.Length ?? 1); // Result format codes

WriteBuffer.StartMessage(messageLength);
if (WriteBuffer.WriteSpaceLeft < headerLength)
{
Debug.Assert(WriteBuffer.Size >= headerLength, "Write buffer too small for Bind header");
await Flush(async, cancellationToken);
}

WriteBuffer.WriteByte(FrontendMessageCode.Bind);
WriteBuffer.WriteInt32(messageLength - 1);
Debug.Assert(portal == string.Empty);
Expand Down Expand Up @@ -249,6 +254,7 @@ internal Task WriteClose(StatementOrPortal type, string name, bool async, Cancel
sizeof(byte) + // Statement or portal
name.Length + sizeof(byte); // Statement or portal name plus null terminator

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < len)
return FlushAndWrite(len, type, name, async);

Expand Down Expand Up @@ -277,14 +283,17 @@ internal async Task WriteQuery(string sql, bool async, CancellationToken cancell
{
var queryByteLen = TextEncoding.GetByteCount(sql);

var len = sizeof(byte) +
sizeof(int) + // Message length (including self excluding code)
queryByteLen + // Query byte length
sizeof(byte);

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < 1 + 4)
await Flush(async, cancellationToken);

WriteBuffer.WriteByte(FrontendMessageCode.Query);
WriteBuffer.WriteInt32(
sizeof(int) + // Message length (including self excluding code)
queryByteLen + // Query byte length
sizeof(byte)); // Null terminator
WriteBuffer.WriteInt32(len - 1);

await WriteBuffer.WriteString(sql, queryByteLen, async, cancellationToken);
if (WriteBuffer.WriteSpaceLeft < 1)
Expand All @@ -299,6 +308,7 @@ internal async Task WriteCopyDone(bool async, CancellationToken cancellationToke
const int len = sizeof(byte) + // Message code
sizeof(int); // Length

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < len)
await Flush(async, cancellationToken);

Expand All @@ -314,6 +324,7 @@ internal async Task WriteCopyFail(bool async, CancellationToken cancellationToke
sizeof(int) + // Length
sizeof(byte); // Error message is always empty (only a null terminator)

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < len)
await Flush(async, cancellationToken);

Expand All @@ -331,6 +342,7 @@ internal void WriteCancelRequest(int backendProcessId, int backendSecretKey)

Debug.Assert(backendProcessId != 0);

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < len)
Flush(false).GetAwaiter().GetResult();

Expand All @@ -345,6 +357,7 @@ internal void WriteTerminate()
const int len = sizeof(byte) + // Message code
sizeof(int); // Length

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < len)
Flush(false).GetAwaiter().GetResult();

Expand All @@ -357,6 +370,7 @@ internal void WriteSslRequest()
const int len = sizeof(int) + // Length
sizeof(int); // SSL request code

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < len)
Flush(false).GetAwaiter().GetResult();

Expand All @@ -377,6 +391,7 @@ internal void WriteStartup(Dictionary<string, string> parameters)
PGUtil.UTF8Encoding.GetByteCount(kvp.Value) + 1;

// Should really never happen, just in case
WriteBuffer.StartMessage(len);
if (len > WriteBuffer.Size)
throw new Exception("Startup message bigger than buffer");

Expand All @@ -400,8 +415,10 @@ internal void WriteStartup(Dictionary<string, string> parameters)

internal async Task WritePassword(byte[] payload, int offset, int count, bool async, CancellationToken cancellationToken = default)
{
WriteBuffer.StartMessage(sizeof(byte) + sizeof(int) + count);
if (WriteBuffer.WriteSpaceLeft < sizeof(byte) + sizeof(int))
await WriteBuffer.Flush(async, cancellationToken);

WriteBuffer.WriteByte(FrontendMessageCode.Password);
WriteBuffer.WriteInt32(sizeof(int) + count);

Expand All @@ -424,6 +441,7 @@ internal async Task WriteSASLInitialResponse(string mechanism, byte[] initialRes
sizeof(int) + // Initial response length
(initialResponse?.Length ?? 0); // Initial response payload

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < len)
await WriteBuffer.Flush(async, cancellationToken);

Expand All @@ -447,6 +465,7 @@ internal async Task WriteSASLInitialResponse(string mechanism, byte[] initialRes

internal Task WritePregenerated(byte[] data, bool async = false, CancellationToken cancellationToken = default)
{
WriteBuffer.StartMessage(data.Length);
if (WriteBuffer.WriteSpaceLeft < data.Length)
return FlushAndWrite(data, async);

Expand Down
13 changes: 2 additions & 11 deletions src/Npgsql/NpgsqlTransaction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -220,16 +220,7 @@ public void Save(string name)

// Note: savepoint names are PostgreSQL identifiers, and so limited by default to 63 characters.
// Since we are prepending, we assume below that the statement will always fit in the buffer.
_connector.WriteBuffer.WriteByte(FrontendMessageCode.Query);
_connector.WriteBuffer.WriteInt32(
sizeof(int) + // Message length (including self excluding code)
_connector.TextEncoding.GetByteCount("SAVEPOINT ") +
_connector.TextEncoding.GetByteCount(name) +
sizeof(byte)); // Null terminator

_connector.WriteBuffer.WriteString("SAVEPOINT ");
_connector.WriteBuffer.WriteString(name);
_connector.WriteBuffer.WriteByte(0);
_connector.WriteQuery("SAVEPOINT " + name);

_connector.PendingPrependedResponses += 2;
}
Expand Down Expand Up @@ -414,7 +405,7 @@ async ValueTask DisposeAsyncInternal()
Debug.Assert(_connector.IsBroken);
Log.Error("Exception while disposing a transaction", ex, _connector.Id);
}

IsDisposed = true;
_connector?.Connection?.EndBindingScope(ConnectorBindingScope.Transaction);
}
Expand Down
60 changes: 58 additions & 2 deletions src/Npgsql/NpgsqlWriteBuffer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ public sealed partial class NpgsqlWriteBuffer : IDisposable
internal Stream Underlying { private get; set; }

readonly Socket? _underlyingSocket;
internal bool MessageLengthValidation { get; set; } = true;

readonly ResettableCancellationTokenSource _timeoutCts;

Expand Down Expand Up @@ -72,6 +73,9 @@ internal TimeSpan Timeout

internal int WritePosition;

int _messageBytesFlushed;
int? _messageLength;

ParameterStream? _parameterStream;

bool _disposed;
Expand Down Expand Up @@ -120,6 +124,8 @@ public async Task Flush(bool async, CancellationToken cancellationToken = defaul
WritePosition = pos;
} else if (WritePosition == 0)
return;
else
AdvanceMessageBytesFlushed(WritePosition);

var finalCt = cancellationToken;
if (async && Timeout > TimeSpan.Zero)
Expand Down Expand Up @@ -187,15 +193,19 @@ internal void DirectWrite(ReadOnlySpan<byte> buffer)
Debug.Assert(WritePosition == 5);

WritePosition = 1;
WriteInt32(buffer.Length + 4);
WriteInt32(checked(buffer.Length + 4));
WritePosition = 5;
_copyMode = false;
StartMessage(5);
Flush();
_copyMode = true;
WriteCopyDataHeader(); // And ready the buffer after the direct write completes
}
else
{
Debug.Assert(WritePosition == 0);
AdvanceMessageBytesFlushed(buffer.Length);
}

try
{
Expand All @@ -218,15 +228,19 @@ internal async Task DirectWrite(ReadOnlyMemory<byte> memory, bool async, Cancell
Debug.Assert(WritePosition == 5);

WritePosition = 1;
WriteInt32(memory.Length + 4);
WriteInt32(checked(memory.Length + 4));
WritePosition = 5;
_copyMode = false;
StartMessage(5);
await Flush(async, cancellationToken);
_copyMode = true;
WriteCopyDataHeader(); // And ready the buffer after the direct write completes
}
else
{
Debug.Assert(WritePosition == 0);
AdvanceMessageBytesFlushed(memory.Length);
}

try
{
Expand Down Expand Up @@ -569,9 +583,51 @@ public void Dispose()

#region Misc

internal void StartMessage(int messageLength)
{
if (!MessageLengthValidation)
return;

if (_messageLength is not null && _messageBytesFlushed != _messageLength && WritePosition != -_messageBytesFlushed + _messageLength)
Throw();

// Add negative WritePosition to compensate for previous message(s) written without flushing.
_messageBytesFlushed = -WritePosition;
_messageLength = messageLength;

void Throw()
{
throw Connector.Break(new OverflowException("Did not write the amount of bytes the message length specified"));
}
}

void AdvanceMessageBytesFlushed(int count)
{
if (!MessageLengthValidation)
return;

if (count < 0 || _messageLength is null || (long)_messageBytesFlushed + count > _messageLength)
Throw();

_messageBytesFlushed += count;

void Throw()
{
if (count < 0)
throw new ArgumentOutOfRangeException(nameof(count), "Can't advance by a negative count");

if (_messageLength is null)
throw Connector.Break(new InvalidOperationException("No message was started"));

if ((long)_messageBytesFlushed + count > _messageLength)
throw Connector.Break(new OverflowException("Tried to write more bytes than the message length specified"));
}
}

internal void Clear()
{
WritePosition = 0;
_messageLength = null;
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ public override int ValidateAndGetLength(T value, ref NpgsqlLengthCache? lengthC
foreach (var member in _memberHandlers)
length += member.ValidateAndGetLength(value, ref lengthCache);

return lengthCache.Lengths[position] = length;
return lengthCache!.Lengths[position] = length;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
Expand Down
2 changes: 1 addition & 1 deletion src/Npgsql/TypeHandlers/HstoreHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ public int ValidateAndGetLength(IDictionary<string, string?> value, ref NpgsqlLe
totalLen += _textHandler.ValidateAndGetLength(kv.Value!, ref lengthCache, null);
}

return lengthCache.Lengths[pos] = totalLen;
return lengthCache!.Lengths[pos] = totalLen;
}

/// <inheritdoc />
Expand Down
2 changes: 1 addition & 1 deletion test/Npgsql.PluginTests/JsonNetTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ public void RoundtripJObject()
{
reader.Read();
var actual = reader.GetFieldValue<JObject>(0);
Assert.That((int)actual["Bar"], Is.EqualTo(8));
Assert.That((int)actual["Bar"]!, Is.EqualTo(8));
}
}
}
Expand Down

0 comments on commit 3183efb

Please sign in to comment.