Skip to content

Commit

Permalink
4.x: Make TailRecursiveSink lock-free and have less allocations (#499)
Browse files Browse the repository at this point in the history
  • Loading branch information
akarnokd authored and Oren Novotny committed May 26, 2018
1 parent eb64149 commit 10a44ad
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 127 deletions.
2 changes: 1 addition & 1 deletion Rx.NET/Source/src/System.Reactive/Internal/ConcatSink.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ public ConcatSink(IObserver<TSource> observer, IDisposable cancel)

protected override IEnumerable<IObservable<TSource>> Extract(IObservable<TSource> source) => (source as IConcatenatable<TSource>)?.GetSources();

public override void OnCompleted() => _recurse();
public override void OnCompleted() => Recurse();
}
}
253 changes: 130 additions & 123 deletions Rx.NET/Source/src/System.Reactive/Internal/TailRecursiveSink.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Collections.Generic;
using System.Reactive.Concurrency;
using System.Reactive.Disposables;
using System.Threading;

namespace System.Reactive
{
Expand All @@ -15,164 +16,176 @@ public TailRecursiveSink(IObserver<TSource> observer, IDisposable cancel)
{
}

private bool _isDisposed;
private SerialDisposable _subscription;
private AsyncLock _gate;
private Stack<IEnumerator<IObservable<TSource>>> _stack;
private Stack<int?> _length;
protected Action _recurse;
bool _isDisposed;

int trampoline;

IDisposable currentSubscription;

Stack<IEnumerator<IObservable<TSource>>> stack;

public IDisposable Run(IEnumerable<IObservable<TSource>> sources)
{
_isDisposed = false;
_subscription = new SerialDisposable();
_gate = new AsyncLock();
_stack = new Stack<IEnumerator<IObservable<TSource>>>();
_length = new Stack<int?>();

if (!TryGetEnumerator(sources, out var e))
if (!TryGetEnumerator(sources, out var current))
return Disposable.Empty;

_stack.Push(e);
_length.Push(Helpers.GetLength(sources));
stack = new Stack<IEnumerator<IObservable<TSource>>>();
stack.Push(current);

var cancelable = SchedulerDefaults.TailRecursion.Schedule(self =>
{
_recurse = self;
_gate.Wait(MoveNext);
});
Drain();

return StableCompositeDisposable.Create(_subscription, cancelable, Disposable.Create(() => _gate.Wait(Dispose)));
return new RecursiveSinkDisposable(this);
}

protected abstract IEnumerable<IObservable<TSource>> Extract(IObservable<TSource> source);

private void MoveNext()
sealed class RecursiveSinkDisposable : IDisposable
{
var hasNext = false;
var next = default(IObservable<TSource>);
readonly TailRecursiveSink<TSource> parent;

do
public RecursiveSinkDisposable(TailRecursiveSink<TSource> parent)
{
if (_stack.Count == 0)
break;
this.parent = parent;
}

if (_isDisposed)
return;
public void Dispose()
{
parent.DisposeAll();
}
}

var e = _stack.Peek();
var l = _length.Peek();
void Drain()
{
if (Interlocked.Increment(ref trampoline) != 1)
{
return;
}

var current = default(IObservable<TSource>);
try
for (; ; )
{
if (Volatile.Read(ref _isDisposed))
{
hasNext = e.MoveNext();
if (hasNext)
while (stack.Count != 0)
{
current = e.Current;
var enumerator = stack.Pop();
enumerator.Dispose();
}
if (Volatile.Read(ref currentSubscription) != BooleanDisposable.True)
{
Interlocked.Exchange(ref currentSubscription, BooleanDisposable.True)?.Dispose();
}
}
catch (Exception ex)
{
e.Dispose();

//
// Failure to enumerate the sequence cannot be handled, even by
// operators like Catch, because it'd lead to another attempt at
// enumerating to find the next observable sequence. Therefore,
// we feed those errors directly to the observer.
//
_observer.OnError(ex);
base.Dispose();
return;
}

if (!hasNext)
{
e.Dispose();

_stack.Pop();
_length.Pop();
}
else
{
var r = l - 1;
_length.Pop();
_length.Push(r);

try
if (stack.Count != 0)
{
next = Helpers.Unpack(current);
}
catch (Exception exception)
{
//
// Errors from unpacking may produce side-effects that normally
// would occur during a SubscribeSafe operation. Those would feed
// back into the observer and be subject to the operator's error
// handling behavior. For example, Catch would allow to handle
// the error using a handler function.
//
if (!Fail(exception))
var currentEnumerator = stack.Peek();

var currentObservable = default(IObservable<TSource>);
var next = default(IObservable<TSource>);

try
{
if (currentEnumerator.MoveNext())
{
currentObservable = currentEnumerator.Current;
}
}
catch (Exception ex)
{
e.Dispose();
currentEnumerator.Dispose();
_observer.OnError(ex);
base.Dispose();
Volatile.Write(ref _isDisposed, true);
continue;
}

return;
}
try
{
next = Helpers.Unpack(currentObservable);

//
// Tail recursive case; drop the current frame.
//
if (r == 0)
{
e.Dispose();
}
catch (Exception ex)
{
next = null;
if (!Fail(ex))
{
Volatile.Write(ref _isDisposed, true);
}
continue;
}

_stack.Pop();
_length.Pop();
if (next != null)
{
var nextSeq = Extract(next);
if (nextSeq != null)
{
if (TryGetEnumerator(nextSeq, out var nextEnumerator))
{
stack.Push(nextEnumerator);
continue;
}
else
{
Volatile.Write(ref _isDisposed, true);
continue;
}
}
else
{
var sad = new SingleAssignmentDisposable();
if (Interlocked.CompareExchange(ref currentSubscription, sad, null) == null)
{
sad.Disposable = next.SubscribeSafe(this);
}
else
{
continue;
}
}
}
else
{
stack.Pop();
currentEnumerator.Dispose();
continue;
}
}

//
// Flattening of nested sequences. Prevents stack overflow in observers.
//
var nextSeq = Extract(next);
if (nextSeq != null)
else
{
if (!TryGetEnumerator(nextSeq, out var nextEnumerator))
return;

_stack.Push(nextEnumerator);
_length.Push(Helpers.GetLength(nextSeq));

hasNext = false;
Volatile.Write(ref _isDisposed, true);
Done();
}
}
} while (!hasNext);

if (!hasNext)
{
Done();
return;
if (Interlocked.Decrement(ref trampoline) == 0)
{
break;
}
}
}

var d = new SingleAssignmentDisposable();
_subscription.Disposable = d;
d.Disposable = next.SubscribeSafe(this);
void DisposeAll()
{
Volatile.Write(ref _isDisposed, true);
// the disposing of currentSubscription is deferred to drain due to some ObservableExTest.Iterate_Complete()
// Interlocked.Exchange(ref currentSubscription, BooleanDisposable.True)?.Dispose();
Drain();
}

private new void Dispose()
protected void Recurse()
{
while (_stack.Count > 0)
var d = Volatile.Read(ref currentSubscription);
if (d != BooleanDisposable.True)
{
var e = _stack.Pop();
_length.Pop();

e.Dispose();
d?.Dispose();
if (Interlocked.CompareExchange(ref currentSubscription, null, d) == d)
{
Drain();
}
}

_isDisposed = true;
}

protected abstract IEnumerable<IObservable<TSource>> Extract(IObservable<TSource> source);

private bool TryGetEnumerator(IEnumerable<IObservable<TSource>> sources, out IEnumerator<IObservable<TSource>> result)
{
try
Expand All @@ -182,12 +195,6 @@ private bool TryGetEnumerator(IEnumerable<IObservable<TSource>> sources, out IEn
}
catch (Exception exception)
{
//
// Failure to enumerate the sequence cannot be handled, even by
// operators like Catch, because it'd lead to another attempt at
// enumerating to find the next observable sequence. Therefore,
// we feed those errors directly to the observer.
//
_observer.OnError(exception);
base.Dispose();

Expand Down
2 changes: 1 addition & 1 deletion Rx.NET/Source/src/System.Reactive/Linq/Observable/Catch.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public override void OnNext(TSource value)
public override void OnError(Exception error)
{
_lastException = error;
_recurse();
Recurse();
}

public override void OnCompleted()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ public override void OnNext(TSource value)

public override void OnError(Exception error)
{
_recurse();
Recurse();
}

public override void OnCompleted()
{
_recurse();
Recurse();
}

protected override bool Fail(Exception error)
Expand Down

0 comments on commit 10a44ad

Please sign in to comment.