Skip to content

Commit

Permalink
Close connection when invalid framing is detected
Browse files Browse the repository at this point in the history
  • Loading branch information
ReubenBond committed Oct 26, 2023
1 parent 81333e7 commit 0289290
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 13 deletions.
41 changes: 41 additions & 0 deletions src/Orleans.Core/Messaging/InvalidMessageFrameException.cs
@@ -0,0 +1,41 @@
#nullable enable

using System;
using System.Runtime.Serialization;

namespace Orleans.Runtime.Messaging;

/// <summary>
/// Indicates that a message frame is invalid, either when sending a messare or receiving a message.
/// </summary>
[GenerateSerializer]
public sealed class InvalidMessageFrameException : OrleansException
{
/// <summary>
/// Initializes a new instance of the <see cref="InvalidMessageFrameException"/> class.
/// </summary>
public InvalidMessageFrameException()
{
}

/// <summary>
/// Initializes a new instance of the <see cref="InvalidMessageFrameException"/> class.
/// </summary>
/// <param name="message">The message that describes the error.</param>
public InvalidMessageFrameException(string message) : base(message)
{
}

/// <summary>
/// Initializes a new instance of the <see cref="InvalidMessageFrameException"/> class.
/// </summary>
/// <param name="message">The message that describes the error.</param>
/// <param name="innerException">The exception that is the cause of the current exception.</param>
public InvalidMessageFrameException(string message, Exception innerException) : base(message, innerException)
{
}

protected InvalidMessageFrameException(SerializationInfo info, StreamingContext context) : base(info, context)
{
}
}
5 changes: 3 additions & 2 deletions src/Orleans.Core/Messaging/MessageSerializer.cs
Expand Up @@ -189,6 +189,7 @@ private ResponseCodec GetRawCodec(Type fieldType)
}

var bodyLength = bufferWriter.CommittedBytes - headerLength;

// Before completing, check lengths
ThrowIfLengthsInvalid(headerLength, bodyLength);

Expand Down Expand Up @@ -217,8 +218,8 @@ private void ThrowIfLengthsInvalid(int headerLength, int bodyLength)
if ((uint)bodyLength > (uint)_maxBodyLength) ThrowInvalidBodyLength(bodyLength);
}

private void ThrowInvalidHeaderLength(int headerLength) => throw new OrleansException($"Invalid header size: {headerLength} (max configured value is {_maxHeaderLength}, see {nameof(MessagingOptions.MaxMessageHeaderSize)})");
private void ThrowInvalidBodyLength(int bodyLength) => throw new OrleansException($"Invalid body size: {bodyLength} (max configured value is {_maxBodyLength}, see {nameof(MessagingOptions.MaxMessageBodySize)})");
private void ThrowInvalidHeaderLength(int headerLength) => throw new InvalidMessageFrameException($"Invalid header size: {headerLength} (max configured value is {_maxHeaderLength}, see {nameof(MessagingOptions.MaxMessageHeaderSize)})");
private void ThrowInvalidBodyLength(int bodyLength) => throw new InvalidMessageFrameException($"Invalid body size: {bodyLength} (max configured value is {_maxBodyLength}, see {nameof(MessagingOptions.MaxMessageBodySize)})");

private void Serialize<TBufferWriter>(ref Writer<TBufferWriter> writer, Message value, PackedHeaders headers) where TBufferWriter : IBufferWriter<byte>
{
Expand Down
30 changes: 19 additions & 11 deletions src/Orleans.Core/Networking/Connection.cs
Expand Up @@ -311,7 +311,7 @@ private async Task ProcessIncoming()
ThreadPool.UnsafeQueueUserWorkItem(handler, preferLocal: true);
}
}
catch (Exception exception) when (this.HandleReceiveMessageFailure(message, exception))
catch (Exception exception) when (HandleReceiveMessageFailure(message, exception))
{
}
} while (requiredBytes == 0);
Expand Down Expand Up @@ -374,9 +374,8 @@ private async Task ProcessOutgoing()
message = null;
}
}
catch (Exception exception) when (message != default)
catch (Exception exception) when (HandleSendMessageFailure(message, exception))
{
this.OnMessageSerializationFailure(message, exception);
}

var flushResult = await output.FlushAsync();
Expand Down Expand Up @@ -443,15 +442,15 @@ private static EndPoint NormalizeEndpoint(EndPoint endpoint)
/// <returns><see langword="true"/> if the exception should not be caught and <see langword="false"/> if it should be caught.</returns>
private bool HandleReceiveMessageFailure(Message message, Exception exception)
{
this.Log.LogWarning(
this.Log.LogError(
exception,
"Exception reading message {Message} from remote endpoint {Remote} to local endpoint {Local}",
message,
this.RemoteEndPoint,
this.LocalEndPoint);

// If deserialization completely failed, rethrow the exception so that it can be handled at another level.
if (message is null)
if (message is null || exception is InvalidMessageFrameException)
{
// Returning false here informs the caller that the exception should not be caught.
return false;
Expand Down Expand Up @@ -485,16 +484,23 @@ private bool HandleReceiveMessageFailure(Message message, Exception exception)
return true;
}

private void OnMessageSerializationFailure(Message message, Exception exception)
private bool HandleSendMessageFailure(Message message, Exception exception)
{
// we only get here if we failed to serialize the msg (or any other catastrophic failure).
// We get here if we failed to serialize the msg (or any other catastrophic failure).
// Request msg fails to serialize on the sender, so we just enqueue a rejection msg.
// Response msg fails to serialize on the responding silo, so we try to send an error response back.
this.Log.LogWarning(
(int)ErrorCode.Messaging_SerializationError,
this.Log.LogError(
exception,
"Unexpected error serializing message {Message}",
message);
"Exception sending message {Message} to remote endpoint {Remote} from local endpoint {Local}",
message,
this.RemoteEndPoint,
this.LocalEndPoint);

if (message is null || exception is InvalidMessageFrameException)
{
// Returning false here informs the caller that the exception should not be caught.
return false;
}

MessagingInstruments.OnFailedSentMessage(message);

Expand Down Expand Up @@ -526,6 +532,8 @@ private void OnMessageSerializationFailure(Message message, Exception exception)

MessagingInstruments.OnDroppedSentMessage(message);
}

return true;
}

private sealed class MessageHandlerPoolPolicy : PooledObjectPolicy<MessageHandler>
Expand Down

0 comments on commit 0289290

Please sign in to comment.