Skip to content
This repository has been archived by the owner on Jan 23, 2023. It is now read-only.

Ensure the selector gets run during Count. #14435

Merged
merged 3 commits into from Dec 27, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
104 changes: 99 additions & 5 deletions src/System.Linq/src/System/Linq/Select.cs
Expand Up @@ -155,9 +155,41 @@ public TResult[] ToArray()
return builder.ToArray();
}

public List<TResult> ToList() => new List<TResult>(this);
public List<TResult> ToList()
{
var list = new List<TResult>();

foreach (TSource item in _source)
{
list.Add(_selector(item));
}

return list;
}

public int GetCount(bool onlyIfCheap)
{
// In case someone uses Count() to force evaluation of
// the selector, run it provided `onlyIfCheap` is false.

if (onlyIfCheap)
{
return -1;
}

public int GetCount(bool onlyIfCheap) => onlyIfCheap ? -1 : _source.Count();
int count = 0;

foreach (TSource item in _source)
{
_selector(item);
checked
{
count++;
}
}

return count;
}
}

internal sealed class SelectArrayIterator<TSource, TResult> : Iterator<TResult>, IPartition<TResult>
Expand Down Expand Up @@ -226,6 +258,17 @@ public List<TResult> ToList()

public int GetCount(bool onlyIfCheap)
{
// In case someone uses Count() to force evaluation of
// the selector, run it provided `onlyIfCheap` is false.

if (!onlyIfCheap)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason for this is obscure and would likely benefit from being commented on. Without context this looks like pointless busy work that should be deleted to improve efficiency.

{
foreach (TSource item in _source)
{
_selector(item);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, if we just returned -1 in this case then the calling Count() method would do pretty much the above. I would imagine this would be slightly faster, but only slightly (that is just a guess though). Do we need the extra code here?

Copy link
Contributor Author

@jamesqo jamesqo Dec 12, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JonHanna During #12703 when I had optimized Where.Select and the issue of running these selectors had come up, I had originally had a EnumerableHelpers.Count function that Enumerable.Count would call after checking for Linq interfaces (just like what ToArray does today). This was the code I had written for the iterators

// Leave it to Count to iterate through us
public int GetCount(bool onlyIfCheap) => onlyIfCheap ? -1 : EnumerableHelpers.Count(this);

However, @stephentoub argued against this. See here for context: #12703 (comment) I ended up writing everything inline for GetCount in those iterators. So I just employed the same strategy here.


I would imagine this would be slightly faster, but only slightly

Virtual method calls are pretty expensive; going from 2 -> 3 virtual method calls (MoveNext & Current to MoveNext, Current & MoveNext) should probably be more than half of a 33% difference. I haven't measured either, but I'm not sure if it would be wise to regress perf here regardless.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough.


return _source.Length;
}

Expand Down Expand Up @@ -351,7 +394,20 @@ public List<TResult> ToList()

public int GetCount(bool onlyIfCheap)
{
return _source.Count;
// In case someone uses Count() to force evaluation of
// the selector, run it provided `onlyIfCheap` is false.

int count = _source.Count;

if (!onlyIfCheap)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise, your reason for doing this isn't obvious from the code alone, so should be commented on. And likely elsewhere, so I won't call out other cases.

{
for (int i = 0; i < count; i++)
{
_selector(_source[i]);
}
}

return count;
}

public IPartition<TResult> Skip(int count)
Expand Down Expand Up @@ -491,7 +547,20 @@ public List<TResult> ToList()

public int GetCount(bool onlyIfCheap)
{
return _source.Count;
// In case someone uses Count() to force evaluation of
// the selector, run it provided `onlyIfCheap` is false.

int count = _source.Count;

if (!onlyIfCheap)
{
for (int i = 0; i < count; i++)
{
_selector(_source[i]);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above, could we just return -1 here and let the caller do this?

}

return count;
}

public IPartition<TResult> Skip(int count)
Expand Down Expand Up @@ -703,6 +772,17 @@ public List<TResult> ToList()

public int GetCount(bool onlyIfCheap)
{
// In case someone uses Count() to force evaluation of
// the selector, run it provided `onlyIfCheap` is false.

if (!onlyIfCheap)
{
foreach (TSource item in _source)
{
_selector(item);
}
}

return _source.GetCount(onlyIfCheap);
}
}
Expand Down Expand Up @@ -852,7 +932,21 @@ public List<TResult> ToList()

public int GetCount(bool onlyIfCheap)
{
return Count;
// In case someone uses Count() to force evaluation of
// the selector, run it provided `onlyIfCheap` is false.

int count = Count;

if (!onlyIfCheap)
{
int end = _minIndexInclusive + count;
for (int i = _minIndexInclusive; i != end; ++i)
Copy link
Contributor Author

@jamesqo jamesqo Dec 11, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An issue came up here that I wasn't sure best how to approach:

  • lazyEnumerable.Select(i => i).Skip(1).Count() runs the selector lazyEnumerable.Count() times, because Select.Skip on a lazy enumerable isn't specially recognized and the selector gets run on the first item.

  • list.Select(i => i).Skip(1).Count() is specially recognized, however, and it returns a SelectListPartitionIterator which does not run the selector on the first item.

One way to fix this would be to start from 0 instead of _minIndexInclusive here. However, if we do that, we break Skip(1).Select(i => i) which also ends up here; patterns like those should definitely not run the selector on the first item.

Ideally, we would somehow have a way to differentiate if Skip or Select was called first from within the iterator, and start from _minIndexInclusive or 0 accordingly. But then we might need to add an extra field...

cc @JonHanna

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm inclined to think that we don't care.

A scenario that was called out as important is someone calling Count() on a Select result specifically to trigger side effects in selectors. (Not a sound practice IMO, but that's another matter). Such a use would be stymied by optimisations that skipped the selectors, and so we avoid such optimisation.

A user who skips something has indicated indifference to that thing. As such I'm inclined to think it doesn't matter whether we run n or n-1 selectors. Indeed, I'm happy running 0 in this case and just calculating what the result of Count() would be.

Others may not be as willing to go with quite so observable a difference to .Net4.6 Framework behaviour though. TBH if this was my PR I'd be taking the fastest route but prepared to back down if I failed to convince on that point.

{
_selector(_source[i]);
}
}

return count;
}
}
}
Expand Down
9 changes: 9 additions & 0 deletions src/System.Linq/src/System/Linq/Where.cs
Expand Up @@ -434,6 +434,9 @@ public override Iterator<TResult> Clone()

public int GetCount(bool onlyIfCheap)
{
// In case someone uses Count() to force evaluation of
// the selector, run it provided `onlyIfCheap` is false.

if (onlyIfCheap)
{
return -1;
Expand Down Expand Up @@ -536,6 +539,9 @@ public override Iterator<TResult> Clone()

public int GetCount(bool onlyIfCheap)
{
// In case someone uses Count() to force evaluation of
// the selector, run it provided `onlyIfCheap` is false.

if (onlyIfCheap)
{
return -1;
Expand Down Expand Up @@ -658,6 +664,9 @@ public override void Dispose()

public int GetCount(bool onlyIfCheap)
{
// In case someone uses Count() to force evaluation of
// the selector, run it provided `onlyIfCheap` is false.

if (onlyIfCheap)
{
return -1;
Expand Down
60 changes: 60 additions & 0 deletions src/System.Linq/tests/SelectTests.cs
Expand Up @@ -1168,5 +1168,65 @@ public static IEnumerable<object[]> MoveNextAfterDisposeData()
yield return new object[] { new int[1] };
yield return new object[] { Enumerable.Range(1, 30) };
}

[Theory]
[MemberData(nameof(RunSelectorDuringCountData))]
public void RunSelectorDuringCount(IEnumerable<int> source)
{
int timesRun = 0;
var selected = source.Select(i => timesRun++);
selected.Count();

Assert.Equal(source.Count(), timesRun);
}

// [Theory]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Disabled currently because the first assert is giving inconsistent results. See comment above

[MemberData(nameof(RunSelectorDuringCountData))]
public void RunSelectorDuringPartitionCount(IEnumerable<int> source)
{
int timesRun = 0;

var selected = source.Select(i => timesRun++);

if (source.Any())
{
selected.Skip(1).Count();
Assert.Equal(source.Count() - 1, timesRun);

selected.Take(source.Count() - 1).Count();
Assert.Equal(source.Count() * 2 - 2, timesRun);
}
}

public static IEnumerable<object[]> RunSelectorDuringCountData()
{
var transforms = new Func<IEnumerable<int>, IEnumerable<int>>[]
{
e => e,
e => ForceNotCollection(e),
e => ForceNotCollection(e).Skip(1),
e => ForceNotCollection(e).Where(i => true),
e => e.ToArray().Where(i => true),
e => e.ToList().Where(i => true),
e => new LinkedList<int>(e).Where(i => true),
e => e.Select(i => i),
e => e.Take(e.Count()),
e => e.ToArray(),
e => e.ToList(),
e => new LinkedList<int>(e) // Implements IList<T>.
};

var r = new Random(unchecked((int)0x984bf1a3));

for (int i = 0; i <= 5; i++)
{
var enumerable = Enumerable.Range(1, i).Select(_ => r.Next());

foreach (var transform in transforms)
{
yield return new object[] { transform(enumerable) };
}
}
}
}
}