Skip to content
Permalink
Browse files

Avoid mod operator when fast alternative available (#27299)

  • Loading branch information...
benaadams authored and jkotas committed Oct 26, 2019
1 parent a2f10df commit e532bf642a3a381d53ff52c234f29deb7d11e7a0
@@ -341,3 +341,20 @@ License notice for Xorshift (Wikipedia)

https://en.wikipedia.org/wiki/Xorshift
License: https://en.wikipedia.org/wiki/Wikipedia:Text_of_Creative_Commons_Attribution-ShareAlike_3.0_Unported_License

License for fastmod (https://github.com/lemire/fastmod)
--------------------------------------

Copyright 2018 Daniel Lemire

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
@@ -51,6 +51,9 @@ private struct Entry

private int[]? _buckets;
private Entry[]? _entries;
#if BIT64
private ulong _fastModMultiplier;
#endif
private int _count;
private int _freeList;
private int _freeCount;
@@ -330,16 +333,15 @@ public virtual void GetObjectData(SerializationInfo info, StreamingContext conte
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key);
}

int[]? buckets = _buckets;
ref Entry entry = ref Unsafe.NullRef<Entry>();
if (buckets != null)
if (_buckets != null)
{
Debug.Assert(_entries != null, "expected entries to be != null");
IEqualityComparer<TKey>? comparer = _comparer;
if (comparer == null)
{
uint hashCode = (uint)key.GetHashCode();
int i = buckets[hashCode % (uint)buckets.Length];
int i = GetBucket(hashCode);
Entry[]? entries = _entries;
uint collisionCount = 0;
if (default(TKey)! != null) // TODO-NULLABLE: default(T) == null warning (https://github.com/dotnet/roslyn/issues/34757)
@@ -407,7 +409,7 @@ public virtual void GetObjectData(SerializationInfo info, StreamingContext conte
else
{
uint hashCode = (uint)comparer.GetHashCode(key);
int i = buckets[hashCode % (uint)buckets.Length];
int i = GetBucket(hashCode);
Entry[]? entries = _entries;
uint collisionCount = 0;
// Value in _buckets is 1-based; subtract 1 from i. We do it here so it fuses with the following conditional.
@@ -453,10 +455,16 @@ public virtual void GetObjectData(SerializationInfo info, StreamingContext conte
private int Initialize(int capacity)
{
int size = HashHelpers.GetPrime(capacity);
int[] buckets = new int[size];
Entry[] entries = new Entry[size];

// Assign member variables after both arrays allocated to guard against corruption from OOM if second fails
_freeList = -1;
_buckets = new int[size];
_entries = new Entry[size];
#if BIT64
_fastModMultiplier = HashHelpers.GetFastModMultiplier((uint)size);
#endif
_buckets = buckets;
_entries = entries;

return size;
}
@@ -481,7 +489,7 @@ private bool TryInsert(TKey key, TValue value, InsertionBehavior behavior)
uint hashCode = (uint)((comparer == null) ? key.GetHashCode() : comparer.GetHashCode(key));

uint collisionCount = 0;
ref int bucket = ref _buckets[hashCode % (uint)_buckets.Length];
ref int bucket = ref GetBucket(hashCode);
// Value in _buckets is 1-based
int i = bucket - 1;

@@ -625,7 +633,7 @@ private bool TryInsert(TKey key, TValue value, InsertionBehavior behavior)
if (count == entries.Length)
{
Resize();
bucket = ref _buckets[hashCode % (uint)_buckets.Length];
bucket = ref GetBucket(hashCode);
}
index = count;
_count = count + 1;
@@ -716,7 +724,6 @@ private void Resize(int newSize, bool forceNewHashCodes)
Debug.Assert(_entries != null, "_entries should be non-null");
Debug.Assert(newSize >= _entries.Length);

int[] buckets = new int[newSize];
Entry[] entries = new Entry[newSize];

int count = _count;
@@ -734,19 +741,23 @@ private void Resize(int newSize, bool forceNewHashCodes)
}
}

// Assign member variables after both arrays allocated to guard against corruption from OOM if second fails
_buckets = new int[newSize];
#if BIT64
_fastModMultiplier = HashHelpers.GetFastModMultiplier((uint)newSize);
#endif
for (int i = 0; i < count; i++)
{
if (entries[i].next >= -1)
{
uint bucket = entries[i].hashCode % (uint)newSize;
ref int bucket = ref GetBucket(entries[i].hashCode);
// Value in _buckets is 1-based
entries[i].next = buckets[bucket] - 1;
entries[i].next = bucket - 1;
// Value in _buckets is 1-based
buckets[bucket] = i + 1;
bucket = i + 1;
}
}

_buckets = buckets;
_entries = entries;
}

@@ -760,17 +771,16 @@ public bool Remove(TKey key)
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key);
}

int[]? buckets = _buckets;
Entry[]? entries = _entries;
if (buckets != null)
if (_buckets != null)
{
Debug.Assert(entries != null, "entries should be non-null");
Debug.Assert(_entries != null, "entries should be non-null");
uint collisionCount = 0;
uint hashCode = (uint)(_comparer?.GetHashCode(key) ?? key.GetHashCode());
uint bucket = hashCode % (uint)buckets.Length;
ref int bucket = ref GetBucket(hashCode);
Entry[]? entries = _entries;
int last = -1;
// Value in buckets is 1-based
int i = buckets[bucket] - 1;
int i = bucket - 1;
while (i >= 0)
{
ref Entry entry = ref entries[i];
@@ -780,7 +790,7 @@ public bool Remove(TKey key)
if (last < 0)
{
// Value in buckets is 1-based
buckets[bucket] = entry.next + 1;
bucket = entry.next + 1;
}
else
{
@@ -829,17 +839,16 @@ public bool Remove(TKey key, [MaybeNullWhen(false)] out TValue value)
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key);
}

int[]? buckets = _buckets;
Entry[]? entries = _entries;
if (buckets != null)
if (_buckets != null)
{
Debug.Assert(entries != null, "entries should be non-null");
Debug.Assert(_entries != null, "entries should be non-null");
uint collisionCount = 0;
uint hashCode = (uint)(_comparer?.GetHashCode(key) ?? key.GetHashCode());
uint bucket = hashCode % (uint)buckets.Length;
ref int bucket = ref GetBucket(hashCode);
Entry[]? entries = _entries;
int last = -1;
// Value in buckets is 1-based
int i = buckets[bucket] - 1;
int i = bucket - 1;
while (i >= 0)
{
ref Entry entry = ref entries[i];
@@ -849,7 +858,7 @@ public bool Remove(TKey key, [MaybeNullWhen(false)] out TValue value)
if (last < 0)
{
// Value in buckets is 1-based
buckets[bucket] = entry.next + 1;
bucket = entry.next + 1;
}
else
{
@@ -982,6 +991,7 @@ public int EnsureCapacity(int capacity)
_version++;
if (_buckets == null)
return Initialize(capacity);

int newSize = HashHelpers.GetPrime(capacity);
Resize(newSize, forceNewHashCodes: false);
return newSize;
@@ -1011,8 +1021,8 @@ public void TrimExcess(int capacity)
{
if (capacity < Count)
ThrowHelper.ThrowArgumentOutOfRangeException(ExceptionArgument.capacity);
int newSize = HashHelpers.GetPrime(capacity);

int newSize = HashHelpers.GetPrime(capacity);
Entry[]? oldEntries = _entries;
int currentCapacity = oldEntries == null ? 0 : oldEntries.Length;
if (newSize >= currentCapacity)
@@ -1022,7 +1032,6 @@ public void TrimExcess(int capacity)
_version++;
Initialize(newSize);
Entry[]? entries = _entries;
int[]? buckets = _buckets;
int count = 0;
for (int i = 0; i < oldCount; i++)
{
@@ -1031,11 +1040,11 @@ public void TrimExcess(int capacity)
{
ref Entry entry = ref entries![count];
entry = oldEntries[i];
uint bucket = hashCode % (uint)newSize;
ref int bucket = ref GetBucket(hashCode);
// Value in _buckets is 1-based
entry.next = buckets![bucket] - 1; // If we get here, we have entries, therefore buckets is not null.
entry.next = bucket - 1;
// Value in _buckets is 1-based
buckets[bucket] = count + 1;
bucket = count + 1;
count++;
}
}
@@ -1153,6 +1162,17 @@ void IDictionary.Remove(object key)
}
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private ref int GetBucket(uint hashCode)
{
int[] buckets = _buckets!;
#if BIT64
return ref buckets[HashHelpers.FastMod(hashCode, (uint)buckets.Length, _fastModMultiplier)];
#else
return ref buckets[hashCode % (uint)buckets.Length];
#endif
}

public struct Enumerator : IEnumerator<KeyValuePair<TKey, TValue>>,
IDictionaryEnumerator
{
@@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System.Diagnostics;
using System.Runtime.CompilerServices;

namespace System.Collections
{
@@ -28,12 +29,14 @@ internal static partial class HashHelpers
// h1(key) + i*h2(key), 0 <= i < size. h2 and the size must be relatively prime.
// We prefer the low computation costs of higher prime numbers over the increased
// memory allocation of a fixed prime number i.e. when right sizing a HashSet.
public static readonly int[] primes = {
private static readonly int[] s_primes =
{
3, 7, 11, 17, 23, 29, 37, 47, 59, 71, 89, 107, 131, 163, 197, 239, 293, 353, 431, 521, 631, 761, 919,
1103, 1327, 1597, 1931, 2333, 2801, 3371, 4049, 4861, 5839, 7013, 8419, 10103, 12143, 14591,
17519, 21023, 25229, 30293, 36353, 43627, 52361, 62851, 75431, 90523, 108631, 130363, 156437,
187751, 225307, 270371, 324449, 389357, 467237, 560689, 672827, 807403, 968897, 1162687, 1395263,
1674319, 2009191, 2411033, 2893249, 3471899, 4166287, 4999559, 5999471, 7199369 };
1674319, 2009191, 2411033, 2893249, 3471899, 4166287, 4999559, 5999471, 7199369
};

public static bool IsPrime(int candidate)
{
@@ -55,9 +58,8 @@ public static int GetPrime(int min)
if (min < 0)
throw new ArgumentException(SR.Arg_HTCapacityOverflow);

for (int i = 0; i < primes.Length; i++)
foreach (int prime in s_primes)
{
int prime = primes[i];
if (prime >= min)
return prime;
}
@@ -86,5 +88,24 @@ public static int ExpandPrime(int oldSize)

return GetPrime(newSize);
}

#if BIT64
public static ulong GetFastModMultiplier(uint divisor)
=> ulong.MaxValue / divisor + 1;

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static uint FastMod(uint value, uint divisor, ulong multiplier)
{
// Using fastmod from Daniel Lemire https://lemire.me/blog/2019/02/08/faster-remainders-when-the-divisor-is-a-constant-beating-compilers-and-libdivide/

ulong lowbits = multiplier * value;
// 64bit * 64bit => 128bit isn't currently supported by Math https://github.com/dotnet/corefx/issues/41822
// otherwise we'd want this to be (uint)Math.MultiplyHigh(lowbits, divisor)
uint high = (uint)((((ulong)(uint)lowbits * divisor >> 32) + (lowbits >> 32) * divisor) >> 32);

Debug.Assert(high == value % divisor);
return high;
}
#endif
}
}

0 comments on commit e532bf6

Please sign in to comment.
You can’t perform that action at this time.