Skip to content

Commit

Permalink
Mitigate risk of hash collision attacks
Browse files Browse the repository at this point in the history
  • Loading branch information
AArnott committed Jan 14, 2020
1 parent caf846f commit 74062a1
Show file tree
Hide file tree
Showing 11 changed files with 1,136 additions and 16 deletions.
8 changes: 4 additions & 4 deletions src/MessagePack.ImmutableCollection/Formatters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ protected override void Add(ImmutableDictionary<TKey, TValue>.Builder collection

protected override ImmutableDictionary<TKey, TValue>.Builder Create(int count)
{
return ImmutableDictionary.CreateBuilder<TKey, TValue>();
return ImmutableDictionary.CreateBuilder<TKey, TValue>(MessagePackSecurity.Active.GetEqualityComparer<TKey>());
}

protected override ImmutableDictionary<TKey, TValue>.Enumerator GetSourceEnumerator(ImmutableDictionary<TKey, TValue> source)
Expand All @@ -117,7 +117,7 @@ protected override ImmutableHashSet<T> Complete(ImmutableHashSet<T>.Builder inte

protected override ImmutableHashSet<T>.Builder Create(int count)
{
return ImmutableHashSet.CreateBuilder<T>();
return ImmutableHashSet.CreateBuilder<T>(MessagePackSecurity.Active.GetEqualityComparer<T>());
}

protected override ImmutableHashSet<T>.Enumerator GetSourceEnumerator(ImmutableHashSet<T> source)
Expand Down Expand Up @@ -242,7 +242,7 @@ protected override void Add(ImmutableDictionary<TKey, TValue>.Builder collection

protected override ImmutableDictionary<TKey, TValue>.Builder Create(int count)
{
return ImmutableDictionary.CreateBuilder<TKey, TValue>();
return ImmutableDictionary.CreateBuilder<TKey, TValue>(MessagePackSecurity.Active.GetEqualityComparer<TKey>());
}
}

Expand All @@ -260,7 +260,7 @@ protected override IImmutableSet<T> Complete(ImmutableHashSet<T>.Builder interme

protected override ImmutableHashSet<T>.Builder Create(int count)
{
return ImmutableHashSet.CreateBuilder<T>();
return ImmutableHashSet.CreateBuilder<T>(MessagePackSecurity.Active.GetEqualityComparer<T>());
}
}

Expand Down
36 changes: 36 additions & 0 deletions src/MessagePack/BitOperations.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
// <auto-generated />

#if !NETCOREAPP

using System.Runtime.CompilerServices;

// Some routines inspired by the Stanford Bit Twiddling Hacks by Sean Eron Anderson:
// http://graphics.stanford.edu/~seander/bithacks.html

namespace System.Numerics
{
/// <summary>
/// Utility methods for intrinsic bit-twiddling operations.
/// The methods use hardware intrinsics when available on the underlying platform,
/// otherwise they use optimized software fallbacks.
/// </summary>
internal static class BitOperations
{
/// <summary>
/// Rotates the specified value left by the specified number of bits.
/// Similar in behavior to the x86 instruction ROL.
/// </summary>
/// <param name="value">The value to rotate.</param>
/// <param name="offset">The number of bits to rotate by.
/// Any value outside the range [0..31] is treated as congruent mod 32.</param>
/// <returns>The rotated value.</returns>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static uint RotateLeft(uint value, int offset)
=> (value << offset) | (value >> (32 - offset));
}
}

#endif
8 changes: 4 additions & 4 deletions src/MessagePack/Formatters/CollectionFormatter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ protected override HashSet<T> Complete(HashSet<T> intermediateCollection)

protected override HashSet<T> Create(int count)
{
return new HashSet<T>();
return new HashSet<T>(MessagePackSecurity.Active.GetEqualityComparer<T>());
}

protected override HashSet<T>.Enumerator GetSourceEnumerator(HashSet<T> source)
Expand Down Expand Up @@ -847,7 +847,7 @@ public T Deserialize(byte[] bytes, int offset, IFormatterResolver formatterResol
var count = MessagePackBinary.ReadMapHeader(bytes, offset, out readSize);
offset += readSize;

var dict = new T();
var dict = CollectionHelpers<T, IEqualityComparer>.CreateHashCollection(count, MessagePackSecurity.Active.GetEqualityComparer());
for (int i = 0; i < count; i++)
{
var key = formatter.Deserialize(bytes, offset, formatterResolver, out readSize);
Expand Down Expand Up @@ -906,7 +906,7 @@ public IDictionary Deserialize(byte[] bytes, int offset, IFormatterResolver form
var count = MessagePackBinary.ReadMapHeader(bytes, offset, out readSize);
offset += readSize;

var dict = new Dictionary<object, object>(count);
var dict = new Dictionary<object, object>(count, MessagePackSecurity.Active.GetEqualityComparer<object>());
for (int i = 0; i < count; i++)
{
var key = formatter.Deserialize(bytes, offset, formatterResolver, out readSize);
Expand Down Expand Up @@ -1004,7 +1004,7 @@ protected override ISet<T> Complete(HashSet<T> intermediateCollection)

protected override HashSet<T> Create(int count)
{
return new HashSet<T>();
return new HashSet<T>(MessagePackSecurity.Active.GetEqualityComparer<T>());
}
}

Expand Down
53 changes: 53 additions & 0 deletions src/MessagePack/Formatters/CollectionHelpers`2.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright (c) All contributors. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Linq.Expressions;
using System.Reflection;

namespace MessagePack.Formatters
{
/// <summary>
/// Provides general helpers for creating collections (including dictionaries).
/// </summary>
/// <typeparam name="TCollection">The concrete type of collection to create.</typeparam>
/// <typeparam name="TEqualityComparer">The type of equality comparer that we would hope to pass into the collection's constructor.</typeparam>
internal static class CollectionHelpers<TCollection, TEqualityComparer>
where TCollection : new()
{
/// <summary>
/// The delegate that will create the collection, if the typical (int count, IEqualityComparer{T} equalityComparer) constructor was found.
/// </summary>
private static Func<int, TEqualityComparer, TCollection> collectionCreator;

/// <summary>
/// Initializes static members of the <see cref="CollectionHelpers{TCollection, TEqualityComparer}"/> class.
/// </summary>
/// <remarks>
/// Initializes a delegate that is optimized to create a collection of a given size and using the given equality comparer, if possible.
/// </remarks>
static CollectionHelpers()
{
var ctor = typeof(TCollection).GetTypeInfo().GetConstructor(new Type[] { typeof(int), typeof(TEqualityComparer) });
if (ctor != null)
{
ParameterExpression param1 = Expression.Parameter(typeof(int), "count");
ParameterExpression param2 = Expression.Parameter(typeof(TEqualityComparer), "equalityComparer");
NewExpression body = Expression.New(ctor, param1, param2);
collectionCreator = Expression.Lambda<Func<int, TEqualityComparer, TCollection>>(body, param1, param2).Compile();
}
}

/// <summary>
/// Initializes a new instance of the <typeparamref name="TCollection"/> collection.
/// </summary>
/// <param name="count">The number of elements the collection should be prepared to receive.</param>
/// <param name="equalityComparer">The equality comparer to initialize the collection with.</param>
/// <returns>The newly initialized collection.</returns>
/// <remarks>
/// Use of the <paramref name="count"/> and <paramref name="equalityComparer"/> are a best effort.
/// If we can't find a constructor on the collection in the expected shape, we'll just instantiate the collection with its default constructor.
/// </remarks>
internal static TCollection CreateHashCollection(int count, TEqualityComparer equalityComparer) => collectionCreator != null ? collectionCreator.Invoke(count, equalityComparer) : new TCollection();
}
}
12 changes: 6 additions & 6 deletions src/MessagePack/Formatters/DictionaryFormatter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ protected override void Add(Dictionary<TKey, TValue> collection, int index, TKey

protected override Dictionary<TKey, TValue> Create(int count)
{
return new Dictionary<TKey, TValue>(count);
return new Dictionary<TKey, TValue>(count, MessagePackSecurity.Active.GetEqualityComparer<TKey>());
}

protected override Dictionary<TKey, TValue>.Enumerator GetSourceEnumerator(Dictionary<TKey, TValue> source)
Expand All @@ -166,7 +166,7 @@ protected override void Add(TDictionary collection, int index, TKey key, TValue

protected override TDictionary Create(int count)
{
return new TDictionary();
return CollectionHelpers<TDictionary, IEqualityComparer<TKey>>.CreateHashCollection(count, MessagePackSecurity.Active.GetEqualityComparer<TKey>());
}
}

Expand All @@ -179,7 +179,7 @@ protected override void Add(Dictionary<TKey, TValue> collection, int index, TKey

protected override Dictionary<TKey, TValue> Create(int count)
{
return new Dictionary<TKey, TValue>(count);
return new Dictionary<TKey, TValue>(count, MessagePackSecurity.Active.GetEqualityComparer<TKey>());
}

protected override IDictionary<TKey, TValue> Complete(Dictionary<TKey, TValue> intermediateCollection)
Expand Down Expand Up @@ -238,7 +238,7 @@ protected override void Add(Dictionary<TKey, TValue> collection, int index, TKey

protected override Dictionary<TKey, TValue> Create(int count)
{
return new Dictionary<TKey, TValue>(count);
return new Dictionary<TKey, TValue>(count, MessagePackSecurity.Active.GetEqualityComparer<TKey>());
}
}

Expand All @@ -256,7 +256,7 @@ protected override void Add(Dictionary<TKey, TValue> collection, int index, TKey

protected override Dictionary<TKey, TValue> Create(int count)
{
return new Dictionary<TKey, TValue>(count);
return new Dictionary<TKey, TValue>(count, MessagePackSecurity.Active.GetEqualityComparer<TKey>());
}
}

Expand All @@ -270,7 +270,7 @@ protected override void Add(ConcurrentDictionary<TKey, TValue> collection, int i
protected override ConcurrentDictionary<TKey, TValue> Create(int count)
{
// concurrent dictionary can't access defaultConcurrecyLevel so does not use count overload.
return new ConcurrentDictionary<TKey, TValue>();
return new ConcurrentDictionary<TKey, TValue>(MessagePackSecurity.Active.GetEqualityComparer<TKey>());
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/MessagePack/Formatters/PrimitiveObjectFormatter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ public sealed class PrimitiveObjectFormatter : IMessagePackFormatter<object>

static readonly Dictionary<Type, int> typeToJumpCode = new Dictionary<Type, int>()
{
// When adding types whose size exceeds 32-bits, add support in MessagePackSecurity.GetHashCollisionResistantEqualityComparer<T>()
{ typeof(Boolean), 0 },
{ typeof(Char), 1 },
{ typeof(SByte), 2 },
Expand Down Expand Up @@ -220,7 +221,7 @@ public object Deserialize(byte[] bytes, int offset, IFormatterResolver formatter
offset += readSize;

var objectFormatter = formatterResolver.GetFormatter<object>();
var hash = new Dictionary<object, object>(length);
var hash = new Dictionary<object, object>(length, MessagePackSecurity.Active.GetEqualityComparer<object>());
for (int i = 0; i < length; i++)
{
var key = objectFormatter.Deserialize(bytes, offset, formatterResolver, out readSize);
Expand Down
Loading

0 comments on commit 74062a1

Please sign in to comment.