Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 18 additions & 25 deletions src/FSharp.Control.AsyncSeq/AsyncSeq.fs
Original file line number Diff line number Diff line change
Expand Up @@ -297,20 +297,20 @@ module AsyncSeqOp =
type OptimizedUnfoldEnumerator<'S, 'T> (f:'S -> Async<('T * 'S) option>, init:'S) =
let mutable currentState = init
let mutable disposed = false

interface IAsyncEnumerator<'T> with
member __.MoveNext () : Async<'T option> =
member __.MoveNext () : Async<'T option> =
if disposed then async.Return None
else async {
let! result = f currentState
match result with
| None ->
| None ->
return None
| Some (value, nextState) ->
currentState <- nextState
return Some value
}
member __.Dispose () =
member __.Dispose () =
disposed <- true

type UnfoldAsyncEnumerator<'S, 'T> (f:'S -> Async<('T * 'S) option>, init:'S) =
Expand Down Expand Up @@ -458,13 +458,6 @@ module AsyncSeq =
type AsyncSeqBuilder() =
member x.Yield(v) =
singleton v
// This looks weird, but it is needed to allow:
//
// while foo do
// do! something
//
// because F# translates body as Bind(something, fun () -> Return())
member x.Return () = empty
member x.YieldFrom(s:AsyncSeq<'T>) =
s
member x.Zero () = empty
Expand Down Expand Up @@ -606,10 +599,10 @@ module AsyncSeq =
// Optimized collect implementation using direct field access instead of ref cells
type OptimizedCollectEnumerator<'T, 'U>(f: 'T -> AsyncSeq<'U>, inp: AsyncSeq<'T>) =
// Mutable fields instead of ref cells to reduce allocations
let mutable inputEnumerator: IAsyncEnumerator<'T> option = None
let mutable inputEnumerator: IAsyncEnumerator<'T> option = None
let mutable innerEnumerator: IAsyncEnumerator<'U> option = None
let mutable disposed = false

// Tail-recursive optimization to avoid deep continuation chains
let rec moveNextLoop () : Async<'U option> = async {
if disposed then return None
Expand Down Expand Up @@ -642,7 +635,7 @@ module AsyncSeq =
inputEnumerator <- Some newOuter
return! moveNextLoop ()
}

interface IAsyncEnumerator<'U> with
member _.MoveNext() = moveNextLoop ()
member _.Dispose() =
Expand All @@ -651,13 +644,13 @@ module AsyncSeq =
match innerEnumerator with
| Some inner -> inner.Dispose(); innerEnumerator <- None
| None -> ()
match inputEnumerator with
match inputEnumerator with
| Some outer -> outer.Dispose(); inputEnumerator <- None
| None -> ()

let collect (f: 'T -> AsyncSeq<'U>) (inp: AsyncSeq<'T>) : AsyncSeq<'U> =
{ new IAsyncEnumerable<'U> with
member _.GetEnumerator() =
member _.GetEnumerator() =
new OptimizedCollectEnumerator<'T, 'U>(f, inp) :> IAsyncEnumerator<'U> }

// let collect (f: 'T -> AsyncSeq<'U>) (inp: AsyncSeq<'T>) : AsyncSeq<'U> =
Expand Down Expand Up @@ -749,7 +742,7 @@ module AsyncSeq =
// Optimized iterAsync implementation to reduce allocations
type internal OptimizedIterAsyncEnumerator<'T>(enumerator: IAsyncEnumerator<'T>, f: 'T -> Async<unit>) =
let mutable disposed = false

member _.IterateAsync() =
let rec loop() = async {
let! next = enumerator.MoveNext()
Expand All @@ -760,17 +753,17 @@ module AsyncSeq =
| None -> return ()
}
loop()

interface IDisposable with
member _.Dispose() =
if not disposed then
disposed <- true
enumerator.Dispose()

// Optimized iteriAsync implementation with direct tail recursion
// Optimized iteriAsync implementation with direct tail recursion
type internal OptimizedIteriAsyncEnumerator<'T>(enumerator: IAsyncEnumerator<'T>, f: int -> 'T -> Async<unit>) =
let mutable disposed = false

member _.IterateAsync() =
let rec loop count = async {
let! next = enumerator.MoveNext()
Expand All @@ -781,7 +774,7 @@ module AsyncSeq =
| None -> return ()
}
loop 0

interface IDisposable with
member _.Dispose() =
if not disposed then
Expand All @@ -798,7 +791,7 @@ module AsyncSeq =
let iterAsync (f: 'T -> Async<unit>) (source: AsyncSeq<'T>) =
match source with
| :? AsyncSeqOp<'T> as source -> source.IterAsync f
| _ ->
| _ ->
async {
let enum = source.GetEnumerator()
use optimizer = new OptimizedIterAsyncEnumerator<_>(enum, f)
Expand Down Expand Up @@ -864,7 +857,7 @@ module AsyncSeq =
// Optimized mapAsync enumerator that avoids computation builder overhead
type private OptimizedMapAsyncEnumerator<'T, 'TResult>(source: IAsyncEnumerator<'T>, f: 'T -> Async<'TResult>) =
let mutable disposed = false

interface IAsyncEnumerator<'TResult> with
member _.MoveNext() = async {
let! moveResult = source.MoveNext()
Expand All @@ -874,7 +867,7 @@ module AsyncSeq =
let! mapped = f value
return Some mapped
}

member _.Dispose() =
if not disposed then
disposed <- true
Expand All @@ -885,7 +878,7 @@ module AsyncSeq =
| :? AsyncSeqOp<'T> as source -> source.MapAsync f
| _ ->
{ new IAsyncEnumerable<'TResult> with
member _.GetEnumerator() =
member _.GetEnumerator() =
new OptimizedMapAsyncEnumerator<'T, 'TResult>(source.GetEnumerator(), f) :> IAsyncEnumerator<'TResult> }

let mapiAsync f (source : AsyncSeq<'T>) : AsyncSeq<'TResult> = asyncSeq {
Expand Down
25 changes: 25 additions & 0 deletions tests/FSharp.Control.AsyncSeq.Tests/AsyncSeqTests.fs
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,31 @@ let ``AsyncSeq.bufferByTimeAndCount empty``() =
//
// Assert.True ((actual = expected))

[<Test>]
let ``AsyncSeq.while do CE is possible`` () =
let mutable i = 0
let mutable foo = true
let something =
async {
i <- i + 1
foo <- i < 3
do! Async.Sleep 10
}
let actual =
asyncSeq {
yield "a"

while foo do
do! something

yield "b"
yield "c"
}
|> AsyncSeq.toListAsync
|> Async.RunSynchronously

Assert.AreEqual([ "a"; "b"; "c" ], actual)

[<Test>]
let ``AsyncSeq.bufferByCountAndTime should not block`` () =
let op =
Expand Down