Skip to content

Commit

Permalink
feat: MLList.ofTaskList (#397)
Browse files Browse the repository at this point in the history
  • Loading branch information
semorrison committed Dec 13, 2023
1 parent 7ef8b7c commit 16d8352
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 2 deletions.
2 changes: 2 additions & 0 deletions Std.lean
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ import Std.Data.List.Pairwise
import Std.Data.List.Perm
import Std.Data.MLList.Basic
import Std.Data.MLList.Heartbeats
import Std.Data.MLList.IO
import Std.Data.Nat.Basic
import Std.Data.Nat.Bitwise
import Std.Data.Nat.Gcd
Expand Down Expand Up @@ -109,6 +110,7 @@ import Std.Lean.Parser
import Std.Lean.PersistentHashMap
import Std.Lean.PersistentHashSet
import Std.Lean.Position
import Std.Lean.System.IO
import Std.Lean.Tactic
import Std.Lean.TagAttribute
import Std.Lean.Util.EnvSearch
Expand Down
18 changes: 17 additions & 1 deletion Std/Data/MLList/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,23 @@ instance [Monad m] [MonadLiftT m n] : ForIn n (MLList m α) α where
partial def fix [Monad m] (f : α → m α) (x : α) : MLList m α :=
cons x <| squash fun _ => fix f <$> f x

/-- Construct a `MLList` recursively. If `f` returns `none` the list will terminate. -/
/--
Constructs an `MLList` recursively, with state in `α`, recording terms from `β`.
If `f` returns `none` the list will terminate.
Variant of `MLList.fix?` that allows returning values of a different type.
-/
partial def fix?' [Monad m] (f : α → m (Option (β × α))) (init : α) : MLList m β :=
squash fun _ => do
match ← f init with
| none => pure .nil
| some (b, a) => pure (.cons b (fix?' f a))

/--
Constructs an `MLList` recursively. If `f` returns `none` the list will terminate.
Returns the initial value as the first element.
-/
partial def fix? [Monad m] (f : α → m (Option α)) (x : α) : MLList m α :=
cons x <| squash fun _ => do
match ← f x with
Expand Down
24 changes: 24 additions & 0 deletions Std/Data/MLList/IO.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/-
Copyright (c) 2023 Scott Morrison. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Scott Morrison
-/
import Std.Lean.System.IO
import Std.Data.MLList.Basic

/-!
# IO operations using monadic lazy lists.
-/

namespace MLList

/--
Give a list of tasks, return the monadic lazy list which
returns the values as they become available.
-/
def ofTaskList (tasks : List (Task α)) : MLList BaseIO α :=
fix?' (init := tasks) fun t => do
if h : 0 < t.length then
some <$> IO.waitAny' t h
else
pure none
44 changes: 44 additions & 0 deletions Std/Lean/System/IO.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/-
Copyright (c) 2023 Scott Morrison. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Scott Morrison
-/
import Std.Data.List.Lemmas
import Std.Data.MLList.Basic

/-!
# Functions for manipulating a list of tasks
* `IO.waitAny'` is a wrapper for `IO.waitAny` that also returns the remaining tasks.
* `List.waitAll : List (Task α) → Task (List α)` gathers a list of tasks into a task returning
the list of all results.
-/

set_option autoImplicit true

-- duplicated from `lean4/src/Init/System/IO.lean`
local macro "nonempty_list" : tactic =>
`(tactic| exact Nat.zero_lt_succ _)

/--
Given a non-empty list of tasks, wait for the first to complete.
Return the value and the list of remaining tasks.
-/
def IO.waitAny' (tasks : List (Task α)) (h : 0 < tasks.length := by nonempty_list) :
BaseIO (α × List (Task α)) := do
let (i, a) ← IO.waitAny
-- It would be more efficient to use `mapIdx` rather than `.enum.map` here
-- but the lemma `List.mapIdx_length` is currently interred in `Mathlib.Data.List.Indexes`
(tasks.enum.map fun (i, t) => t.map (prio := .max) fun a => (i, a))
((tasks.enum.length_map _).symm ▸ tasks.enum_length ▸ h)
return (a, tasks.eraseIdx i)

/--
Given a list of tasks, create the task returning the list of results,
by waiting for each.
-/
def List.waitAll (tasks : List (Task α)) : Task (List α) :=
match tasks with
| [] => .pure []
| task :: tasks => task.bind (prio := .max) fun a =>
tasks.waitAll.map (prio := .max) fun as => a :: as
24 changes: 23 additions & 1 deletion test/MLList.lean
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import Std.Data.MLList.Basic
import Std.Data.MLList.IO
import Std.Tactic.GuardMsgs

set_option linter.missingDocs false

/-! ### Test fix to performance problem in `asArray`. -/

def g (n : Nat) : MLList Lean.Meta.MetaM Nat := do
for _ in [:n] do
if true then
Expand All @@ -13,3 +15,23 @@ def g (n : Nat) : MLList Lean.Meta.MetaM Nat := do
-- This used to fail before add the `uncons?` field to the implementation of `MLList`.
#guard_msgs in
#eval MLList.asArray $ (g 3000)

/-!
### Test `MLList.ofTaskList`.
We generate three tasks which sleep for `100`, `10`, and `1` milliseconds respectively,
and then verify that `MLList.ofTaskList` return their results in the order they complete.
-/

def sleep (n : UInt32) : BaseIO (Task UInt32) :=
IO.asTask (do IO.sleep n; return n) |>.map fun t => t.map fun
| .ok n => n
| .error _ => 0

def sleeps : MLList BaseIO UInt32 := .squash fun _ => do
let r ← [100,10,1].map sleep |>.traverse id
return .ofTaskList r

/-- info: [1, 10, 100] -/
#guard_msgs in
#eval sleeps.force

0 comments on commit 16d8352

Please sign in to comment.