Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add explicit protocol validation when reading RESP messages #332

Merged
merged 19 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions benchmark/BDN.benchmark/Resp/RespIntegerReadBenchmarks.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ namespace BDN.benchmark.Resp
public unsafe class RespIntegerReadBenchmarks
{
[Benchmark]
[ArgumentsSource(nameof(SignedInt32EncodedValues))]
public int ReadInt32(AsciiTestCase testCase)
[ArgumentsSource(nameof(LengthHeaderValues))]
public int ReadLengthHeader(AsciiTestCase testCase)
{
fixed (byte* inputPtr = testCase.Bytes)
{
var start = inputPtr;
RespReadUtils.ReadInt(out var value, ref start, start + testCase.Bytes.Length);
RespReadUtils.ReadLengthHeader(out var value, ref start, start + testCase.Bytes.Length, allowNull: true);
return value;
}
}
Expand Down Expand Up @@ -72,6 +72,9 @@ public ulong ReadULongWithLengthHeader(AsciiTestCase testCase)
public static IEnumerable<object> SignedInt32EncodedValues
=> ToRespIntegerTestCases(RespIntegerWriteBenchmarks.SignedInt32Values);

public static IEnumerable<object> LengthHeaderValues
=> ToRespLengthHeaderTestCases(RespIntegerWriteBenchmarks.SignedInt32Values);

public static IEnumerable<object> SignedInt64EncodedValues
=> ToRespIntegerTestCases(RespIntegerWriteBenchmarks.SignedInt64Values);

Expand All @@ -90,6 +93,9 @@ public static IEnumerable<object> UnsignedInt64EncodedValuesWithLengthHeader
public static IEnumerable<AsciiTestCase> ToRespIntegerTestCases<T>(T[] integerValues) where T : struct
=> integerValues.Select(testCase => new AsciiTestCase($":{testCase}\r\n"));

public static IEnumerable<AsciiTestCase> ToRespLengthHeaderTestCases<T>(T[] integerValues) where T : struct
=> integerValues.Select(testCase => new AsciiTestCase($"${testCase}\r\n"));

public static IEnumerable<AsciiTestCase> ToRespIntegerWithLengthHeader<T>(T[] integerValues) where T : struct
=> integerValues.Select(testCase => new AsciiTestCase($"${testCase.ToString()?.Length ?? 0}\r\n{testCase}\r\n"));

Expand Down
2 changes: 1 addition & 1 deletion libs/client/GarnetClientProcessReplies.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ unsafe bool ProcessReplyAsString(ref byte* ptr, byte* end, out string result, ou
break;

case (byte)'$':
if (!RespReadUtils.ReadStringWithLengthHeader(out result, ref ptr, end))
if (!RespReadUtils.ReadStringWithLengthHeader(out result, ref ptr, end, allowNull: true))
return false;
break;

Expand Down
43 changes: 21 additions & 22 deletions libs/cluster/Session/ClusterSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
using System;
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using Garnet.common;
using Garnet.common.Parsing;
using Garnet.networking;
using Garnet.server;
using Garnet.server.ACL;
Expand Down Expand Up @@ -225,38 +227,35 @@ bool CheckACLAdminPermissions()

ReadOnlySpan<byte> GetCommand(ReadOnlySpan<byte> bufSpan, out bool success)
{
if (bytesRead - readHead < 6)
success = false;

var ptr = recvBufferPtr + readHead;
var end = recvBufferPtr + bytesRead;

// Try to read the command length
if (!RespReadUtils.ReadLengthHeader(out int length, ref ptr, end))
{
success = false;
return default;
}

Debug.Assert(*(recvBufferPtr + readHead) == '$');
int psize = *(recvBufferPtr + readHead + 1) - '0';
readHead += 2;
while (*(recvBufferPtr + readHead) != '\r')
{
psize = psize * 10 + *(recvBufferPtr + readHead) - '0';
if (bytesRead - readHead < 1)
{
success = false;
return default;
}
readHead++;
}
if (bytesRead - readHead < 2 + psize + 2)
readHead = (int)(ptr - recvBufferPtr);

// Try to read the command value
ptr += length;
if (ptr + 2 > end)
{
success = false;
return default;
}
Debug.Assert(*(recvBufferPtr + readHead + 1) == '\n');

var result = bufSpan.Slice(readHead + 2, psize);
Debug.Assert(*(recvBufferPtr + readHead + 2 + psize) == '\r');
Debug.Assert(*(recvBufferPtr + readHead + 2 + psize + 1) == '\n');
if (*(ushort*)ptr != MemoryMarshal.Read<ushort>("\r\n"u8))
{
RespParsingException.ThrowUnexpectedToken(*ptr);
}

readHead += 2 + psize + 2;
success = true;
var result = bufSpan.Slice(readHead, length);
readHead += length + 2;

return result;
}
}
Expand Down
85 changes: 85 additions & 0 deletions libs/common/Parsing/RespParsingException.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

using System.Diagnostics.CodeAnalysis;
using System.Text;

namespace Garnet.common.Parsing
{
/// <summary>
/// Exception wrapper for RESP parsing errors.
/// </summary>
public class RespParsingException : GarnetException
{
/// <summary>
/// Construct a new RESP parsing exception with the given message.
/// </summary>
/// <param name="message">Message that described the exception that has occurred.</param>
RespParsingException(string message) : base(message)
{
// Nothing...
}

/// <summary>
/// Throw an "Unexcepted Token" exception.
/// </summary>
/// <param name="token">The character that was unexpected.</param>
[DoesNotReturn]
public static void ThrowUnexpectedToken(byte token)
{
var c = (char)token;
var escaped = char.IsControl(c) ? $"\\x{token:x2}" : c.ToString();
Throw($"Unexpected character '{escaped}'.");
}

/// <summary>
/// Throw an invalid string length exception.
/// </summary>
/// <param name="len">The invalid string length.</param>
[DoesNotReturn]
public static void ThrowInvalidStringLength(long len)
{
Throw($"Invalid string length '{len}'.");
}

/// <summary>
/// Throw an invalid length exception.
/// </summary>
/// <param name="len">The invalid length.</param>
[DoesNotReturn]
public static void ThrowInvalidLength(long len)
{
Throw($"Invalid length '{len}'.");
}

/// <summary>
/// Throw NaN (not a number) exception.
/// </summary>
/// <param name="buffer">Pointer to an ASCII-encoded byte buffer containing the string that could not be converted.</param>
/// <param name="length">Length of the buffer.</param>
[DoesNotReturn]
public static unsafe void ThrowNotANumber(byte* buffer, int length)
{
Throw($"Unable to parse number: {Encoding.ASCII.GetString(buffer, length)}");
}

/// <summary>
/// Throw a exception indicating that an integer overflow has occurred.
/// </summary>
/// <param name="buffer">Pointer to an ASCII-encoded byte buffer containing the string that caused the overflow.</param>
/// <param name="length">Length of the buffer.</param>
[DoesNotReturn]
public static unsafe void ThrowIntegerOverflow(byte* buffer, int length)
{
Throw($"Unable to parse integer. The given number is larger than allowed: {Encoding.ASCII.GetString(buffer, length)}");
}

/// <summary>
/// Throw helper that throws a RespParsingException.
/// </summary>
/// <param name="message">Exception message.</param>
[DoesNotReturn]
public static void Throw(string message) =>
throw new RespParsingException(message);
}
}
Loading