Skip to content

Commit

Permalink
Add recursion guard and exception type filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
AArnott committed Nov 17, 2021
1 parent 09258cf commit 49a0269
Show file tree
Hide file tree
Showing 8 changed files with 416 additions and 55 deletions.
77 changes: 77 additions & 0 deletions src/StreamJsonRpc/ExceptionSettings.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

namespace StreamJsonRpc
{
using System;
using System.Runtime.Serialization;
using Microsoft;

/// <summary>
/// Contains security-related settings that influence how errors are serialized and deserialized.
/// </summary>
public abstract record ExceptionSettings
{
/// <summary>
/// The recommended settings for use when communicating with a trusted party.
/// </summary>
public static readonly ExceptionSettings TrustedData = new DefaultExceptionSettings(int.MaxValue, trusted: true);

/// <summary>
/// The recommended settings for use when communicating with an untrusted party.
/// </summary>
public static readonly ExceptionSettings UntrustedData = new DefaultExceptionSettings(50, trusted: false);

/// <summary>
/// Initializes a new instance of the <see cref="ExceptionSettings"/> class.
/// </summary>
/// <param name="recursionLimit">The maximum number of nested errors to serialize or deserialize.</param>
protected ExceptionSettings(int recursionLimit)
{
Requires.Range(recursionLimit > 0, nameof(recursionLimit));
this.RecursionLimit = recursionLimit;
}

/// <summary>
/// Gets the maximum number of nested errors to serialize or deserialize.
/// </summary>
/// <value>The default value is 50.</value>
/// <remarks>
/// This can help mitigate DoS attacks from unbounded recursion that otherwise error deserialization
/// becomes perhaps uniquely vulnerable to since the data structure allows recursion.
/// </remarks>
public int RecursionLimit { get; init; }

/// <summary>
/// Tests whether a type can be deserialized as part of deserializing an exception.
/// </summary>
/// <param name="type">The type that may be deserialized.</param>
/// <returns><see langword="true" /> if the type is safe to deserialize; <see langword="false" /> otherwise.</returns>
/// <remarks>
/// <para>
/// The default implementation returns <see langword="true" /> for all types in <see cref="TrustedData"/>-based instances;
/// or for <see cref="UntrustedData"/>-based instances will return <see langword="true" /> for
/// <see cref="Exception"/>-derived types that are expected to be safe to deserialize.
/// </para>
/// <para>
/// <see cref="Exception"/>-derived types that may deserialize data that would be unsafe coming from an untrusted party <em>should</em>
/// consider the <see cref="StreamingContext"/> passed to their deserializing constructor and skip deserializing of potentitally
/// dangerous data when <see cref="StreamingContext.State"/> includes the <see cref="StreamingContextStates.Remoting"/> flag.
/// </para>
/// </remarks>
public abstract bool CanDeserialize(Type type);

private record DefaultExceptionSettings : ExceptionSettings
{
private readonly bool trusted;

public DefaultExceptionSettings(int recursionLimit, bool trusted)
: base(recursionLimit)
{
this.trusted = trusted;
}

public override bool CanDeserialize(Type type) => this.trusted || typeof(Exception).IsAssignableFrom(type);
}
}
}
94 changes: 66 additions & 28 deletions src/StreamJsonRpc/JsonMessageFormatter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1589,6 +1589,11 @@ public object Convert(object value, TypeCode typeCode)

private class ExceptionConverter : JsonConverter<Exception?>
{
/// <summary>
/// Tracks recursion count while serializing or deserializing an exception.
/// </summary>
private static ThreadLocal<int> exceptionRecursionCounter = new();

private readonly JsonMessageFormatter formatter;

internal ExceptionConverter(JsonMessageFormatter formatter)
Expand All @@ -1604,37 +1609,53 @@ internal ExceptionConverter(JsonMessageFormatter formatter)
return null;
}

if (reader.TokenType != JsonToken.StartObject)
exceptionRecursionCounter.Value++;
try
{
throw new InvalidOperationException("Expected a StartObject token.");
}
if (reader.TokenType != JsonToken.StartObject)
{
throw new InvalidOperationException("Expected a StartObject token.");
}

SerializationInfo? info = new SerializationInfo(objectType, new JsonConverterFormatter(serializer));
while (reader.Read())
{
if (reader.TokenType == JsonToken.EndObject)
if (exceptionRecursionCounter.Value > this.formatter.rpc.ExceptionOptions.RecursionLimit)
{
break;
// Exception recursion has gone too deep. Skip this value and return null as if there were no inner exception.
// Note that in skipping, the parser may use recursion internally and may still throw if its own limits are exceeded.
reader.Skip();
return null;
}

if (reader.TokenType == JsonToken.PropertyName)
SerializationInfo? info = new SerializationInfo(objectType, new JsonConverterFormatter(serializer));
while (reader.Read())
{
string name = (string)reader.Value!;
if (!reader.Read())
if (reader.TokenType == JsonToken.EndObject)
{
throw new EndOfStreamException();
break;
}

JToken? value = reader.TokenType == JsonToken.Null ? null : JToken.Load(reader);
info.AddSafeValue(name, value);
}
else
{
throw new InvalidOperationException("Expected PropertyName token but encountered: " + reader.TokenType);
if (reader.TokenType == JsonToken.PropertyName)
{
string name = (string)reader.Value!;
if (!reader.Read())
{
throw new EndOfStreamException();
}

JToken? value = reader.TokenType == JsonToken.Null ? null : JToken.Load(reader);
info.AddSafeValue(name, value);
}
else
{
throw new InvalidOperationException("Expected PropertyName token but encountered: " + reader.TokenType);
}
}
}

return ExceptionSerializationHelpers.Deserialize<Exception>(this.formatter.rpc, info, this.formatter.rpc?.TraceSource);
return ExceptionSerializationHelpers.Deserialize<Exception>(this.formatter.rpc, info, this.formatter.rpc?.TraceSource);
}
finally
{
exceptionRecursionCounter.Value--;
}
}

public override void WriteJson(JsonWriter writer, Exception? value, JsonSerializer serializer)
Expand All @@ -1645,16 +1666,33 @@ public override void WriteJson(JsonWriter writer, Exception? value, JsonSerializ
return;
}

SerializationInfo info = new SerializationInfo(value.GetType(), new JsonConverterFormatter(serializer));
ExceptionSerializationHelpers.Serialize(value, info);
writer.WriteStartObject();
foreach (SerializationEntry element in info.GetSafeMembers())
// We have to guard our own recursion because the serializer has no visibility into inner exceptions.
// Each exception in the russian doll is a new serialization job from its perspective.
exceptionRecursionCounter.Value++;
try
{
writer.WritePropertyName(element.Name);
serializer.Serialize(writer, element.Value);
}
if (exceptionRecursionCounter.Value > this.formatter.rpc?.ExceptionOptions.RecursionLimit)
{
// Exception recursion has gone too deep. Skip this value and write null as if there were no inner exception.
writer.WriteNull();
return;
}

writer.WriteEndObject();
SerializationInfo info = new SerializationInfo(value.GetType(), new JsonConverterFormatter(serializer));
ExceptionSerializationHelpers.Serialize(value, info);
writer.WriteStartObject();
foreach (SerializationEntry element in info.GetSafeMembers())
{
writer.WritePropertyName(element.Name);
serializer.Serialize(writer, element.Value);
}

writer.WriteEndObject();
}
finally
{
exceptionRecursionCounter.Value--;
}
}
}

Expand Down
19 changes: 19 additions & 0 deletions src/StreamJsonRpc/JsonRpc.cs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ public class JsonRpc : IDisposableObservable, IJsonRpcFormatterCallbacks, IJsonR
/// </summary>
private ExceptionProcessing exceptionStrategy;

/// <summary>
/// Backing field for <see cref="ExceptionOptions"/>.
/// </summary>
private ExceptionSettings exceptionSettings = ExceptionSettings.UntrustedData;

/// <summary>
/// Backing field for the <see cref="IJsonRpcFormatterCallbacks.RequestTransmissionAborted"/> event.
/// </summary>
Expand Down Expand Up @@ -549,6 +554,20 @@ public ExceptionProcessing ExceptionStrategy
}
}

/// <summary>
/// Gets or sets the settings to use for serializing/deserializing exceptions.
/// </summary>
public ExceptionSettings ExceptionOptions
{
get => this.exceptionSettings;
set
{
Requires.NotNull(value, nameof(value));
this.ThrowIfConfigurationLocked();
this.exceptionSettings = value;
}
}

/// <summary>
/// Gets or sets the strategy for propagating activity IDs over RPC.
/// </summary>
Expand Down
83 changes: 63 additions & 20 deletions src/StreamJsonRpc/MessagePackFormatter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ namespace StreamJsonRpc
using System.Runtime.ExceptionServices;
using System.Runtime.Serialization;
using System.Text;
using System.Threading;
using MessagePack;
using MessagePack.Formatters;
using MessagePack.Resolvers;
Expand Down Expand Up @@ -1494,6 +1495,15 @@ public void Serialize(ref MessagePackWriter writer, T? value, MessagePackSeriali
/// </remarks>
private class MessagePackExceptionResolver : IFormatterResolver
{
/// <summary>
/// Tracks recursion count while serializing or deserializing an exception.
/// </summary>
/// <devremarks>
/// This is placed here (<em>outside</em> the generic <see cref="ExceptionFormatter{T}"/> class)
/// so that it's one counter shared across all exception types that may be serialized or deserialized.
/// </devremarks>
private static ThreadLocal<int> exceptionRecursionCounter = new();

private readonly object[] formatterActivationArgs;

private readonly Dictionary<Type, object?> formatterCache = new Dictionary<Type, object?>();
Expand Down Expand Up @@ -1545,24 +1555,42 @@ public T Deserialize(ref MessagePackReader reader, MessagePackSerializerOptions
return null;
}

var info = new SerializationInfo(typeof(T), new MessagePackFormatterConverter(options));
int memberCount = reader.ReadMapHeader();
for (int i = 0; i < memberCount; i++)
// We have to guard our own recursion because the serializer has no visibility into inner exceptions.
// Each exception in the russian doll is a new serialization job from its perspective.
exceptionRecursionCounter.Value++;
try
{
string name = reader.ReadString();
if (exceptionRecursionCounter.Value > this.formatter.rpc.ExceptionOptions.RecursionLimit)
{
// Exception recursion has gone too deep. Skip this value and return null as if there were no inner exception.
// Note that in skipping, the parser may use recursion internally and may still throw if its own limits are exceeded.
reader.Skip();
return null;
}

var info = new SerializationInfo(typeof(T), new MessagePackFormatterConverter(options));
int memberCount = reader.ReadMapHeader();
for (int i = 0; i < memberCount; i++)
{
string name = reader.ReadString();

// SerializationInfo.GetValue(string, typeof(object)) does not call our formatter,
// so the caller will get a boxed RawMessagePack struct in that case.
// Although we can't do much about *that* in general, we can at least ensure that null values
// are represented as null instead of this boxed struct.
var value = reader.TryReadNil() ? null : (object)RawMessagePack.ReadRaw(ref reader, false);
// SerializationInfo.GetValue(string, typeof(object)) does not call our formatter,
// so the caller will get a boxed RawMessagePack struct in that case.
// Although we can't do much about *that* in general, we can at least ensure that null values
// are represented as null instead of this boxed struct.
var value = reader.TryReadNil() ? null : (object)RawMessagePack.ReadRaw(ref reader, false);

info.AddSafeValue(name, value);
}
info.AddSafeValue(name, value);
}

var resolverWrapper = options.Resolver as ResolverWrapper;
Report.If(resolverWrapper is null, "Unexpected resolver type.");
return ExceptionSerializationHelpers.Deserialize<T>(this.formatter.rpc, info, resolverWrapper?.Formatter.rpc?.TraceSource);
var resolverWrapper = options.Resolver as ResolverWrapper;
Report.If(resolverWrapper is null, "Unexpected resolver type.");
return ExceptionSerializationHelpers.Deserialize<T>(this.formatter.rpc, info, resolverWrapper?.Formatter.rpc?.TraceSource);
}
finally
{
exceptionRecursionCounter.Value--;
}
}

public void Serialize(ref MessagePackWriter writer, T? value, MessagePackSerializerOptions options)
Expand All @@ -1573,13 +1601,28 @@ public void Serialize(ref MessagePackWriter writer, T? value, MessagePackSeriali
return;
}

var info = new SerializationInfo(typeof(T), new MessagePackFormatterConverter(options));
ExceptionSerializationHelpers.Serialize(value, info);
writer.WriteMapHeader(info.GetSafeMemberCount());
foreach (SerializationEntry element in info.GetSafeMembers())
exceptionRecursionCounter.Value++;
try
{
if (exceptionRecursionCounter.Value > this.formatter.rpc?.ExceptionOptions.RecursionLimit)
{
// Exception recursion has gone too deep. Skip this value and write null as if there were no inner exception.
writer.WriteNil();
return;
}

var info = new SerializationInfo(typeof(T), new MessagePackFormatterConverter(options));
ExceptionSerializationHelpers.Serialize(value, info);
writer.WriteMapHeader(info.GetSafeMemberCount());
foreach (SerializationEntry element in info.GetSafeMembers())
{
writer.Write(element.Name);
MessagePackSerializer.Serialize(element.ObjectType, ref writer, element.Value, options);
}
}
finally
{
writer.Write(element.Name);
MessagePackSerializer.Serialize(element.ObjectType, ref writer, element.Value, options);
exceptionRecursionCounter.Value--;
}
}
}
Expand Down
Loading

0 comments on commit 49a0269

Please sign in to comment.