Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorize Enumerable.Range initialization, take 2 #87992

Merged
merged 3 commits into from Jul 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
41 changes: 35 additions & 6 deletions src/libraries/System.Linq/src/System/Linq/Range.SpeedOpt.cs
Expand Up @@ -2,6 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

namespace System.Linq
{
Expand All @@ -16,15 +19,17 @@ public override IEnumerable<TResult> Select<TResult>(Func<int, TResult> selector

public int[] ToArray()
{
int[] array = new int[_end - _start];
Fill(array, _start);
int start = _start;
int[] array = new int[_end - start];
Fill(array, start);
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
return array;
}

public List<int> ToList()
{
List<int> list = new List<int>(_end - _start);
Fill(SetCountAndGetSpan(list, _end - _start), _start);
(int start, int end) = (_start, _end);
List<int> list = new List<int>(end - start);
Fill(SetCountAndGetSpan(list, end - start), start);
return list;
}

Expand All @@ -33,9 +38,33 @@ public List<int> ToList()

private static void Fill(Span<int> destination, int value)
{
for (int i = 0; i < destination.Length; i++, value++)
ref int pos = ref MemoryMarshal.GetReference(destination);
ref int end = ref Unsafe.Add(ref pos, destination.Length);

if (Vector.IsHardwareAccelerated &&
Vector<int>.Count <= 8 &&
destination.Length >= Vector<int>.Count)
{
Vector<int> init = new Vector<int>((ReadOnlySpan<int>)new int[] { 0, 1, 2, 3, 4, 5, 6, 7 });
Vector<int> current = new Vector<int>(value) + init;
Vector<int> increment = new Vector<int>(Vector<int>.Count);

ref int oneVectorFromEnd = ref Unsafe.Subtract(ref end, Vector<int>.Count);
do
{
current.StoreUnsafe(ref pos);
current += increment;
pos = ref Unsafe.Add(ref pos, Vector<int>.Count);
}
while (!Unsafe.IsAddressGreaterThan(ref pos, ref oneVectorFromEnd));

value = current[0];
}

while (Unsafe.IsAddressLessThan(ref pos, ref end))
{
destination[i] = value;
pos = value++;
pos = ref Unsafe.Add(ref pos, 1);
}
}

Expand Down
21 changes: 13 additions & 8 deletions src/libraries/System.Linq/tests/RangeTests.cs
@@ -1,11 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Xunit;

namespace System.Linq.Tests
Expand All @@ -26,11 +22,20 @@ public void Range_ProduceCorrectSequence()
Assert.Equal(100, expected);
}

[Fact]
public void Range_ToArray_ProduceCorrectResult()
public static IEnumerable<object[]> Range_ToArray_ProduceCorrectResult_MemberData()
{
for (int i = 0; i < 64; i++)
{
yield return new object[] { i };
}
}

[Theory]
[MemberData(nameof(Range_ToArray_ProduceCorrectResult_MemberData))]
public void Range_ToArray_ProduceCorrectResult(int length)
{
var array = Enumerable.Range(1, 100).ToArray();
Assert.Equal(100, array.Length);
var array = Enumerable.Range(1, length).ToArray();
Assert.Equal(length, array.Length);
for (var i = 0; i < array.Length; i++)
Assert.Equal(i + 1, array[i]);
}
Expand Down