diff --git a/src/FSharp.Control.AsyncSeq/AsyncSeq.fs b/src/FSharp.Control.AsyncSeq/AsyncSeq.fs index da15b10..f11d424 100644 --- a/src/FSharp.Control.AsyncSeq/AsyncSeq.fs +++ b/src/FSharp.Control.AsyncSeq/AsyncSeq.fs @@ -297,20 +297,20 @@ module AsyncSeqOp = type OptimizedUnfoldEnumerator<'S, 'T> (f:'S -> Async<('T * 'S) option>, init:'S) = let mutable currentState = init let mutable disposed = false - + interface IAsyncEnumerator<'T> with - member __.MoveNext () : Async<'T option> = + member __.MoveNext () : Async<'T option> = if disposed then async.Return None else async { let! result = f currentState match result with - | None -> + | None -> return None | Some (value, nextState) -> currentState <- nextState return Some value } - member __.Dispose () = + member __.Dispose () = disposed <- true type UnfoldAsyncEnumerator<'S, 'T> (f:'S -> Async<('T * 'S) option>, init:'S) = @@ -458,13 +458,6 @@ module AsyncSeq = type AsyncSeqBuilder() = member x.Yield(v) = singleton v - // This looks weird, but it is needed to allow: - // - // while foo do - // do! something - // - // because F# translates body as Bind(something, fun () -> Return()) - member x.Return () = empty member x.YieldFrom(s:AsyncSeq<'T>) = s member x.Zero () = empty @@ -606,10 +599,10 @@ module AsyncSeq = // Optimized collect implementation using direct field access instead of ref cells type OptimizedCollectEnumerator<'T, 'U>(f: 'T -> AsyncSeq<'U>, inp: AsyncSeq<'T>) = // Mutable fields instead of ref cells to reduce allocations - let mutable inputEnumerator: IAsyncEnumerator<'T> option = None + let mutable inputEnumerator: IAsyncEnumerator<'T> option = None let mutable innerEnumerator: IAsyncEnumerator<'U> option = None let mutable disposed = false - + // Tail-recursive optimization to avoid deep continuation chains let rec moveNextLoop () : Async<'U option> = async { if disposed then return None @@ -642,7 +635,7 @@ module AsyncSeq = inputEnumerator <- Some newOuter return! moveNextLoop () } - + interface IAsyncEnumerator<'U> with member _.MoveNext() = moveNextLoop () member _.Dispose() = @@ -651,13 +644,13 @@ module AsyncSeq = match innerEnumerator with | Some inner -> inner.Dispose(); innerEnumerator <- None | None -> () - match inputEnumerator with + match inputEnumerator with | Some outer -> outer.Dispose(); inputEnumerator <- None | None -> () let collect (f: 'T -> AsyncSeq<'U>) (inp: AsyncSeq<'T>) : AsyncSeq<'U> = { new IAsyncEnumerable<'U> with - member _.GetEnumerator() = + member _.GetEnumerator() = new OptimizedCollectEnumerator<'T, 'U>(f, inp) :> IAsyncEnumerator<'U> } // let collect (f: 'T -> AsyncSeq<'U>) (inp: AsyncSeq<'T>) : AsyncSeq<'U> = @@ -749,7 +742,7 @@ module AsyncSeq = // Optimized iterAsync implementation to reduce allocations type internal OptimizedIterAsyncEnumerator<'T>(enumerator: IAsyncEnumerator<'T>, f: 'T -> Async) = let mutable disposed = false - + member _.IterateAsync() = let rec loop() = async { let! next = enumerator.MoveNext() @@ -760,17 +753,17 @@ module AsyncSeq = | None -> return () } loop() - + interface IDisposable with member _.Dispose() = if not disposed then disposed <- true enumerator.Dispose() - // Optimized iteriAsync implementation with direct tail recursion + // Optimized iteriAsync implementation with direct tail recursion type internal OptimizedIteriAsyncEnumerator<'T>(enumerator: IAsyncEnumerator<'T>, f: int -> 'T -> Async) = let mutable disposed = false - + member _.IterateAsync() = let rec loop count = async { let! next = enumerator.MoveNext() @@ -781,7 +774,7 @@ module AsyncSeq = | None -> return () } loop 0 - + interface IDisposable with member _.Dispose() = if not disposed then @@ -798,7 +791,7 @@ module AsyncSeq = let iterAsync (f: 'T -> Async) (source: AsyncSeq<'T>) = match source with | :? AsyncSeqOp<'T> as source -> source.IterAsync f - | _ -> + | _ -> async { let enum = source.GetEnumerator() use optimizer = new OptimizedIterAsyncEnumerator<_>(enum, f) @@ -864,7 +857,7 @@ module AsyncSeq = // Optimized mapAsync enumerator that avoids computation builder overhead type private OptimizedMapAsyncEnumerator<'T, 'TResult>(source: IAsyncEnumerator<'T>, f: 'T -> Async<'TResult>) = let mutable disposed = false - + interface IAsyncEnumerator<'TResult> with member _.MoveNext() = async { let! moveResult = source.MoveNext() @@ -874,7 +867,7 @@ module AsyncSeq = let! mapped = f value return Some mapped } - + member _.Dispose() = if not disposed then disposed <- true @@ -885,7 +878,7 @@ module AsyncSeq = | :? AsyncSeqOp<'T> as source -> source.MapAsync f | _ -> { new IAsyncEnumerable<'TResult> with - member _.GetEnumerator() = + member _.GetEnumerator() = new OptimizedMapAsyncEnumerator<'T, 'TResult>(source.GetEnumerator(), f) :> IAsyncEnumerator<'TResult> } let mapiAsync f (source : AsyncSeq<'T>) : AsyncSeq<'TResult> = asyncSeq { diff --git a/tests/FSharp.Control.AsyncSeq.Tests/AsyncSeqTests.fs b/tests/FSharp.Control.AsyncSeq.Tests/AsyncSeqTests.fs index 3810215..35b1d2a 100644 --- a/tests/FSharp.Control.AsyncSeq.Tests/AsyncSeqTests.fs +++ b/tests/FSharp.Control.AsyncSeq.Tests/AsyncSeqTests.fs @@ -581,6 +581,31 @@ let ``AsyncSeq.bufferByTimeAndCount empty``() = // // Assert.True ((actual = expected)) +[] +let ``AsyncSeq.while do CE is possible`` () = + let mutable i = 0 + let mutable foo = true + let something = + async { + i <- i + 1 + foo <- i < 3 + do! Async.Sleep 10 + } + let actual = + asyncSeq { + yield "a" + + while foo do + do! something + + yield "b" + yield "c" + } + |> AsyncSeq.toListAsync + |> Async.RunSynchronously + + Assert.AreEqual([ "a"; "b"; "c" ], actual) + [] let ``AsyncSeq.bufferByCountAndTime should not block`` () = let op =