From 16d8352f7ed0d38cbc58ace03b3429d693cf50c6 Mon Sep 17 00:00:00 2001 From: Scott Morrison Date: Thu, 14 Dec 2023 07:49:44 +1100 Subject: [PATCH] feat: MLList.ofTaskList (#397) --- Std.lean | 2 ++ Std/Data/MLList/Basic.lean | 18 +++++++++++++++- Std/Data/MLList/IO.lean | 24 +++++++++++++++++++++ Std/Lean/System/IO.lean | 44 ++++++++++++++++++++++++++++++++++++++ test/MLList.lean | 24 ++++++++++++++++++++- 5 files changed, 110 insertions(+), 2 deletions(-) create mode 100644 Std/Data/MLList/IO.lean create mode 100644 Std/Lean/System/IO.lean diff --git a/Std.lean b/Std.lean index a866ba703d..4a78963d77 100644 --- a/Std.lean +++ b/Std.lean @@ -53,6 +53,7 @@ import Std.Data.List.Pairwise import Std.Data.List.Perm import Std.Data.MLList.Basic import Std.Data.MLList.Heartbeats +import Std.Data.MLList.IO import Std.Data.Nat.Basic import Std.Data.Nat.Bitwise import Std.Data.Nat.Gcd @@ -109,6 +110,7 @@ import Std.Lean.Parser import Std.Lean.PersistentHashMap import Std.Lean.PersistentHashSet import Std.Lean.Position +import Std.Lean.System.IO import Std.Lean.Tactic import Std.Lean.TagAttribute import Std.Lean.Util.EnvSearch diff --git a/Std/Data/MLList/Basic.lean b/Std/Data/MLList/Basic.lean index c548667855..2d5ab2ed97 100644 --- a/Std/Data/MLList/Basic.lean +++ b/Std/Data/MLList/Basic.lean @@ -117,7 +117,23 @@ instance [Monad m] [MonadLiftT m n] : ForIn n (MLList m α) α where partial def fix [Monad m] (f : α → m α) (x : α) : MLList m α := cons x <| squash fun _ => fix f <$> f x -/-- Construct a `MLList` recursively. If `f` returns `none` the list will terminate. -/ +/-- +Constructs an `MLList` recursively, with state in `α`, recording terms from `β`. +If `f` returns `none` the list will terminate. + +Variant of `MLList.fix?` that allows returning values of a different type. +-/ +partial def fix?' [Monad m] (f : α → m (Option (β × α))) (init : α) : MLList m β := + squash fun _ => do + match ← f init with + | none => pure .nil + | some (b, a) => pure (.cons b (fix?' f a)) + +/-- +Constructs an `MLList` recursively. If `f` returns `none` the list will terminate. + +Returns the initial value as the first element. +-/ partial def fix? [Monad m] (f : α → m (Option α)) (x : α) : MLList m α := cons x <| squash fun _ => do match ← f x with diff --git a/Std/Data/MLList/IO.lean b/Std/Data/MLList/IO.lean new file mode 100644 index 0000000000..21f46bbee6 --- /dev/null +++ b/Std/Data/MLList/IO.lean @@ -0,0 +1,24 @@ +/- +Copyright (c) 2023 Scott Morrison. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Scott Morrison +-/ +import Std.Lean.System.IO +import Std.Data.MLList.Basic + +/-! +# IO operations using monadic lazy lists. +-/ + +namespace MLList + +/-- +Give a list of tasks, return the monadic lazy list which +returns the values as they become available. +-/ +def ofTaskList (tasks : List (Task α)) : MLList BaseIO α := + fix?' (init := tasks) fun t => do + if h : 0 < t.length then + some <$> IO.waitAny' t h + else + pure none diff --git a/Std/Lean/System/IO.lean b/Std/Lean/System/IO.lean new file mode 100644 index 0000000000..973d822918 --- /dev/null +++ b/Std/Lean/System/IO.lean @@ -0,0 +1,44 @@ +/- +Copyright (c) 2023 Scott Morrison. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Scott Morrison +-/ +import Std.Data.List.Lemmas +import Std.Data.MLList.Basic + +/-! +# Functions for manipulating a list of tasks + +* `IO.waitAny'` is a wrapper for `IO.waitAny` that also returns the remaining tasks. +* `List.waitAll : List (Task α) → Task (List α)` gathers a list of tasks into a task returning + the list of all results. +-/ + +set_option autoImplicit true + +-- duplicated from `lean4/src/Init/System/IO.lean` +local macro "nonempty_list" : tactic => + `(tactic| exact Nat.zero_lt_succ _) + +/-- +Given a non-empty list of tasks, wait for the first to complete. +Return the value and the list of remaining tasks. +-/ +def IO.waitAny' (tasks : List (Task α)) (h : 0 < tasks.length := by nonempty_list) : + BaseIO (α × List (Task α)) := do + let (i, a) ← IO.waitAny + -- It would be more efficient to use `mapIdx` rather than `.enum.map` here + -- but the lemma `List.mapIdx_length` is currently interred in `Mathlib.Data.List.Indexes` + (tasks.enum.map fun (i, t) => t.map (prio := .max) fun a => (i, a)) + ((tasks.enum.length_map _).symm ▸ tasks.enum_length ▸ h) + return (a, tasks.eraseIdx i) + +/-- +Given a list of tasks, create the task returning the list of results, +by waiting for each. +-/ +def List.waitAll (tasks : List (Task α)) : Task (List α) := + match tasks with + | [] => .pure [] + | task :: tasks => task.bind (prio := .max) fun a => + tasks.waitAll.map (prio := .max) fun as => a :: as diff --git a/test/MLList.lean b/test/MLList.lean index a03bdd26d2..eaf14cbc04 100644 --- a/test/MLList.lean +++ b/test/MLList.lean @@ -1,8 +1,10 @@ -import Std.Data.MLList.Basic +import Std.Data.MLList.IO import Std.Tactic.GuardMsgs set_option linter.missingDocs false +/-! ### Test fix to performance problem in `asArray`. -/ + def g (n : Nat) : MLList Lean.Meta.MetaM Nat := do for _ in [:n] do if true then @@ -13,3 +15,23 @@ def g (n : Nat) : MLList Lean.Meta.MetaM Nat := do -- This used to fail before add the `uncons?` field to the implementation of `MLList`. #guard_msgs in #eval MLList.asArray $ (g 3000) + +/-! +### Test `MLList.ofTaskList`. + +We generate three tasks which sleep for `100`, `10`, and `1` milliseconds respectively, +and then verify that `MLList.ofTaskList` return their results in the order they complete. +-/ + +def sleep (n : UInt32) : BaseIO (Task UInt32) := + IO.asTask (do IO.sleep n; return n) |>.map fun t => t.map fun + | .ok n => n + | .error _ => 0 + +def sleeps : MLList BaseIO UInt32 := .squash fun _ => do + let r ← [100,10,1].map sleep |>.traverse id + return .ofTaskList r + +/-- info: [1, 10, 100] -/ +#guard_msgs in +#eval sleeps.force