From d5ed221c9b190fcfeff349707d989caa91b790c2 Mon Sep 17 00:00:00 2001 From: Sofia Rodrigues Date: Fri, 12 Sep 2025 13:27:45 -0300 Subject: [PATCH 1/6] feat: change all Task types to Async --- src/Std/Internal/Async/Basic.lean | 284 +++++++++++++++++++---- src/Std/Internal/Async/DNS.lean | 14 +- src/Std/Internal/Async/TCP.lean | 29 ++- src/Std/Internal/Async/Timer.lean | 49 ++-- src/Std/Internal/Async/UDP.lean | 12 +- tests/lean/run/async_base_functions.lean | 31 ++- 6 files changed, 317 insertions(+), 102 deletions(-) diff --git a/src/Std/Internal/Async/Basic.lean b/src/Std/Internal/Async/Basic.lean index fbd9a571ccd2..0245caabbb6f 100644 --- a/src/Std/Internal/Async/Basic.lean +++ b/src/Std/Internal/Async/Basic.lean @@ -369,10 +369,10 @@ def joinTask (t : Task (MaybeTask α)) : Task α := | .pure a => .pure a | .ofTask t => t -instance : Functor (MaybeTask) where +instance : Functor MaybeTask where map := MaybeTask.map -instance : Monad (MaybeTask) where +instance : Monad MaybeTask where pure := MaybeTask.pure bind := MaybeTask.bind @@ -494,6 +494,77 @@ instance : MonadAsync Task BaseAsync where instance [Inhabited α] : Inhabited (BaseAsync α) where default := .mk <| pure (MaybeTask.pure default) +instance : MonadFinally BaseAsync where + tryFinally' x f := do + let res ← x + Prod.mk res <$> f (some res) + + +/-- +Converts `Task` into `BaseAsync`. +-/ +@[inline] +protected def ofEAsyncTask (task : Task α) : BaseAsync α := do + pure (f := BaseIO) (MaybeTask.ofTask task) + +/-- +Converts `Except` to `BaseAsync`. +-/ +@[inline] +protected def ofExcept (except : Except Empty α) : BaseAsync α := + pure (f := BaseIO) <| MaybeTask.pure <| match except with | .ok res => res + +/-- +Runs two computations concurrently and returns both results as a pair. +-/ +@[inline, specialize] +def concurrently (x : BaseAsync α) (y : BaseAsync β) (prio := Task.Priority.default) : BaseAsync (α × β) := do + let taskX : Task _ ← MonadAsync.async x (prio := prio) + let taskY : Task _ ← MonadAsync.async y (prio := prio) + let resultX ← MonadAwait.await taskX + let resultY ← MonadAwait.await taskY + return (resultX, resultY) + +/-- +Runs two computations concurrently and returns the result of the one that finishes first. +The other result is lost and the other task is not cancelled, so the task will continue the execution +until the end. +-/ +@[inline, specialize] +def race [Inhabited α] (x : BaseAsync α) (y : BaseAsync α) (prio := Task.Priority.default) : BaseAsync α := do + let promise ← IO.Promise.new + + let task₁ : Task _ ← MonadAsync.async (prio := prio) x + let task₂ : Task _ ← MonadAsync.async (prio := prio) y + + BaseIO.chainTask task₁ (liftM ∘ promise.resolve) + BaseIO.chainTask task₂ (liftM ∘ promise.resolve) + + MonadAwait.await promise.result! + +/-- +Runs all computations in an `Array` concurrently and returns all results as an array. +-/ +@[inline, specialize] +def concurrentlyAll (xs : Array (BaseAsync α)) (prio := Task.Priority.default) : BaseAsync (Array α) := do + let tasks : Array (Task α) ← xs.mapM (MonadAsync.async (prio := prio)) + tasks.mapM MonadAwait.await + +/-- +Runs all computations concurrently and returns the result of the first one to finish. +All other results are lost, and the tasks are not cancelled, so they'll continue their executing +until the end. +-/ +@[inline, specialize] +def raceAll [Inhabited α] [ForM BaseAsync c (BaseAsync α)] (xs : c) (prio := Task.Priority.default) : BaseAsync α := do + let promise ← IO.Promise.new + + ForM.forM xs fun x => do + let task₁ ← MonadAsync.async (t := Task) (prio := prio) x + BaseIO.chainTask task₁ (liftM ∘ promise.resolve) + + MonadAwait.await promise.result! + end BaseAsync /-- @@ -578,6 +649,13 @@ Lifts an `EAsync` computation into an `ETask` that can be awaited and joined. protected def asTask (x : EAsync ε α) (prio := Task.Priority.default) : EIO ε (ETask ε α) := x |> BaseAsync.asTask (prio := prio) +/-- +Block until the `EAsync` finishes and returns its value. Propagates any error encountered during execution. +-/ +@[inline] +protected def block (x : EAsync ε α) (prio := Task.Priority.default) : EIO ε α := + x.asTask (prio := prio) >>= ETask.block + /-- Raises an error of type `ε` within the `EAsync` monad. -/ @@ -707,6 +785,75 @@ protected partial def forIn instance : ForIn (EAsync ε) Lean.Loop Unit where forIn _ := EAsync.forIn +/-- +Converts `ETask` into `EAsync`. +-/ +@[inline] +protected def ofEAsyncTask (task : ETask ε α) : EAsync ε α := do + pure (f := BaseIO) (MaybeTask.ofTask task) + +/-- +Converts `Except` to `EAsync`. +-/ +@[inline] +protected def ofExcept (except : Except ε α) : EAsync ε α := + pure (f := BaseIO) (MaybeTask.pure except) + +/-- +Runs two computations concurrently and returns both results as a pair. +-/ +@[inline, specialize] +def concurrently (x : EAsync ε α) (y : EAsync ε β) (prio := Task.Priority.default) : EAsync ε (α × β) := do + let taskX : ETask ε _ ← MonadAsync.async x (prio := prio) + let taskY : ETask ε _ ← MonadAsync.async y (prio := prio) + let resultX ← MonadAwait.await taskX + let resultY ← MonadAwait.await taskY + return (resultX, resultY) + +/-- +Runs two computations concurrently and returns the result of the one that finishes first. +The other result is lost and the other task is not cancelled, so the task will continue the execution +until the end. +-/ +@[inline, specialize] +def race [Inhabited α] (x : EAsync ε α) (y : EAsync ε α) + (prio := Task.Priority.default) : + EAsync ε α := do + let promise ← IO.Promise.new + + let task₁ : ETask ε _ ← MonadAsync.async (prio := prio) x + let task₂ : ETask ε _ ← MonadAsync.async (prio := prio) y + + BaseIO.chainTask task₁ (liftM ∘ promise.resolve) + BaseIO.chainTask task₂ (liftM ∘ promise.resolve) + + let result ← MonadAwait.await promise.result! + EAsync.ofExcept result + +/-- +Runs all computations in an `Array` concurrently and returns all results as an array. +-/ +@[inline, specialize] +def concurrentlyAll (xs : Array (EAsync ε α)) (prio := Task.Priority.default) : EAsync ε (Array α) := do + let tasks : Array (ETask ε α) ← xs.mapM (MonadAsync.async (prio := prio)) + tasks.mapM MonadAwait.await + +/-- +Runs all computations concurrently and returns the result of the first one to finish. +All other results are lost, and the tasks are not cancelled, so they'll continue their executing +until the end. +-/ +@[inline, specialize] +def raceAll [Inhabited α] [ForM (EAsync ε) c (EAsync ε α)] (xs : c) (prio := Task.Priority.default) : EAsync ε α := do + let promise ← IO.Promise.new + + ForM.forM xs fun x => do + let task₁ ← MonadAsync.async (t := ETask ε) (prio := prio) x + BaseIO.chainTask task₁ (liftM ∘ promise.resolve) + + let result ← MonadAwait.await promise.result! + EAsync.ofExcept result + end EAsync /-- @@ -723,6 +870,61 @@ Converts a `Async` to a `AsyncTask`. protected def toIO (x : Async α) : IO (AsyncTask α) := MaybeTask.toTask <$> x.toRawBaseIO +/-- +Block until the `Async` finishes and returns its value. Propagates any error encountered during execution. +-/ +@[inline] +protected def block (x : Async α) (prio := Task.Priority.default) : IO α := + x.asTask (prio := prio) >>= ETask.block + +/-- +Converts `Promise` into `Async`. +-/ +@[inline] +protected def ofPromise (task : IO (IO.Promise (Except IO.Error α))) : Async α := do + match ← task.toBaseIO with + | .ok data => pure (f := BaseIO) (MaybeTask.ofTask data.result!) + | .error err => pure (f := BaseIO) (MaybeTask.pure (.error err)) + +/-- +Converts `AsyncTask` into `Async`. +-/ +@[inline] +protected def ofAsyncTask (task : AsyncTask α) : Async α := do + pure (f := BaseIO) (MaybeTask.ofTask task) + +/-- +Converts `IO (Task α)` into `Async`. +-/ +@[inline] +protected def ofIOTask (task : IO (Task α)) : Async α := do + match ← task.toBaseIO with + | .ok data => .ofAsyncTask (data.map Except.ok) + | .error err => pure (f := BaseIO) (MaybeTask.pure (.error err)) + +/-- +Converts `Except` to `Async`. +-/ +@[inline] +protected def ofExcept (except : Except IO.Error α) : Async α := + pure (f := BaseIO) (MaybeTask.pure except) + +/-- +Converts `Task` to `Async`. +-/ +@[inline] +protected def ofTask (task : Task α) : Async α := do + .ofAsyncTask (task.map Except.ok) + +/-- +Converts `IO (IO.Promise α)` to `Async`. +-/ +@[inline] +protected def ofPurePromise (task : IO (IO.Promise α)) : Async α := do + match ← task.toBaseIO with + | .ok data => pure (f := BaseIO) (MaybeTask.ofTask <| data.result!.map (.ok)) + | .error err => pure (f := BaseIO) (MaybeTask.pure (.error err)) + @[default_instance] instance : MonadAsync AsyncTask Async := inferInstanceAs (MonadAsync (ETask IO.Error) (EAsync IO.Error)) @@ -733,31 +935,15 @@ instance : MonadAwait AsyncTask Async := instance : MonadAwait IO.Promise Async := inferInstanceAs (MonadAwait IO.Promise (EAsync IO.Error)) -end Async - -export MonadAsync (async) -export MonadAwait (await) - -/-- -This function transforms the operation inside the monad `m` into a task and let it run in the background. --/ -@[inline, specialize] -def background [Monad m] [MonadAsync t m] (action : m α) (prio := Task.Priority.default) : m Unit := - discard (async (t := t) (prio := prio) action) - /-- Runs two computations concurrently and returns both results as a pair. -/ @[inline, specialize] -def concurrently - [Monad m] [MonadAwait t m] [MonadAsync t m] - (x : m α) (y : m β) - (prio := Task.Priority.default) : - m (α × β) := do - let taskX : t α ← async x (prio := prio) - let taskY : t β ← async y (prio := prio) - let resultX ← await taskX - let resultY ← await taskY +def concurrently (x : Async α) (y : Async β) (prio := Task.Priority.default) : Async (α × β) := do + let taskX ← MonadAsync.async x (prio := prio) + let taskY ← MonadAsync.async y (prio := prio) + let resultX ← MonadAwait.await taskX + let resultY ← MonadAwait.await taskY return (resultX, resultY) /-- @@ -766,28 +952,27 @@ The other result is lost and the other task is not cancelled, so the task will c until the end. -/ @[inline, specialize] -def race - [MonadLiftT BaseIO m] [MonadAwait Task m] [MonadAsync t m] [MonadAwait t m] - [Monad m] [Inhabited α] (x : m α) (y : m α) +def race [Inhabited α] (x : Async α) (y : Async α) (prio := Task.Priority.default) : - m α := do + Async α := do let promise ← IO.Promise.new - discard (async (t := t) (prio := prio) <| Bind.bind x (liftM ∘ promise.resolve)) - discard (async (t := t) (prio := prio) <| Bind.bind y (liftM ∘ promise.resolve)) + let task₁ ← MonadAsync.async (t := AsyncTask) (prio := prio) x + let task₂ ← MonadAsync.async (t := AsyncTask) (prio := prio) y + + BaseIO.chainTask task₁ (liftM ∘ promise.resolve) + BaseIO.chainTask task₂ (liftM ∘ promise.resolve) - await promise.result! + let result ← MonadAwait.await promise.result! + Async.ofExcept result /-- Runs all computations in an `Array` concurrently and returns all results as an array. -/ @[inline, specialize] -def concurrentlyAll - [Monad m] [MonadAwait t m] [MonadAsync t m] (xs : Array (m α)) - (prio := Task.Priority.default) : - m (Array α) := do - let tasks : Array (t α) ← xs.mapM (async (prio := prio)) - tasks.mapM await +def concurrentlyAll (xs : Array (Async α)) (prio := Task.Priority.default) : Async (Array α) := do + let tasks : Array (AsyncTask α) ← xs.mapM (MonadAsync.async (prio := prio)) + tasks.mapM MonadAwait.await /-- Runs all computations concurrently and returns the result of the first one to finish. @@ -795,18 +980,27 @@ All other results are lost, and the tasks are not cancelled, so they'll continue until the end. -/ @[inline, specialize] -def raceAll - [ForM m c (m α)] [MonadLiftT BaseIO m] [MonadAwait Task m] - [MonadAsync t m] [MonadAwait t m] [Monad m] [Inhabited α] - (xs : c) - (prio := Task.Priority.default) : - m α := do +def raceAll [ForM Async c (Async α)] (xs : c) (prio := Task.Priority.default) : Async α := do let promise ← IO.Promise.new - ForM.forM xs fun x => - discard (async (t := t) (prio := prio) <| Bind.bind x (liftM ∘ promise.resolve)) + ForM.forM xs fun x => do + let task₁ ← MonadAsync.async (t := AsyncTask) (prio := prio) x + BaseIO.chainTask task₁ (liftM ∘ promise.resolve) - await promise.result! + let result ← MonadAwait.await promise.result! + Async.ofExcept result + +end Async + +export MonadAsync (async) +export MonadAwait (await) + +/-- +This function transforms the operation inside the monad `m` into a task and let it run in the background. +-/ +@[inline, specialize] +def background [Monad m] [MonadAsync t m] (action : m α) (prio := Task.Priority.default) : m Unit := + discard (async (t := t) (prio := prio) action) end Async end IO diff --git a/src/Std/Internal/Async/DNS.lean b/src/Std/Internal/Async/DNS.lean index a728b5d0afb3..5c13e4dd1735 100644 --- a/src/Std/Internal/Async/DNS.lean +++ b/src/Std/Internal/Async/DNS.lean @@ -39,10 +39,11 @@ structure NameInfo where Asynchronously resolves a hostname and service to an array of socket addresses. -/ @[inline] -def getAddrInfo (host : String) (service : String) (addressFamily : Option AddressFamily := none) : - IO (AsyncTask (Array IPAddr)) := - AsyncTask.ofPromise <$> UV.DNS.getAddrInfo host service - (match addressFamily with +def getAddrInfo (host : String) (service : String) (addrFamily : Option AddressFamily := none) : Async (Array IPAddr) := do + Async.ofPromise <| UV.DNS.getAddrInfo + host + service + (match addrFamily with | none => 0 | some .ipv4 => 1 | some .ipv6 => 2) @@ -51,9 +52,10 @@ def getAddrInfo (host : String) (service : String) (addressFamily : Option Addre Performs a reverse DNS lookup on a `SocketAddress`. -/ @[inline] -def getNameInfo (host : @& SocketAddress) : IO (AsyncTask NameInfo) := +def getNameInfo (host : @& SocketAddress) : Async NameInfo := UV.DNS.getNameInfo host - |>.map (Task.map (.map <| Function.uncurry NameInfo.mk) ∘ AsyncTask.ofPromise) + |> Async.ofPromise + |>.map (Function.uncurry NameInfo.mk) end DNS end Async diff --git a/src/Std/Internal/Async/TCP.lean b/src/Std/Internal/Async/TCP.lean index 1e6667f6f2eb..de2853fd9f56 100644 --- a/src/Std/Internal/Async/TCP.lean +++ b/src/Std/Internal/Async/TCP.lean @@ -18,7 +18,6 @@ namespace Internal namespace IO namespace Async namespace TCP - open Std.Net namespace Socket @@ -66,9 +65,10 @@ def listen (s : Server) (backlog : UInt32) : IO Unit := Accepts an incoming connection. -/ @[inline] -def accept (s : Server) : IO (AsyncTask Client) := do - let conn ← s.native.accept - return conn.result!.map (·.map Client.ofNative) +def accept (s : Server) : Async Client := do + s.native.accept + |> Async.ofPromise + |>.map Client.ofNative /-- Gets the local address of the server socket. @@ -115,15 +115,15 @@ def bind (s : Client) (addr : SocketAddress) : IO Unit := Connects the client socket to the given address. -/ @[inline] -def connect (s : Client) (addr : SocketAddress) : IO (AsyncTask Unit) := - AsyncTask.ofPromise <$> s.native.connect addr +def connect (s : Client) (addr : SocketAddress) : Async Unit := + Async.ofPromise <| s.native.connect addr /-- Sends data through the client socket. -/ @[inline] -def send (s : Client) (data : ByteArray) : IO (AsyncTask Unit) := - AsyncTask.ofPromise <$> s.native.send data +def send (s : Client) (data : ByteArray) : Async Unit := + Async.ofPromise <| s.native.send data /-- Receives data from the client socket. If data is received, it’s wrapped in .some. If EOF is reached, @@ -132,8 +132,8 @@ socket is not supported. Instead, we recommend binding multiple sockets to the s Furthermore calling this function in parallel with `recvSelector` is not supported. -/ @[inline] -def recv? (s : Client) (size : UInt64) : IO (AsyncTask (Option ByteArray)) := - AsyncTask.ofPromise <$> s.native.recv? size +def recv? (s : Client) (size : UInt64) : Async (Option ByteArray) := + Async.ofPromise <| s.native.recv? size /-- Creates a `Selector` that resolves once `s` has data available, up to at most `size` bytes, @@ -146,7 +146,7 @@ def recvSelector (s : TCP.Socket.Client) (size : UInt64) : IO (Selector (Option tryFn := do if ← readableWaiter.isResolved then -- We know that this read should not block - let res ← (← s.recv? size).block + let res ← (s.recv? size).block return some res else return none @@ -161,7 +161,7 @@ def recvSelector (s : TCP.Socket.Client) (size : UInt64) : IO (Selector (Option try discard <| IO.ofExcept res -- We know that this read should not block - let res ← (← s.recv? size).block + let res ← (s.recv? size).block promise.resolve (.ok res) catch e => promise.resolve (.error e) @@ -173,8 +173,8 @@ def recvSelector (s : TCP.Socket.Client) (size : UInt64) : IO (Selector (Option Shuts down the write side of the client socket. -/ @[inline] -def shutdown (s : Client) : IO (AsyncTask Unit) := - AsyncTask.ofPromise <$> s.native.shutdown +def shutdown (s : Client) : Async Unit := + Async.ofPromise <| s.native.shutdown /-- Gets the remote address of the client socket. @@ -205,7 +205,6 @@ def keepAlive (s : Client) (enable : Bool) (delay : Std.Time.Second.Offset) (_ : s.native.keepAlive enable.toInt8 delay.val.toNat.toUInt32 end Client - end Socket end TCP end Async diff --git a/src/Std/Internal/Async/Timer.lean b/src/Std/Internal/Async/Timer.lean index 879b60f2dde5..1f73d9c8b7eb 100644 --- a/src/Std/Internal/Async/Timer.lean +++ b/src/Std/Internal/Async/Timer.lean @@ -33,35 +33,34 @@ Set up a `Sleep` that waits for `duration` milliseconds. This function only initializes but does not yet start the timer. -/ @[inline] -def mk (duration : Std.Time.Millisecond.Offset) : IO Sleep := do +def mk (duration : Std.Time.Millisecond.Offset) : Async Sleep := do let native ← Internal.UV.Timer.mk duration.toInt.toNat.toUInt64 false return ofNative native /-- If: -- `s` is not yet running start it and return an `AsyncTask` that will resolve once the previously - configured `duration` has run out. -- `s` is already or not anymore running return the same `AsyncTask` as the first call to `wait`. +- `s` is not yet running start it and return an `Async` computation that will complete once the previously + configured `duration` has elapsed. +- `s` is already or not anymore running return the same `Async` computation as the first call to `wait`. -/ @[inline] -def wait (s : Sleep) : IO (AsyncTask Unit) := do - let promise ← s.native.next - return .ofPurePromise promise +def wait (s : Sleep) : Async Unit := + Async.ofPurePromise s.native.next /-- If: -- `s` is still running the timer restarts counting from now and finishes after `duration` +- `s` is still running the timer restarts counting from now and completes after `duration` milliseconds. - `s` is not yet or not anymore running this is a no-op. -/ @[inline] -def reset (s : Sleep) : IO Unit := +def reset (s : Sleep) : Async Unit := s.native.reset /-- If: -- `s` is still running this stops `s` without resolving any remaining `AsyncTask`s that were created - through `wait`. Note that if another `AsyncTask` is binding on any of these it is going hang +- `s` is still running this stops `s` without completing any remaining `Async` computations that were created + through `wait`. Note that if another `Async` computation is binding on any of these it will hang forever without further intervention. - `s` is not yet or not anymore running this is a no-op. -/ @@ -73,8 +72,8 @@ def stop (s : Sleep) : IO Unit := Create a `Selector` that resolves once `s` has finished. Note that calling this function starts `s` if it hasn't already started. -/ -def selector (s : Sleep) : IO (Selector Unit) := do - let sleepWaiter ← s.wait +def selector (s : Sleep) : Async (Selector Unit) := do + let sleepWaiter ← s.wait.asTask return { tryFn := do if ← IO.hasFinished sleepWaiter then @@ -92,16 +91,16 @@ def selector (s : Sleep) : IO (Selector Unit) := do end Sleep /-- -Return an `AsyncTask` that resolves after `duration`. +Return an `Async` computation that completes after `duration`. -/ -def sleep (duration : Std.Time.Millisecond.Offset) : IO (AsyncTask Unit) := do +def sleep (duration : Std.Time.Millisecond.Offset) : Async Unit := do let sleeper ← Sleep.mk duration sleeper.wait /-- -Return a `Selector` that resolves after `duration`. +Return a `Selector` that completes after `duration`. -/ -def Selector.sleep (duration : Std.Time.Millisecond.Offset) : IO (Selector Unit) := do +def Selector.sleep (duration : Std.Time.Millisecond.Offset) : Async (Selector Unit) := do let sleeper ← Sleep.mk duration sleeper.selector @@ -113,7 +112,6 @@ structure Interval where private ofNative :: native : Internal.UV.Timer - namespace Interval /-- @@ -127,19 +125,18 @@ def mk (duration : Std.Time.Millisecond.Offset) (_ : 0 < duration := by decide) /-- If: -- `i` is not yet running start it and return an `AsyncTask` that resolves right away as the 0th +- `i` is not yet running start it and return an `Async` computation that completes right away as the 0th multiple of `duration` has elapsed. - `i` is already running and: - - the tick from the last call of `i` has not yet finished return the same `AsyncTask` as the last + - the tick from the last call of `i` has not yet finished return the same `Async` computation as the last call - - the tick from the last call of `i` has finished return a new `AsyncTask` that waits for the + - the tick from the last call of `i` has finished return a new `Async` computation that waits for the closest next tick from the time of calling this function. - `i` is not running anymore this is a no-op. -/ @[inline] -def tick (i : Interval) : IO (AsyncTask Unit) := do - let promise ← i.native.next - return .ofPurePromise promise +def tick (i : Interval) : Async Unit := do + Async.ofPurePromise i.native.next /-- If: @@ -153,8 +150,8 @@ def reset (i : Interval) : IO Unit := /-- If: -- `i` is still running this stops `i` without resolving any remaining `AsyncTask` that were created - through `tick`. Note that if another `AsyncTask` is binding on any of these it is going hang +- `i` is still running this stops `i` without completing any remaining `Async` computations that were created + through `tick`. Note that if another `Async` computation is binding on any of these it will hang forever without further intervention. - `i` is not yet or not anymore running this is a no-op. -/ diff --git a/src/Std/Internal/Async/UDP.lean b/src/Std/Internal/Async/UDP.lean index 37f84fc37bff..9c1d2d7e8371 100644 --- a/src/Std/Internal/Async/UDP.lean +++ b/src/Std/Internal/Async/UDP.lean @@ -66,8 +66,8 @@ Sends data through an UDP socket. The `addr` parameter specifies the destination is `none`, the data is sent to the default peer address set by `connect`. -/ @[inline] -def send (s : Socket) (data : ByteArray) (addr : Option SocketAddress := none) : IO (AsyncTask Unit) := - AsyncTask.ofPromise <$> s.native.send data addr +def send (s : Socket) (data : ByteArray) (addr : Option SocketAddress := none) : Async Unit := + Async.ofPromise <| s.native.send data addr /-- Receives data from an UDP socket. `size` is for the maximum bytes to receive. @@ -77,8 +77,8 @@ has not been previously bound with `bind`, it is automatically bound to `0.0.0.0 Furthermore calling this function in parallel with `recvSelector` is not supported. -/ @[inline] -def recv (s : Socket) (size : UInt64) : IO (AsyncTask (ByteArray × Option SocketAddress)) := - AsyncTask.ofPromise <$> s.native.recv size +def recv (s : Socket) (size : UInt64) : Async (ByteArray × Option SocketAddress) := + Async.ofPromise <| s.native.recv size /-- Creates a `Selector` that resolves once `s` has data available, up to at most `size` bytes, @@ -93,7 +93,7 @@ def recvSelector (s : Socket) (size : UInt64) : tryFn := do if ← readableWaiter.isResolved then -- We know that this read should not block - let res ← (← s.recv size).block + let res ← (s.recv size).block return some res else return none @@ -108,7 +108,7 @@ def recvSelector (s : Socket) (size : UInt64) : try discard <| IO.ofExcept res -- We know that this read should not block - let res ← (← s.recv size).block + let res ← (s.recv size).block promise.resolve (.ok res) catch e => promise.resolve (.error e) diff --git a/tests/lean/run/async_base_functions.lean b/tests/lean/run/async_base_functions.lean index dec2fe81c9f9..c0d79803d77b 100644 --- a/tests/lean/run/async_base_functions.lean +++ b/tests/lean/run/async_base_functions.lean @@ -23,7 +23,7 @@ def sequential : Async Unit := do def conc : Async Unit := do let ref ← Std.Mutex.new 0 - discard <| concurrently (wait 200 ref 1) (wait 400 ref 2) + discard <| Async.concurrently (wait 200 ref 1) (wait 1000 ref 2) ref.atomically (·.modify (· * 10)) assert! (← ref.atomically (·.get)) == 30 @@ -31,7 +31,7 @@ def conc : Async Unit := do def racer : Async Unit := do let ref ← Std.Mutex.new 0 - race (wait 200 ref 1) (wait 400 ref 2) + Async.race (wait 200 ref 1) (wait 1000 ref 2) ref.atomically (·.modify (· * 10)) assert! (← ref.atomically (·.get)) == 10 @@ -39,7 +39,7 @@ def racer : Async Unit := do def concAll : Async Unit := do let ref ← Std.Mutex.new 0 - discard <| concurrentlyAll #[(wait 200 ref 1), (wait 400 ref 2)] + discard <| Async.concurrentlyAll #[(wait 200 ref 1), (wait 1000 ref 2)] ref.atomically (·.modify (· * 10)) assert! (← ref.atomically (·.get)) == 30 @@ -47,8 +47,31 @@ def concAll : Async Unit := do def racerAll : Async Unit := do let ref ← Std.Mutex.new 0 - raceAll #[(wait 200 ref 1), (wait 400 ref 2)] + Async.raceAll #[(wait 200 ref 1), (wait 1000 ref 2)] ref.atomically (·.modify (· * 10)) assert! (← ref.atomically (·.get)) == 10 #eval do (← racerAll.toEIO).block + +def racerAllNotCancels : Async Unit := do + let ref ← Std.Mutex.new 0 + Async.raceAll #[(wait 200 ref 1), (wait 700 ref 2)] + ref.atomically (·.modify (· * 10)) + IO.sleep 1000 + assert! (← ref.atomically (·.get)) == 12 + +#eval do (← racerAllNotCancels.toEIO).block + +def racerAllError : Async Unit := do + let ref ← Std.Mutex.new 0 + Async.raceAll #[(wait 400 ref 2), throw (IO.userError "error wins")] + +/-- error: error wins -/ +#guard_msgs in +#eval do (← racerAllError.toEIO).block + +def racerAllErrorLost : Async Unit := do + let result ← Async.raceAll #[(do IO.sleep 1000; throw (IO.userError "error wins")) , (do IO.sleep 200; pure 10)] + assert! result = 10 + +#eval do (← racerAllErrorLost.toEIO).block From 290971d9f443278a2928e7429a157319d46311c2 Mon Sep 17 00:00:00 2001 From: Sofia Rodrigues Date: Fri, 12 Sep 2025 13:52:01 -0300 Subject: [PATCH 2/6] feat: change Select to Async --- src/Std/Internal/Async/Select.lean | 33 ++--- src/Std/Internal/Async/TCP.lean | 2 +- src/Std/Internal/Async/UDP.lean | 2 +- tests/lean/run/async_dns.lean | 36 ++---- tests/lean/run/async_select_channel.lean | 40 +++--- tests/lean/run/async_select_socket.lean | 130 ++++++++++---------- tests/lean/run/async_select_timer.lean | 23 ++-- tests/lean/run/async_sleep.lean | 86 ++++++------- tests/lean/run/async_surface_sleep.lean | 88 ++++++------- tests/lean/run/async_tcp_fname_errors.lean | 12 +- tests/lean/run/async_tcp_half.lean | 12 +- tests/lean/run/async_tcp_server_client.lean | 28 ++--- tests/lean/run/async_udp_sockets.lean | 4 +- 13 files changed, 242 insertions(+), 254 deletions(-) diff --git a/src/Std/Internal/Async/Select.lean b/src/Std/Internal/Async/Select.lean index c6bd776dfbb6..f68940d1c0c0 100644 --- a/src/Std/Internal/Async/Select.lean +++ b/src/Std/Internal/Async/Select.lean @@ -99,7 +99,7 @@ structure Selectable (α : Type) where /-- The continuation that is called on results from the event source. -/ - cont : β → IO (AsyncTask α) + cont : β → Async α private def shuffleIt {α : Type u} (xs : Array α) (gen : StdGen) : Array α := go xs gen 0 @@ -123,16 +123,18 @@ The protocol for this is as follows: Once one of them resolves the `Waiter`, all `Selector.unregisterFn` functions are called, and the `Selectable.cont` of the winning `Selector` is executed and returned. -/ -partial def Selectable.one (selectables : Array (Selectable α)) : IO (AsyncTask α) := do +partial def Selectable.one (selectables : Array (Selectable α)) : Async α := do if selectables.isEmpty then throw <| .userError "Selectable.one requires at least one Selectable" let seed := UInt64.toNat (ByteArray.toUInt64LE! (← IO.getRandomBytes 8)) let gen := mkStdGen seed let selectables := shuffleIt selectables gen + for selectable in selectables do if let some val ← selectable.selector.tryFn then - return ← selectable.cont val + let result ← selectable.cont val + return result let finished ← IO.mkRef false let promise ← IO.Promise.new @@ -142,27 +144,30 @@ partial def Selectable.one (selectables : Array (Selectable α)) : IO (AsyncTask let waiter := Waiter.mk finished waiterPromise selectable.selector.registerFn waiter - IO.chainTask (t := waiterPromise.result?) fun res? => do + discard <| IO.bindTask (t := waiterPromise.result?) fun res? => do match res? with | none => /- If we get `none` that means the waiterPromise was dropped, usually due to cancellation. In this situation just do nothing. -/ - return () + return (Task.pure (.ok ())) | some res => - try - let res ← IO.ofExcept res + let async : Async _ := + try + let res ← IO.ofExcept res + + for selectable in selectables do + selectable.selector.unregisterFn - for selectable in selectables do - selectable.selector.unregisterFn + let contRes ← selectable.cont res + promise.resolve (.ok contRes) + catch e => + promise.resolve (.error e) - let contRes ← selectable.cont res - discard <| contRes.mapIO (promise.resolve <| .ok ·) - catch e => - promise.resolve (.error e) + async.toBaseIO - return AsyncTask.ofPromise promise + Async.ofPromise (pure promise) end Async end IO diff --git a/src/Std/Internal/Async/TCP.lean b/src/Std/Internal/Async/TCP.lean index de2853fd9f56..bf08cfd5dc15 100644 --- a/src/Std/Internal/Async/TCP.lean +++ b/src/Std/Internal/Async/TCP.lean @@ -140,7 +140,7 @@ Creates a `Selector` that resolves once `s` has data available, up to at most `s and provides that data. Calling this function starts the data wait, so it must not be called in parallel with `recv?`. -/ -def recvSelector (s : TCP.Socket.Client) (size : UInt64) : IO (Selector (Option ByteArray)) := do +def recvSelector (s : TCP.Socket.Client) (size : UInt64) : Async (Selector (Option ByteArray)) := do let readableWaiter ← s.native.waitReadable return { tryFn := do diff --git a/src/Std/Internal/Async/UDP.lean b/src/Std/Internal/Async/UDP.lean index 9c1d2d7e8371..f1d9c9c6ce1c 100644 --- a/src/Std/Internal/Async/UDP.lean +++ b/src/Std/Internal/Async/UDP.lean @@ -87,7 +87,7 @@ automatically bound to `0.0.0.0` (all interfaces) with a random port. Calling this function starts the data wait, so it must not be called in parallel with `recv`. -/ def recvSelector (s : Socket) (size : UInt64) : - IO (Selector (ByteArray × Option SocketAddress)) := do + Async (Selector (ByteArray × Option SocketAddress)) := do let readableWaiter ← s.native.waitReadable return { tryFn := do diff --git a/tests/lean/run/async_dns.lean b/tests/lean/run/async_dns.lean index b273282bc272..c6c5dbf617b6 100644 --- a/tests/lean/run/async_dns.lean +++ b/tests/lean/run/async_dns.lean @@ -12,47 +12,27 @@ def assertBEq [BEq α] [ToString α] (actual expected : α) : IO Unit := do throw <| IO.userError <| s!"expected '{expected}', got '{actual}'" -def baseSelector (asyncWaiter : AsyncTask α) : Selector α := - { - tryFn := do - if ← IO.hasFinished asyncWaiter then - let result ← IO.ofExcept asyncWaiter.get - return some result - else - return none - registerFn waiter := do - discard <| AsyncTask.mapIO (x := asyncWaiter) fun data => do - let lose := return () - let win promise := promise.resolve (.ok data) - waiter.race lose win - unregisterFn := pure () - } +def timeout [Inhabited α] (a : Async α) (time : Std.Time.Millisecond.Offset) : Async α := do + let result ← Async.race (a.map Except.ok) (sleep time |>.map Except.error) -def race (a : AsyncTask α) (b : AsyncTask β) (map : Except α β → AsyncTask γ) : IO (AsyncTask γ) := do - Selectable.one #[ - .case (baseSelector a) fun a => return map (.error a), - .case (baseSelector b) fun b => return map (.ok b), - ] - -def timeout (a : AsyncTask α) (time : Std.Time.Millisecond.Offset) : IO (AsyncTask α) := do - race (← sleep time) a fun - | .ok res => Task.pure (.ok res) - | .error _ => Task.pure (.error (IO.userError "Timeout.")) + match result with + | .ok res => pure res + | .error _ => throw (.userError "timeout") def runDNS : Async Unit := do - let infos ← await <| (← timeout (← DNS.getAddrInfo "google.com" "http") 10000) + let infos ← timeout (DNS.getAddrInfo "google.com" "http") 1000 unless infos.size > 0 do (throw <| IO.userError <| "No DNS results for google.com" : IO _) def runDNSNoAscii : Async Unit := do - let infos ← await <| (← timeout (← DNS.getAddrInfo "google.com▸" "http") 10000) + let infos ← timeout (DNS.getAddrInfo "google.com▸" "http") 10000 unless infos.size > 0 do (throw <| IO.userError <| "No DNS results for google.com" : IO _) def runReverseDNS : Async Unit := do - let result ← await (← DNS.getNameInfo (.v4 ⟨.ofParts 8 8 8 8, 53⟩)) + let result ← DNS.getNameInfo (.v4 ⟨.ofParts 8 8 8 8, 53⟩) assertBEq result.service "domain" assertBEq result.host "dns.google" diff --git a/tests/lean/run/async_select_channel.lean b/tests/lean/run/async_select_channel.lean index 2f6b641af9d4..258a67b3cf69 100644 --- a/tests/lean/run/async_select_channel.lean +++ b/tests/lean/run/async_select_channel.lean @@ -4,24 +4,24 @@ open Std Internal IO Async namespace A -def testReceiver (ch1 ch2 : Std.Channel Nat) (count : Nat) : IO (AsyncTask Nat) := do +def testReceiver (ch1 ch2 : Std.Channel Nat) (count : Nat) : Async Nat := do go ch1 ch2 count 0 where - go (ch1 ch2 : Std.Channel Nat) (count : Nat) (acc : Nat) : IO (AsyncTask Nat) := do + go (ch1 ch2 : Std.Channel Nat) (count : Nat) (acc : Nat) : Async Nat := do match count with - | 0 => return AsyncTask.pure acc + | 0 => return acc | count + 1 => Selectable.one #[ .case ch1.recvSelector fun data => go ch1 ch2 count (acc + data), .case ch2.recvSelector fun data => go ch1 ch2 count (acc + data), ] -def testIt (capacity : Option Nat) : IO Bool := do +def testIt (capacity : Option Nat) : Async Bool := do let amount := 1000 let messages := Array.range amount let ch1 ← Std.Channel.new capacity let ch2 ← Std.Channel.new capacity - let recvTask ← testReceiver ch1 ch2 amount + let recvTask ← async (testReceiver ch1 ch2 amount) for msg in messages do if (← IO.rand 0 1) = 0 then @@ -29,47 +29,47 @@ def testIt (capacity : Option Nat) : IO Bool := do else ch2.sync.send msg - let acc ← recvTask.block + let acc ← await recvTask return acc == messages.sum /-- info: true -/ #guard_msgs in -#eval testIt none +#eval testIt none |>.block /-- info: true -/ #guard_msgs in -#eval testIt (some 0) +#eval testIt (some 0) |>.block /-- info: true -/ #guard_msgs in -#eval testIt (some 1) +#eval testIt (some 1) |>.block /-- info: true -/ #guard_msgs in -#eval testIt (some 128) +#eval testIt (some 128) |>.block end A namespace B -def testReceiver (ch1 ch2 : Std.CloseableChannel Nat) (count : Nat) : IO (AsyncTask Nat) := do +def testReceiver (ch1 ch2 : Std.CloseableChannel Nat) (count : Nat) : Async Nat := do go ch1 ch2 count 0 where - go (ch1 ch2 : Std.CloseableChannel Nat) (count : Nat) (acc : Nat) : IO (AsyncTask Nat) := do + go (ch1 ch2 : Std.CloseableChannel Nat) (count : Nat) (acc : Nat) : Async Nat := do match count with - | 0 => return AsyncTask.pure acc + | 0 => return acc | count + 1 => Selectable.one #[ .case ch1.recvSelector fun data => go ch1 ch2 count (acc + data.getD 0), .case ch2.recvSelector fun data => go ch1 ch2 count (acc + data.getD 0), ] -def testIt (capacity : Option Nat) : IO Bool := do +def testIt (capacity : Option Nat) : Async Bool := do let amount := 1000 let messages := Array.range amount let ch1 ← Std.CloseableChannel.new capacity let ch2 ← Std.CloseableChannel.new capacity - let recvTask ← testReceiver ch1 ch2 amount + let recvTask ← async (testReceiver ch1 ch2 amount) for msg in messages do if (← IO.rand 0 1) = 0 then @@ -77,23 +77,23 @@ def testIt (capacity : Option Nat) : IO Bool := do else ch2.sync.send msg - let acc ← recvTask.block + let acc ← await recvTask return acc == messages.sum /-- info: true -/ #guard_msgs in -#eval testIt none +#eval testIt none |>.block /-- info: true -/ #guard_msgs in -#eval testIt (some 0) +#eval testIt (some 0) |>.block /-- info: true -/ #guard_msgs in -#eval testIt (some 1) +#eval testIt (some 1) |>.block /-- info: true -/ #guard_msgs in -#eval testIt (some 128) +#eval testIt (some 128) |>.block end B diff --git a/tests/lean/run/async_select_socket.lean b/tests/lean/run/async_select_socket.lean index 79aa912aa79a..5bf9e5459d9a 100644 --- a/tests/lean/run/async_select_socket.lean +++ b/tests/lean/run/async_select_socket.lean @@ -6,92 +6,98 @@ open Std Internal IO Async namespace TCP -def testClient (addr : Net.SocketAddress) : IO (AsyncTask String) := do +def testClient (addr : Net.SocketAddress) : Async String := do let client ← TCP.Socket.Client.mk - (← client.connect addr).bindIO fun _ => do - Selectable.one #[ - .case (← Selector.sleep 1000) fun _ => return AsyncTask.pure "Timeout", - .case (← client.recvSelector 4096) fun data? => do - if let some data := data? then - return AsyncTask.pure <| String.fromUTF8! data - else - return AsyncTask.pure "Closed" - ] - -def test (serverFn : TCP.Socket.Server → IO (AsyncTask Unit)) (addr : Net.SocketAddress) : - IO Unit := do + client.connect addr + + Selectable.one #[ + .case (← Selector.sleep 1000) fun _ => return "Timeout", + .case (← client.recvSelector 4096) fun data? => do + if let some data := data? then + return String.fromUTF8! data + else + return "Closed" + ] + +def test (serverFn : TCP.Socket.Server → Async Unit) (addr : Net.SocketAddress) : Async String := do let server ← TCP.Socket.Server.mk server.bind addr server.listen 1 - let serverTask ← serverFn server - let clientTask ← testClient addr - serverTask.block - IO.println (← clientTask.block) + let serverTask ← async (serverFn server) + let clientTask ← async (testClient addr) + await serverTask + await clientTask -def testServerSend (server : TCP.Socket.Server) : IO (AsyncTask Unit) := do - (← server.accept).bindIO fun client => do - client.send (String.toUTF8 "Success") +def testServerSend (server : TCP.Socket.Server) : Async Unit := do + let client ← server.accept + client.send (String.toUTF8 "Success") -def testServerTimeout (server : TCP.Socket.Server) : IO (AsyncTask Unit) := do - (← server.accept).bindIO fun client => do - (← Async.sleep 1500).bindIO fun _ => do - client.shutdown +def testServerTimeout (server : TCP.Socket.Server) : Async Unit := do + let client ← server.accept + Async.sleep 1500 + client.shutdown -def testServerClose (server : TCP.Socket.Server) : IO (AsyncTask Unit) := do - (← server.accept).bindIO fun client => client.shutdown +def testServerClose (server : TCP.Socket.Server) : Async Unit := do + let client ← server.accept + client.shutdown -/-- info: Success -/ +/-- info: "Success" -/ #guard_msgs in -#eval test testServerSend (Net.SocketAddressV4.mk (.ofParts 127 0 0 1) 7070) +#eval test testServerSend (Net.SocketAddressV4.mk (.ofParts 127 0 0 1) 7070) |>.block -/-- info: Closed -/ +/-- info: "Closed" -/ #guard_msgs in -#eval test testServerClose (Net.SocketAddressV4.mk (.ofParts 127 0 0 1) 7071) +#eval test testServerClose (Net.SocketAddressV4.mk (.ofParts 127 0 0 1) 7071) |>.block -/-- info: Timeout -/ +/-- info: "Timeout" -/ #guard_msgs in -#eval test testServerTimeout (Net.SocketAddressV4.mk (.ofParts 127 0 0 1) 7072) +#eval test testServerTimeout (Net.SocketAddressV4.mk (.ofParts 127 0 0 1) 7072) |>.block end TCP - namespace UDP -def testClient (addr : Net.SocketAddress) : IO (AsyncTask String) := do +def testClient (addr : Net.SocketAddress) : Async String := do + IO.println "sending client" let client ← UDP.Socket.mk client.connect addr - (← client.send "ping".toUTF8).bindIO fun _ => do - Selectable.one #[ - .case (← Selector.sleep 1000) fun _ => return AsyncTask.pure "Timeout", - .case (← client.recvSelector 4096) fun (data, _) => do - return AsyncTask.pure <| String.fromUTF8! data - ] - -def test (serverFn : UDP.Socket → IO (AsyncTask Unit)) (addr : Net.SocketAddress) : IO Unit := do + client.send "ping".toUTF8 + + Selectable.one #[ + .case (← Selector.sleep 1000) fun _ => return "Timeout", + .case (← client.recvSelector 4096) fun (data, _) => do + return String.fromUTF8! data + ] + +def test (serverFn : UDP.Socket → Async Unit) (addr : Net.SocketAddress) : Async String := do let server ← UDP.Socket.mk server.bind addr - let serverTask ← serverFn server - let clientTask ← testClient addr - serverTask.block - IO.println (← clientTask.block) - -def testServerSend (server : UDP.Socket) : IO (AsyncTask Unit) := do - (← server.recv 4096).bindIO fun (_, client?) => do - let client := client?.get! - server.send (String.toUTF8 "Success") client - -def testServerTimeout (server : UDP.Socket) : IO (AsyncTask Unit) := do - (← server.recv 4096).bindIO fun (_, client?) => do - let client := client?.get! - (← Async.sleep 1500).bindIO fun _ => do - server.send (String.toUTF8 "Success") client - -/-- info: Success -/ + let serverTask ← async (serverFn server) + let clientTask ← async (testClient addr) + await serverTask + await clientTask + +def testServerSend (server : UDP.Socket) : Async Unit := do + let (_, client?) ← server.recv 4096 + let client := client?.get! + server.send (String.toUTF8 "Success") client + +def testServerTimeout (server : UDP.Socket) : Async Unit := do + let (_, client?) ← server.recv 4096 + let client := client?.get! + Async.sleep 1500 + server.send (String.toUTF8 "Success") client + +/-- +info: "Success" +-/ #guard_msgs in -#eval test testServerSend (Net.SocketAddressV4.mk (.ofParts 127 0 0 1) 7070) +#eval test testServerSend (Net.SocketAddressV4.mk (.ofParts 127 0 0 1) 7075) |>.block -/-- info: Timeout -/ +/-- +info: "Timeout" +-/ #guard_msgs in -#eval test testServerTimeout (Net.SocketAddressV4.mk (.ofParts 127 0 0 1) 7072) +#eval test testServerTimeout (Net.SocketAddressV4.mk (.ofParts 127 0 0 1) 7075) |>.block end UDP diff --git a/tests/lean/run/async_select_timer.lean b/tests/lean/run/async_select_timer.lean index 7999798c7b5d..e6e528741fe0 100644 --- a/tests/lean/run/async_select_timer.lean +++ b/tests/lean/run/async_select_timer.lean @@ -2,33 +2,30 @@ import Std.Internal.Async.Timer open Std Internal IO Async -def test1 : IO (AsyncTask Nat) := do +def test1 : Async Nat := do let s1 ← Sleep.mk 1000 let s2 ← Sleep.mk 1500 Selectable.one #[ - .case (← s2.selector) fun _ => return AsyncTask.pure 2, - .case (← s1.selector) fun _ => return AsyncTask.pure 1, + .case (← s2.selector) fun _ => return 2, + .case (← s1.selector) fun _ => return 1, ] /-- info: 1 -/ #guard_msgs in -#eval show IO _ from do - let task ← test1 - IO.ofExcept task.get +#eval test1 |>.block -def test2 : IO (AsyncTask Nat) := do +def test2 : Async Nat := do Selectable.one #[ - .case (← Selector.sleep 1500) fun _ => return AsyncTask.pure 2, - .case (← Selector.sleep 1000) fun _ => return AsyncTask.pure 1, + .case (← Selector.sleep 1500) fun _ => return 2, + .case (← Selector.sleep 1000) fun _ => return 1, ] /-- info: 1 -/ #guard_msgs in -#eval show IO _ from do - let task ← test2 - IO.ofExcept task.get +#eval EAsync.block <| show Async _ from do + test2 /-- error: Selectable.one requires at least one Selectable -/ #guard_msgs in -#eval show IO _ from do +#eval EAsync.block <| show Async _ from do let foo ← Selectable.one (α := Unit) #[] diff --git a/tests/lean/run/async_sleep.lean b/tests/lean/run/async_sleep.lean index b8a438515235..3045b058edbb 100644 --- a/tests/lean/run/async_sleep.lean +++ b/tests/lean/run/async_sleep.lean @@ -27,12 +27,12 @@ def oneShotSleep : IO Unit := do assertDuration BASE_DURATION EPS do let timer ← Timer.mk BASE_DURATION.toUInt64 false let p ← timer.next - await p.result + await p.result! def promiseBehavior1 : IO Unit := do let timer ← Timer.mk BASE_DURATION.toUInt64 false let p ← timer.next - let r := p.result + let r := p.result! assert! (← IO.getTaskState r) != .finished IO.sleep (BASE_DURATION + EPS).toUInt32 assert! (← IO.getTaskState r) == .finished @@ -41,35 +41,35 @@ def promiseBehavior2 : IO Unit := do let timer ← Timer.mk BASE_DURATION.toUInt64 false let p1 ← timer.next let p2 ← timer.next - assert! (← IO.getTaskState p1.result) != .finished - assert! (← IO.getTaskState p2.result) != .finished + assert! (← IO.getTaskState p1.result!) != .finished + assert! (← IO.getTaskState p2.result!) != .finished IO.sleep (BASE_DURATION + EPS).toUInt32 - assert! (← IO.getTaskState p1.result) == .finished - assert! (← IO.getTaskState p2.result) == .finished + assert! (← IO.getTaskState p1.result!) == .finished + assert! (← IO.getTaskState p2.result!) == .finished def promiseBehavior3 : IO Unit := do let timer ← Timer.mk BASE_DURATION.toUInt64 false let p1 ← timer.next - assert! (← IO.getTaskState p1.result) != .finished + assert! (← IO.getTaskState p1.result!) != .finished IO.sleep (BASE_DURATION + EPS).toUInt32 - assert! (← IO.getTaskState p1.result) == .finished + assert! (← IO.getTaskState p1.result!) == .finished let p3 ← timer.next - assert! (← IO.getTaskState p3.result) == .finished + assert! (← IO.getTaskState p3.result!) == .finished def resetBehavior : IO Unit := do let timer ← Timer.mk BASE_DURATION.toUInt64 false let p ← timer.next - assert! (← IO.getTaskState p.result) != .finished + assert! (← IO.getTaskState p.result!) != .finished IO.sleep (BASE_DURATION / 2).toUInt32 - assert! (← IO.getTaskState p.result) != .finished + assert! (← IO.getTaskState p.result!) != .finished timer.reset IO.sleep (BASE_DURATION / 2).toUInt32 - assert! (← IO.getTaskState p.result) != .finished + assert! (← IO.getTaskState p.result!) != .finished IO.sleep ((BASE_DURATION / 2) + EPS).toUInt32 - assert! (← IO.getTaskState p.result) == .finished + assert! (← IO.getTaskState p.result!) == .finished #eval oneShotSleep #eval promiseBehavior1 @@ -88,7 +88,7 @@ where go : IO Unit := do let timer ← Timer.mk BASE_DURATION.toUInt64 true let prom ← timer.next - await prom.result + await prom.result! timer.stop def sleepSecond : IO Unit := do @@ -98,8 +98,8 @@ where let timer ← Timer.mk BASE_DURATION.toUInt64 true let task ← - IO.bindTask (← timer.next).result fun _ => do - IO.bindTask (← timer.next).result fun _ => pure (Task.pure (.ok 2)) + IO.bindTask (← timer.next).result! fun _ => do + IO.bindTask (← timer.next).result! fun _ => pure (Task.pure (.ok 2)) discard <| await task timer.stop @@ -108,88 +108,88 @@ def promiseBehavior1 : IO Unit := do let timer ← Timer.mk BASE_DURATION.toUInt64 true let p1 ← timer.next IO.sleep EPS.toUInt32 - assert! (← IO.getTaskState p1.result) == .finished + assert! (← IO.getTaskState p1.result!) == .finished let p2 ← timer.next - assert! (← IO.getTaskState p2.result) != .finished + assert! (← IO.getTaskState p2.result!) != .finished IO.sleep (BASE_DURATION + EPS).toUInt32 - assert! (← IO.getTaskState p2.result) == .finished + assert! (← IO.getTaskState p2.result!) == .finished timer.stop def promiseBehavior2 : IO Unit := do let timer ← Timer.mk BASE_DURATION.toUInt64 true let p1 ← timer.next IO.sleep EPS.toUInt32 - assert! (← IO.getTaskState p1.result) == .finished + assert! (← IO.getTaskState p1.result!) == .finished let prom1 ← timer.next let prom2 ← timer.next - assert! (← IO.getTaskState prom1.result) != .finished - assert! (← IO.getTaskState prom2.result) != .finished + assert! (← IO.getTaskState prom1.result!) != .finished + assert! (← IO.getTaskState prom2.result!) != .finished IO.sleep (BASE_DURATION + EPS).toUInt32 - assert! (← IO.getTaskState prom1.result) == .finished - assert! (← IO.getTaskState prom2.result) == .finished + assert! (← IO.getTaskState prom1.result!) == .finished + assert! (← IO.getTaskState prom2.result!) == .finished timer.stop def promiseBehavior3 : IO Unit := do let timer ← Timer.mk BASE_DURATION.toUInt64 true let p1 ← timer.next IO.sleep EPS.toUInt32 - assert! (← IO.getTaskState p1.result) == .finished + assert! (← IO.getTaskState p1.result!) == .finished let prom1 ← timer.next - assert! (← IO.getTaskState prom1.result) != .finished + assert! (← IO.getTaskState prom1.result!) != .finished IO.sleep (BASE_DURATION + EPS).toUInt32 - assert! (← IO.getTaskState prom1.result) == .finished + assert! (← IO.getTaskState prom1.result!) == .finished let prom2 ← timer.next - assert! (← IO.getTaskState prom2.result) != .finished + assert! (← IO.getTaskState prom2.result!) != .finished IO.sleep (BASE_DURATION + EPS).toUInt32 - assert! (← IO.getTaskState prom2.result) == .finished + assert! (← IO.getTaskState prom2.result!) == .finished timer.stop def delayedTickBehavior : IO Unit := do let timer ← Timer.mk BASE_DURATION.toUInt64 true let p1 ← timer.next IO.sleep EPS.toUInt32 - assert! (← IO.getTaskState p1.result) == .finished + assert! (← IO.getTaskState p1.result!) == .finished IO.sleep (BASE_DURATION / 2).toUInt32 let p2 ← timer.next - assert! (← IO.getTaskState p2.result) != .finished + assert! (← IO.getTaskState p2.result!) != .finished IO.sleep ((BASE_DURATION / 2) + EPS).toUInt32 - assert! (← IO.getTaskState p2.result) == .finished + assert! (← IO.getTaskState p2.result!) == .finished timer.stop def skippedTickBehavior : IO Unit := do let timer ← Timer.mk BASE_DURATION.toUInt64 true let p1 ← timer.next IO.sleep EPS.toUInt32 - assert! (← IO.getTaskState p1.result) == .finished + assert! (← IO.getTaskState p1.result!) == .finished IO.sleep (2 * BASE_DURATION + BASE_DURATION / 2).toUInt32 let p2 ← timer.next - assert! (← IO.getTaskState p2.result) != .finished + assert! (← IO.getTaskState p2.result!) != .finished IO.sleep ((BASE_DURATION / 2) + EPS).toUInt32 - assert! (← IO.getTaskState p2.result) == .finished + assert! (← IO.getTaskState p2.result!) == .finished timer.stop def resetBehavior : IO Unit := do let timer ← Timer.mk BASE_DURATION.toUInt64 true let p1 ← timer.next IO.sleep EPS.toUInt32 - assert! (← IO.getTaskState p1.result) == .finished + assert! (← IO.getTaskState p1.result!) == .finished let prom ← timer.next - assert! (← IO.getTaskState prom.result) != .finished + assert! (← IO.getTaskState prom.result!) != .finished IO.sleep (BASE_DURATION / 2).toUInt32 - assert! (← IO.getTaskState prom.result) != .finished + assert! (← IO.getTaskState prom.result!) != .finished timer.reset IO.sleep (BASE_DURATION / 2).toUInt32 - assert! (← IO.getTaskState prom.result) != .finished + assert! (← IO.getTaskState prom.result!) != .finished IO.sleep ((BASE_DURATION / 2) + EPS).toUInt32 - assert! (← IO.getTaskState prom.result) == .finished + assert! (← IO.getTaskState prom.result!) == .finished timer.stop def sequentialSleep : IO Unit := do @@ -199,9 +199,9 @@ where let timer ← Timer.mk (BASE_DURATION / 2).toUInt64 true -- 0th interval ticks instantly let task ← - IO.bindTask (← timer.next).result fun _ => do - IO.bindTask (← timer.next).result fun _ => do - IO.bindTask (← timer.next).result fun _ => pure (Task.pure (.ok 2)) + IO.bindTask (← timer.next).result! fun _ => do + IO.bindTask (← timer.next).result! fun _ => do + IO.bindTask (← timer.next).result! fun _ => pure (Task.pure (.ok 2)) discard <| await task timer.stop diff --git a/tests/lean/run/async_surface_sleep.lean b/tests/lean/run/async_surface_sleep.lean index b548fd058f2c..d7590d1f4592 100644 --- a/tests/lean/run/async_surface_sleep.lean +++ b/tests/lean/run/async_surface_sleep.lean @@ -12,42 +12,42 @@ def BASE_DURATION : Std.Time.Millisecond.Offset := 10 namespace SleepTest def oneSleep : IO Unit := do - let task ← go - assert! (← task.block) == 37 + let task ← go.block + assert! task == 37 where - go : IO (AsyncTask Nat) := do + go : Async Nat := do let sleep ← Sleep.mk BASE_DURATION - (← sleep.wait).mapIO fun _ => - return 37 + sleep.wait + return 37 def doubleSleep : IO Unit := do - let task ← go - assert! (← task.block) == 37 + let task ← go.block + assert! task == 37 where - go : IO (AsyncTask Nat) := do + go : Async Nat := do let sleep ← Sleep.mk BASE_DURATION - (← sleep.wait).bindIO fun _ => do - (← sleep.wait).mapIO fun _ => - return 37 + sleep.wait + sleep.wait + return 37 def resetSleep : IO Unit := do - let task ← go - assert! (← task.block) == 37 + let task ← go.block + assert! task == 37 where - go : IO (AsyncTask Nat) := do + go : Async Nat := do let sleep ← Sleep.mk BASE_DURATION - let waiter ← sleep.wait + sleep.wait sleep.reset - waiter.mapIO fun _ => - return 37 + sleep.wait + return 37 def simpleSleep : IO Unit := do - let task ← go - assert! (← task.block) == 37 + let task ← go.block + assert! task == 37 where - go : IO (AsyncTask Nat) := do - (← sleep BASE_DURATION).mapIO fun _ => - return 37 + go : Async Nat := do + sleep BASE_DURATION + return 37 #eval oneSleep #eval doubleSleep @@ -59,38 +59,38 @@ end SleepTest namespace IntervalTest def oneSleep : IO Unit := do - let task ← go - assert! (← task.block) == 37 + let task ← go.block + assert! task == 37 where - go : IO (AsyncTask Nat) := do + go : Async Nat := do let interval ← Interval.mk BASE_DURATION - (← interval.tick).mapIO fun _ => do - interval.stop - return 37 + interval.tick + interval.stop + return 37 def doubleSleep : IO Unit := do - let task ← go - assert! (← task.block) == 37 + let task ← go.block + assert! task == 37 where - go : IO (AsyncTask Nat) := do + go : Async Nat := do let interval ← Interval.mk BASE_DURATION - (← interval.tick).bindIO fun _ => do - (← interval.tick).mapIO fun _ => do - interval.stop - return 37 + interval.tick + interval.tick + interval.stop + return 37 def resetSleep : IO Unit := do - let task ← go - assert! (← task.block) == 37 + let task ← go.block + assert! task == 37 where - go : IO (AsyncTask Nat) := do + go : Async Nat := do let interval ← Interval.mk BASE_DURATION - (← interval.tick).bindIO fun _ => do - let waiter ← interval.tick - interval.reset - waiter.mapIO fun _ => do - interval.stop - return 37 + interval.tick + let waiter := interval.tick + interval.reset + waiter + interval.stop + return 37 #eval oneSleep #eval doubleSleep diff --git a/tests/lean/run/async_tcp_fname_errors.lean b/tests/lean/run/async_tcp_fname_errors.lean index 01c75b4f46d0..4af75f991791 100644 --- a/tests/lean/run/async_tcp_fname_errors.lean +++ b/tests/lean/run/async_tcp_fname_errors.lean @@ -12,18 +12,18 @@ def assertBEq [BEq α] [ToString α] (actual expected : α) : IO Unit := do /-- Mike is another client. -/ def runMike (client: TCP.Socket.Client) : Async Unit := do - let message ← await (← client.recv? 1024) + let message ← client.recv? 1024 assertBEq (String.fromUTF8? =<< message) none /-- Joe is another client. -/ def runJoe (client: TCP.Socket.Client) : Async Unit := do - let message ← await (← client.recv? 1024) + let message ← client.recv? 1024 assertBEq (String.fromUTF8? =<< message) none /-- Robert is the server. -/ def runRobert (server: TCP.Socket.Server) : Async Unit := do - discard <| await (← server.accept) - discard <| await (← server.accept) + discard <| server.accept + discard <| server.accept def clientServer : IO Unit := do let addr := SocketAddressV4.mk (.ofParts 127 0 0 1) 8083 @@ -35,7 +35,7 @@ def clientServer : IO Unit := do assertBEq (← server.getSockName).port 8083 let joe ← TCP.Socket.Client.mk - let task ← joe.connect addr + let task ← joe.connect addr |>.toBaseIO task.block assertBEq (← joe.getPeerName).port 8083 @@ -43,7 +43,7 @@ def clientServer : IO Unit := do joe.noDelay let mike ← TCP.Socket.Client.mk - let task ← mike.connect addr + let task ← mike.connect addr |>.toBaseIO task.block assertBEq (← mike.getPeerName).port 8083 diff --git a/tests/lean/run/async_tcp_half.lean b/tests/lean/run/async_tcp_half.lean index c2fdf655e1b0..ab9f12aab674 100644 --- a/tests/lean/run/async_tcp_half.lean +++ b/tests/lean/run/async_tcp_half.lean @@ -15,9 +15,9 @@ def assertBEq [BEq α] [ToString α] (actual expected : α) : IO Unit := do def runJoe (addr: SocketAddress) : Async Unit := do let client ← TCP.Socket.Client.mk - await (← client.connect addr) - await (← client.send (String.toUTF8 "hello robert!")) - await (← client.shutdown) + client.connect addr + client.send (String.toUTF8 "hello robert!") + client.shutdown def listenClose : IO Unit := do let addr := SocketAddressV4.mk (.ofParts 127 0 0 1) 8080 @@ -35,15 +35,15 @@ def acceptClose : IO Unit := do let joeTask ← (runJoe addr).toIO - let task ← server.accept + let task ← server.accept |>.toBaseIO let client ← task.block - let mes ← client.recv? 1024 + let mes ← client.recv? 1024 |>.toBaseIO let msg ← mes.block assertBEq (String.fromUTF8? =<< msg) ("hello robert!") - let mes ← client.recv? 1024 + let mes ← client.recv? 1024 |>.toBaseIO let msg ← mes.block assertBEq (String.fromUTF8? =<< msg) none diff --git a/tests/lean/run/async_tcp_server_client.lean b/tests/lean/run/async_tcp_server_client.lean index 48ef4b71496a..5bcd1b2b8fee 100644 --- a/tests/lean/run/async_tcp_server_client.lean +++ b/tests/lean/run/async_tcp_server_client.lean @@ -15,29 +15,29 @@ def assertBEq [BEq α] [ToString α] (actual expected : α) : IO Unit := do /-- Mike is another client. -/ def runMike (client: TCP.Socket.Client) : Async Unit := do - let mes ← await (← client.recv? 1024) + let mes ← client.recv? 1024 assertBEq (String.fromUTF8? =<< mes) (some "hi mike!! :)") - await (← client.send (String.toUTF8 "hello robert!!")) - await (← client.shutdown) + client.send (String.toUTF8 "hello robert!!") + client.shutdown /-- Joe is another client. -/ def runJoe (client: TCP.Socket.Client) : Async Unit := do - let mes ← await (← client.recv? 1024) + let mes ← client.recv? 1024 assertBEq (String.fromUTF8? =<< mes) (some "hi joe! :)") - await (← client.send (String.toUTF8 "hello robert!")) - await (← client.shutdown) + client.send (String.toUTF8 "hello robert!") + client.shutdown /-- Robert is the server. -/ def runRobert (server: TCP.Socket.Server) : Async Unit := do - let joe ← await (← server.accept) - let mike ← await (← server.accept) + let joe ← server.accept + let mike ← server.accept - await (← joe.send (String.toUTF8 "hi joe! :)")) - let mes ← await (← joe.recv? 1024) + joe.send (String.toUTF8 "hi joe! :)") + let mes ← joe.recv? 1024 assertBEq (String.fromUTF8? =<< mes) (some "hello robert!") - await (← mike.send (String.toUTF8 "hi mike!! :)")) - let mes ← await (← mike.recv? 1024) + mike.send (String.toUTF8 "hi mike!! :)") + let mes ← mike.recv? 1024 assertBEq (String.fromUTF8? =<< mes) (some "hello robert!!") pure () @@ -54,7 +54,7 @@ def clientServer (addr : SocketAddress) : IO Unit := do assertBEq (← server.getSockName).port addr.port let joe ← TCP.Socket.Client.mk - let task ← joe.connect addr + let task ← joe.connect addr |>.toBaseIO task.block assertBEq (← joe.getPeerName).port addr.port @@ -62,7 +62,7 @@ def clientServer (addr : SocketAddress) : IO Unit := do joe.noDelay let mike ← TCP.Socket.Client.mk - let task ← mike.connect addr + let task ← mike.connect addr |>.toBaseIO task.block assertBEq (← mike.getPeerName).port addr.port diff --git a/tests/lean/run/async_udp_sockets.lean b/tests/lean/run/async_udp_sockets.lean index a7bd64b5a1ef..2e93cca92818 100644 --- a/tests/lean/run/async_udp_sockets.lean +++ b/tests/lean/run/async_udp_sockets.lean @@ -18,7 +18,7 @@ def runJoe (addr : UInt16 → SocketAddress) (first second : UInt16) : Async Uni client.bind (addr second) client.connect (addr first) - await (← client.send (String.toUTF8 "hello robert!")) + client.send (String.toUTF8 "hello robert!") def acceptClose (addr : UInt16 → SocketAddress) (first second : UInt16) : IO Unit := do @@ -29,7 +29,7 @@ def acceptClose (addr : UInt16 → SocketAddress) (first second : UInt16) : IO U let res ← (runJoe addr first second).toIO res.block - let res ← server.recv 1024 + let res ← server.recv 1024 |>.toBaseIO let (msg, addr) ← res.block if let some addr := addr then From 3ea8218567b7c72929627cb5a402d5618a59c35a Mon Sep 17 00:00:00 2001 From: Sofia Rodrigues Date: Mon, 15 Sep 2025 21:06:32 -0300 Subject: [PATCH 3/6] fix: change selectable base to async --- src/Std/Internal/Async/Select.lean | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Std/Internal/Async/Select.lean b/src/Std/Internal/Async/Select.lean index f68940d1c0c0..aa30feb20fee 100644 --- a/src/Std/Internal/Async/Select.lean +++ b/src/Std/Internal/Async/Select.lean @@ -72,18 +72,18 @@ structure Selector (α : Type) where Attempts to retrieve a piece of data from the event source in a non-blocking fashion, returning `some` if data is available and `none` otherwise. -/ - tryFn : IO (Option α) + tryFn : Async (Option α) /-- Registers a `Waiter` with the event source. Once data is available, the event source should attempt to call `Waiter.race` and resolve the `Waiter`'s promise if it wins. It is crucial that data is never actually consumed from the event source unless `Waiter.race` wins in order to prevent data loss. -/ - registerFn : Waiter α → IO Unit + registerFn : Waiter α → Async Unit /-- A cleanup function that is called once any `Selector` has won the `Selectable.one` race. -/ - unregisterFn : IO Unit + unregisterFn : Async Unit /-- An event source together with a continuation to call on data obtained from that event source, From a0a95df8518f95e99e624f25af4cb282651ab517 Mon Sep 17 00:00:00 2001 From: Sofia Rodrigues Date: Mon, 15 Sep 2025 21:25:39 -0300 Subject: [PATCH 4/6] fix: channel --- src/Std/Sync/Channel.lean | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/src/Std/Sync/Channel.lean b/src/Std/Sync/Channel.lean index bad33f035563..19e1e01e3282 100644 --- a/src/Std/Sync/Channel.lean +++ b/src/Std/Sync/Channel.lean @@ -578,7 +578,7 @@ private partial def recvSelector (ch : Bounded α) : Selector (Option α) where set { st with consumers } where - registerAux (ch : Bounded α) (waiter : Waiter (Option α)) : IO Unit := do + registerAux (ch : Bounded α) (waiter : Waiter (Option α)) : Async Unit := do ch.state.atomically do -- We did drop the lock between `tryFn` and now so maybe ready? if ← recvReady' then @@ -597,16 +597,17 @@ where let promise ← IO.Promise.new modify fun st => { st with consumers := st.consumers.enqueue ⟨promise, some waiter⟩ } - IO.chainTask promise.result? fun res? => do - match res? with - | none => return () - | some res => - if res then - registerAux ch waiter - else - let lose := return () - let win promise := promise.resolve (.ok none) - waiter.race lose win + let result ← await promise.result? + + match result with + | none => return () + | some res => + if res then + registerAux ch waiter + else + let lose := return () + let win promise := promise.resolve (.ok none) + waiter.race lose win end Bounded From e3ca4d15a8facfa5ea92961d41e813940bf2c353 Mon Sep 17 00:00:00 2001 From: Sofia Rodrigues Date: Mon, 15 Sep 2025 23:25:51 -0300 Subject: [PATCH 5/6] fix: channel --- src/Std/Sync/Channel.lean | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/src/Std/Sync/Channel.lean b/src/Std/Sync/Channel.lean index 19e1e01e3282..37f0ae237a65 100644 --- a/src/Std/Sync/Channel.lean +++ b/src/Std/Sync/Channel.lean @@ -566,7 +566,7 @@ private partial def recvSelector (ch : Bounded α) : Selector (Option α) where else return none - registerFn := registerAux ch + registerFn x := registerAux ch x unregisterFn := do ch.state.atomically do @@ -578,7 +578,7 @@ private partial def recvSelector (ch : Bounded α) : Selector (Option α) where set { st with consumers } where - registerAux (ch : Bounded α) (waiter : Waiter (Option α)) : Async Unit := do + registerAux (ch : Bounded α) (waiter : Waiter (Option α)) : IO Unit := do ch.state.atomically do -- We did drop the lock between `tryFn` and now so maybe ready? if ← recvReady' then @@ -597,17 +597,15 @@ where let promise ← IO.Promise.new modify fun st => { st with consumers := st.consumers.enqueue ⟨promise, some waiter⟩ } - let result ← await promise.result? - - match result with - | none => return () - | some res => - if res then - registerAux ch waiter - else - let lose := return () - let win promise := promise.resolve (.ok none) - waiter.race lose win + IO.chainTask promise.result? fun + | none => return () + | some res => + if res then + registerAux ch waiter + else + let lose := return () + let win promise := promise.resolve (.ok none) + waiter.race lose win end Bounded From 73ad0d26d31b0865009aace3f38ea04cab08c876 Mon Sep 17 00:00:00 2001 From: Sofia Rodrigues Date: Sat, 20 Sep 2025 13:09:59 -0300 Subject: [PATCH 6/6] fix: remove some helper functiosn --- src/Std/Internal/Async/Basic.lean | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/src/Std/Internal/Async/Basic.lean b/src/Std/Internal/Async/Basic.lean index 0245caabbb6f..fa6285edd138 100644 --- a/src/Std/Internal/Async/Basic.lean +++ b/src/Std/Internal/Async/Basic.lean @@ -499,14 +499,6 @@ instance : MonadFinally BaseAsync where let res ← x Prod.mk res <$> f (some res) - -/-- -Converts `Task` into `BaseAsync`. --/ -@[inline] -protected def ofEAsyncTask (task : Task α) : BaseAsync α := do - pure (f := BaseIO) (MaybeTask.ofTask task) - /-- Converts `Except` to `BaseAsync`. -/ @@ -785,13 +777,6 @@ protected partial def forIn instance : ForIn (EAsync ε) Lean.Loop Unit where forIn _ := EAsync.forIn -/-- -Converts `ETask` into `EAsync`. --/ -@[inline] -protected def ofEAsyncTask (task : ETask ε α) : EAsync ε α := do - pure (f := BaseIO) (MaybeTask.ofTask task) - /-- Converts `Except` to `EAsync`. -/