diff --git a/src/category/monad/cont.lean b/src/category/monad/cont.lean index a27ed1792ee8f..94a72ea27b820 100644 --- a/src/category/monad/cont.lean +++ b/src/category/monad/cont.lean @@ -9,6 +9,7 @@ http://hackage.haskell.org/package/mtl-2.2.2/docs/Control-Monad-Cont.html -/ import tactic.interactive +import category.monad.basic universes u v w @@ -17,13 +18,12 @@ structure monad_cont.label (α : Type w) (m : Type u → Type v) (β : Type u) : def monad_cont.goto {α β} {m : Type u → Type v} (f : monad_cont.label α m β) (x : α) := f.apply x -class monad_cont (m : Type u → Type v) -extends monad m := +class monad_cont (m : Type u → Type v) := (call_cc : Π {α β}, ((monad_cont.label α m β) → m α) → m α) open monad_cont -class is_lawful_monad_cont (m : Type u → Type v) [monad_cont m] +class is_lawful_monad_cont (m : Type u → Type v) [monad m] [monad_cont m] extends is_lawful_monad m := (call_cc_bind_right {α ω γ} (cmd : m α) (next : (label ω m γ) → α → m ω) : call_cc (λ f, cmd >>= next f) = cmd >>= λ x, call_cc (λ f, next f x)) @@ -55,6 +55,11 @@ def with_cont_t (f : (β → m r) → α → m r) (x : cont_t r m α) : cont_t r lemma run_with_cont_t (f : (β → m r) → α → m r) (x : cont_t r m α) : run (with_cont_t f x) = run x ∘ f := rfl +@[extensionality] +protected lemma ext {x y : cont_t r m α} + (h : ∀ f, x.run f = y.run f) : + x = y := by { ext; apply h } + instance : monad (cont_t r m) := { pure := λ α x f, f x, bind := λ α β x f g, x $ λ i, f i g } @@ -69,7 +74,7 @@ instance [monad m] : has_monad_lift m (cont_t r m) := lemma monad_lift_bind [monad m] [is_lawful_monad m] {α β} (x : m α) (f : α → m β) : (monad_lift (x >>= f) : cont_t r m β) = monad_lift x >>= monad_lift ∘ f := -by { ext, simp only [monad_lift,has_monad_lift.monad_lift,(∘),(>>=),bind_assoc,id.def] } +by { ext, simp only [monad_lift,has_monad_lift.monad_lift,(∘),(>>=),bind_assoc,id.def,run], } instance : monad_cont (cont_t r m) := { call_cc := λ α β f g, f ⟨λ x h, g x⟩ g } @@ -79,4 +84,81 @@ instance : is_lawful_monad_cont (cont_t r m) := call_cc_bind_left := by intros; ext; refl, call_cc_dummy := by intros; ext; refl } +instance (ε) [monad_except ε m] : monad_except ε (cont_t r m) := +{ throw := λ x e f, throw e, + catch := λ α act h f, catch (act f) (λ e, h e f) } + +instance : monad_run (λ α, (α → m r) → ulift.{u v} (m r)) (cont_t.{u v u} r m) := +{ run := λ α f x, ⟨ f x ⟩ } + end cont_t + +variables {m : Type u → Type v} [monad m] + +def except_t.mk_label {α β ε} : label (except.{u u} ε α) m β → label α (except_t ε m) β +| ⟨ f ⟩ := ⟨ λ a, monad_lift $ f (except.ok a) ⟩ + +lemma except_t.goto_mk_label {α β ε : Type*} (x : label (except.{u u} ε α) m β) (i : α) : + goto (except_t.mk_label x) i = ⟨ except.ok <$> goto x (except.ok i) ⟩ := by cases x; refl + +def except_t.call_cc {ε} [monad_cont m] {α β : Type*} (f : label α (except_t ε m) β → except_t ε m α) : except_t ε m α := +except_t.mk (call_cc $ λ x : label _ m β, except_t.run $ f (except_t.mk_label x) : m (except ε α)) + +instance {ε} [monad_cont m] : monad_cont (except_t ε m) := +{ call_cc := λ α β, except_t.call_cc } + +instance {ε} [monad_cont m] [is_lawful_monad_cont m] : is_lawful_monad_cont (except_t ε m) := +{ call_cc_bind_right := by { intros, simp [call_cc,except_t.call_cc,call_cc_bind_right], ext, dsimp, congr, ext ⟨ ⟩; simp [except_t.bind_cont,@call_cc_dummy m _], }, + call_cc_bind_left := by { intros, simp [call_cc,except_t.call_cc,call_cc_bind_right,except_t.goto_mk_label,map_eq_bind_pure_comp,bind_assoc,@call_cc_bind_left m _], ext, refl }, + call_cc_dummy := by { intros, simp [call_cc,except_t.call_cc,@call_cc_dummy m _], ext, refl }, } + +def option_t.mk_label {α β} : label (option.{u} α) m β → label α (option_t m) β +| ⟨ f ⟩ := ⟨ λ a, monad_lift $ f (some a) ⟩ + +lemma option_t.goto_mk_label {α β : Type*} (x : label (option.{u} α) m β) (i : α) : + goto (option_t.mk_label x) i = ⟨ some <$> goto x (some i) ⟩ := by cases x; refl + +def option_t.call_cc [monad_cont m] {α β : Type*} (f : label α (option_t m) β → option_t m α) : option_t m α := +option_t.mk (call_cc $ λ x : label _ m β, option_t.run $ f (option_t.mk_label x) : m (option α)) + +instance [monad_cont m] : monad_cont (option_t m) := +{ call_cc := λ α β, option_t.call_cc } + +instance [monad_cont m] [is_lawful_monad_cont m] : is_lawful_monad_cont (option_t m) := +{ call_cc_bind_right := by { intros, simp [call_cc,option_t.call_cc,call_cc_bind_right], ext, dsimp, congr, ext ⟨ ⟩; simp [option_t.bind_cont,@call_cc_dummy m _], }, + call_cc_bind_left := by { intros, simp [call_cc,option_t.call_cc,call_cc_bind_right,option_t.goto_mk_label,map_eq_bind_pure_comp,bind_assoc,@call_cc_bind_left m _], ext, refl }, + call_cc_dummy := by { intros, simp [call_cc,option_t.call_cc,@call_cc_dummy m _], ext, refl }, } + +def state_t.mk_label {α β σ : Type u} : label (α × σ) m (β × σ) → label α (state_t σ m) β +| ⟨ f ⟩ := ⟨ λ a, ⟨ λ s, f (a,s) ⟩ ⟩ + +lemma state_t.goto_mk_label {α β σ : Type u} (x : label (α × σ) m (β × σ)) (i : α) : + goto (state_t.mk_label x) i = ⟨ λ s, (goto x (i,s)) ⟩ := by cases x; refl + +def state_t.call_cc {σ} [monad_cont m] {α β : Type*} (f : label α (state_t σ m) β → state_t σ m α) : state_t σ m α := +⟨ λ r, call_cc (λ f', (f $ state_t.mk_label f').run r) ⟩ + +instance {σ} [monad_cont m] : monad_cont (state_t σ m) := +{ call_cc := λ α β, state_t.call_cc } + +instance {σ} [monad_cont m] [is_lawful_monad_cont m] : is_lawful_monad_cont (state_t σ m) := +{ call_cc_bind_right := by { intros, simp [call_cc,state_t.call_cc,call_cc_bind_right,(>>=),state_t.bind], ext, dsimp, congr, ext ⟨x₀,x₁⟩, refl }, + call_cc_bind_left := by { intros, simp [call_cc,state_t.call_cc,call_cc_bind_left,(>>=),state_t.bind,state_t.goto_mk_label], ext, refl }, + call_cc_dummy := by { intros, simp [call_cc,state_t.call_cc,call_cc_bind_right,(>>=),state_t.bind,@call_cc_dummy m _], ext, refl }, } + +def reader_t.mk_label {α β} (ρ) : label α m β → label α (reader_t ρ m) β +| ⟨ f ⟩ := ⟨ monad_lift ∘ f ⟩ + +lemma reader_t.goto_mk_label {α ρ β} (x : label α m β) (i : α) : + goto (reader_t.mk_label ρ x) i = monad_lift (goto x i) := by cases x; refl + +def reader_t.call_cc {ε} [monad_cont m] {α β : Type*} (f : label α (reader_t ε m) β → reader_t ε m α) : reader_t ε m α := +⟨ λ r, call_cc (λ f', (f $ reader_t.mk_label _ f').run r) ⟩ + +instance {ρ} [monad_cont m] : monad_cont (reader_t ρ m) := +{ call_cc := λ α β, reader_t.call_cc } + +instance {ρ} [monad_cont m] [is_lawful_monad_cont m] : is_lawful_monad_cont (reader_t ρ m) := +{ call_cc_bind_right := by { intros, simp [call_cc,reader_t.call_cc,call_cc_bind_right], ext, refl }, + call_cc_bind_left := by { intros, simp [call_cc,reader_t.call_cc,call_cc_bind_left,reader_t.goto_mk_label], ext, refl }, + call_cc_dummy := by { intros, simp [call_cc,reader_t.call_cc,@call_cc_dummy m _], ext, refl } }