Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add explicit protocol validation when reading RESP messages #332

Merged
merged 19 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions benchmark/BDN.benchmark/Resp/RespIntegerReadBenchmarks.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ namespace BDN.benchmark.Resp
public unsafe class RespIntegerReadBenchmarks
{
[Benchmark]
[ArgumentsSource(nameof(SignedInt32EncodedValues))]
public int ReadInt32(AsciiTestCase testCase)
[ArgumentsSource(nameof(LengthHeaderValues))]
public int ReadLengthHeader(AsciiTestCase testCase)
{
fixed (byte* inputPtr = testCase.Bytes)
{
var start = inputPtr;
RespReadUtils.ReadInt(out var value, ref start, start + testCase.Bytes.Length);
RespReadUtils.ReadLengthHeader(out var value, ref start, start + testCase.Bytes.Length);
return value;
}
}
Expand Down Expand Up @@ -72,6 +72,9 @@ public ulong ReadULongWithLengthHeader(AsciiTestCase testCase)
public static IEnumerable<object> SignedInt32EncodedValues
=> ToRespIntegerTestCases(RespIntegerWriteBenchmarks.SignedInt32Values);

public static IEnumerable<object> LengthHeaderValues
=> ToRespLengthHeaderTestCases(RespIntegerWriteBenchmarks.SignedInt32Values);

public static IEnumerable<object> SignedInt64EncodedValues
=> ToRespIntegerTestCases(RespIntegerWriteBenchmarks.SignedInt64Values);

Expand All @@ -90,6 +93,9 @@ public static IEnumerable<object> UnsignedInt64EncodedValuesWithLengthHeader
public static IEnumerable<AsciiTestCase> ToRespIntegerTestCases<T>(T[] integerValues) where T : struct
=> integerValues.Select(testCase => new AsciiTestCase($":{testCase}\r\n"));

public static IEnumerable<AsciiTestCase> ToRespLengthHeaderTestCases<T>(T[] integerValues) where T : struct
=> integerValues.Select(testCase => new AsciiTestCase($"${testCase}\r\n"));

public static IEnumerable<AsciiTestCase> ToRespIntegerWithLengthHeader<T>(T[] integerValues) where T : struct
=> integerValues.Select(testCase => new AsciiTestCase($"${testCase.ToString()?.Length ?? 0}\r\n{testCase}\r\n"));

Expand Down
43 changes: 21 additions & 22 deletions libs/cluster/Session/ClusterSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
using System;
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using Garnet.common;
using Garnet.common.Parsing;
using Garnet.networking;
using Garnet.server;
using Garnet.server.ACL;
Expand Down Expand Up @@ -225,38 +227,35 @@ bool CheckACLAdminPermissions()

ReadOnlySpan<byte> GetCommand(ReadOnlySpan<byte> bufSpan, out bool success)
{
if (bytesRead - readHead < 6)
success = false;

var ptr = recvBufferPtr + readHead;
var end = recvBufferPtr + bytesRead;

// Try to read the command length
if (!RespReadUtils.ReadLengthHeader(out int length, ref ptr, end))
{
success = false;
return default;
}

Debug.Assert(*(recvBufferPtr + readHead) == '$');
int psize = *(recvBufferPtr + readHead + 1) - '0';
readHead += 2;
while (*(recvBufferPtr + readHead) != '\r')
{
psize = psize * 10 + *(recvBufferPtr + readHead) - '0';
if (bytesRead - readHead < 1)
{
success = false;
return default;
}
readHead++;
}
if (bytesRead - readHead < 2 + psize + 2)
readHead = (int)(ptr - recvBufferPtr);

// Try to read the command value
ptr += length;
if (ptr + 2 > end)
{
success = false;
return default;
}
Debug.Assert(*(recvBufferPtr + readHead + 1) == '\n');

var result = bufSpan.Slice(readHead + 2, psize);
Debug.Assert(*(recvBufferPtr + readHead + 2 + psize) == '\r');
Debug.Assert(*(recvBufferPtr + readHead + 2 + psize + 1) == '\n');
if (*(ushort*)ptr != MemoryMarshal.Read<ushort>("\r\n"u8))
{
RespParsingException.ThrowUnexpectedToken(*ptr);
}

readHead += 2 + psize + 2;
success = true;
var result = bufSpan.Slice(readHead, length);
readHead += length + 2;

return result;
}
}
Expand Down
2 changes: 1 addition & 1 deletion libs/cluster/Session/MigrateCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ private bool TryMIGRATE(int count, byte* ptr)
//3. Key
byte* singleKeyPtr = null;
var sksize = 0;
if (!RespReadUtils.ReadPtrWithLengthHeader(ref singleKeyPtr, ref sksize, ref ptr, recvBufferPtr + bytesRead))
if (!RespReadUtils.ReadPtrWithLengthHeader(ref singleKeyPtr, ref sksize, ref ptr, recvBufferPtr + bytesRead, 0))
return false;

//4. Destination DB
Expand Down
82 changes: 82 additions & 0 deletions libs/common/Parsing/RespParsingException.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

using System;
using System.Diagnostics.CodeAnalysis;

namespace Garnet.common.Parsing
{
/// <summary>
/// Exception wrapper for RESP parsing errors.
/// </summary>
public class RespParsingException : GarnetException
{
/// <summary>
/// Construct a new RESP parsing exception with the given message.
/// </summary>
/// <param name="message">Message that described the exception that has occurred.</param>
RespParsingException(String message) : base(message)
{
// Nothing...
}

/// <summary>
/// Throw an "Unexcepted Token" exception.
/// </summary>
/// <param name="c">The character that was unexpected.</param>
[DoesNotReturn]
public static void ThrowUnexpectedToken(byte c)
{
Throw($"Unexpected byte ({c}) in RESP command package.");
badrishc marked this conversation as resolved.
Show resolved Hide resolved
}

/// <summary>
/// Throw an invalid string length exception.
/// </summary>
/// <param name="len">The invalid string length.</param>
[DoesNotReturn]
public static void ThrowInvalidStringLength(long len)
{
Throw($"Invalid string length '{len}' in RESP command package.");
}

/// <summary>
/// Throw an invalid length exception.
/// </summary>
/// <param name="len">The invalid length.</param>
[DoesNotReturn]
public static void ThrowInvalidLength(long len)
{
Throw($"Invalid length '{len}' in RESP command package.");
}

/// <summary>
/// Throw NaN (not a number) exception.
/// </summary>
/// <param name="buffer">The input buffer that could not be converted into a number.</param>
[DoesNotReturn]
public static void ThrowNotANumber(ReadOnlySpan<byte> buffer)
{
var ascii = new System.Text.ASCIIEncoding();
lmaas marked this conversation as resolved.
Show resolved Hide resolved
Throw($"Unable to parse number: {ascii.GetString(buffer)}");
}

/// <summary>
/// Throw a exception indicating that an integer overflow has occurred.
/// </summary>
[DoesNotReturn]
public static void ThrowIntegerOverflow()
{
var ascii = new System.Text.ASCIIEncoding();
lmaas marked this conversation as resolved.
Show resolved Hide resolved
Throw($"Unable to parse integer. The given number is larger than allowed.");
}

/// <summary>
/// Throw helper that throws a RespParsingException.
/// </summary>
/// <param name="message">Exception message.</param>
[DoesNotReturn]
public static void Throw(string message) =>
throw new RespParsingException(message);
}
}
Loading