Skip to content
This repository was archived by the owner on Jul 24, 2024. It is now read-only.

Commit c726c12

Browse files
cipher1024mergify[bot]
authored andcommitted
feat(category/monad/cont): monad_cont instances for state_t, reader_t, except_t and option_t (#733)
* feat(category/monad/cont): monad_cont instances for state_t, reader_t, except_t and option_t * feat(category/monad/writer): writer monad transformer
1 parent 98ba07b commit c726c12

File tree

3 files changed

+272
-5
lines changed

3 files changed

+272
-5
lines changed

src/category/monad/basic.lean

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
2+
import tactic.interactive
3+
4+
attribute [extensionality] reader_t.ext state_t.ext except_t.ext option_t.ext
5+
attribute [functor_norm] bind_assoc pure_bind bind_pure
6+
universes u v
7+
8+
lemma map_eq_bind_pure_comp (m : Type u → Type v) [monad m] [is_lawful_monad m] {α β : Type u} (f : α → β) (x : m α) :
9+
f <$> x = x >>= pure ∘ f := by rw bind_pure_comp_eq_map

src/category/monad/cont.lean

Lines changed: 104 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ http://hackage.haskell.org/package/mtl-2.2.2/docs/Control-Monad-Cont.html
99
-/
1010

1111
import tactic.ext
12+
import category.monad.basic category.monad.writer
1213

1314
universes u v w
1415

@@ -17,13 +18,12 @@ structure monad_cont.label (α : Type w) (m : Type u → Type v) (β : Type u) :
1718

1819
def monad_cont.goto {α β} {m : Type u → Type v} (f : monad_cont.label α m β) (x : α) := f.apply x
1920

20-
class monad_cont (m : Type u → Type v)
21-
extends monad m :=
21+
class monad_cont (m : Type u → Type v) :=
2222
(call_cc : Π {α β}, ((monad_cont.label α m β) → m α) → m α)
2323

2424
open monad_cont
2525

26-
class is_lawful_monad_cont (m : Type u → Type v) [monad_cont m]
26+
class is_lawful_monad_cont (m : Type u → Type v) [monad m] [monad_cont m]
2727
extends is_lawful_monad m :=
2828
(call_cc_bind_right {α ω γ} (cmd : m α) (next : (label ω m γ) → α → m ω) :
2929
call_cc (λ f, cmd >>= next f) = cmd >>= λ x, call_cc (λ f, next f x))
@@ -36,6 +36,8 @@ export is_lawful_monad_cont
3636

3737
def cont_t (r : Type u) (m : Type u → Type v) (α : Type w) := (α → m r) → m r
3838

39+
@[reducible] def cont (r : Type u) (α : Type w) := cont_t r id α
40+
3941
namespace cont_t
4042

4143
export monad_cont (label goto)
@@ -55,6 +57,11 @@ def with_cont_t (f : (β → m r) → α → m r) (x : cont_t r m α) : cont_t r
5557
lemma run_with_cont_t (f : (β → m r) → α → m r) (x : cont_t r m α) :
5658
run (with_cont_t f x) = run x ∘ f := rfl
5759

60+
@[extensionality]
61+
protected lemma ext {x y : cont_t r m α}
62+
(h : ∀ f, x.run f = y.run f) :
63+
x = y := by { ext; apply h }
64+
5865
instance : monad (cont_t r m) :=
5966
{ pure := λ α x f, f x,
6067
bind := λ α β x f g, x $ λ i, f i g }
@@ -64,12 +71,15 @@ instance : is_lawful_monad (cont_t r m) :=
6471
pure_bind := by { intros, ext, refl },
6572
bind_assoc := by { intros, ext, refl } }
6673

74+
def cont_t.monad_lift [monad m] {α} : m α → cont_t r m α :=
75+
λ x f, x >>= f
76+
6777
instance [monad m] : has_monad_lift m (cont_t r m) :=
68-
{ monad_lift := λ a x f, x >>= f }
78+
{ monad_lift := λ α, cont_t.monad_lift }
6979

7080
lemma monad_lift_bind [monad m] [is_lawful_monad m] {α β} (x : m α) (f : α → m β) :
7181
(monad_lift (x >>= f) : cont_t r m β) = monad_lift x >>= monad_lift ∘ f :=
72-
by { ext, simp only [monad_lift,has_monad_lift.monad_lift,(∘),(>>=),bind_assoc,id.def] }
82+
by { ext, simp only [monad_lift,has_monad_lift.monad_lift,(∘),(>>=),bind_assoc,id.def,run,cont_t.monad_lift] }
7383

7484
instance : monad_cont (cont_t r m) :=
7585
{ call_cc := λ α β f g, f ⟨λ x h, g x⟩ g }
@@ -79,4 +89,93 @@ instance : is_lawful_monad_cont (cont_t r m) :=
7989
call_cc_bind_left := by intros; ext; refl,
8090
call_cc_dummy := by intros; ext; refl }
8191

92+
instance (ε) [monad_except ε m] : monad_except ε (cont_t r m) :=
93+
{ throw := λ x e f, throw e,
94+
catch := λ α act h f, catch (act f) (λ e, h e f) }
95+
96+
instance : monad_run (λ α, (α → m r) → ulift.{u v} (m r)) (cont_t.{u v u} r m) :=
97+
{ run := λ α f x, ⟨ f x ⟩ }
98+
8299
end cont_t
100+
101+
variables {m : Type u → Type v} [monad m]
102+
103+
def except_t.mk_label {α β ε} : label (except.{u u} ε α) m β → label α (except_t ε m) β
104+
| ⟨ f ⟩ := ⟨ λ a, monad_lift $ f (except.ok a) ⟩
105+
106+
lemma except_t.goto_mk_label {α β ε : Type*} (x : label (except.{u u} ε α) m β) (i : α) :
107+
goto (except_t.mk_label x) i = ⟨ except.ok <$> goto x (except.ok i) ⟩ := by cases x; refl
108+
109+
def except_t.call_cc {ε} [monad_cont m] {α β : Type*} (f : label α (except_t ε m) β → except_t ε m α) : except_t ε m α :=
110+
except_t.mk (call_cc $ λ x : label _ m β, except_t.run $ f (except_t.mk_label x) : m (except ε α))
111+
112+
instance {ε} [monad_cont m] : monad_cont (except_t ε m) :=
113+
{ call_cc := λ α β, except_t.call_cc }
114+
115+
instance {ε} [monad_cont m] [is_lawful_monad_cont m] : is_lawful_monad_cont (except_t ε m) :=
116+
{ call_cc_bind_right := by { intros, simp [call_cc,except_t.call_cc,call_cc_bind_right], ext, dsimp, congr, ext ⟨ ⟩; simp [except_t.bind_cont,@call_cc_dummy m _], },
117+
call_cc_bind_left := by { intros, simp [call_cc,except_t.call_cc,call_cc_bind_right,except_t.goto_mk_label,map_eq_bind_pure_comp,bind_assoc,@call_cc_bind_left m _], ext, refl },
118+
call_cc_dummy := by { intros, simp [call_cc,except_t.call_cc,@call_cc_dummy m _], ext, refl }, }
119+
120+
def option_t.mk_label {α β} : label (option.{u} α) m β → label α (option_t m) β
121+
| ⟨ f ⟩ := ⟨ λ a, monad_lift $ f (some a) ⟩
122+
123+
lemma option_t.goto_mk_label {α β : Type*} (x : label (option.{u} α) m β) (i : α) :
124+
goto (option_t.mk_label x) i = ⟨ some <$> goto x (some i) ⟩ := by cases x; refl
125+
126+
def option_t.call_cc [monad_cont m] {α β : Type*} (f : label α (option_t m) β → option_t m α) : option_t m α :=
127+
option_t.mk (call_cc $ λ x : label _ m β, option_t.run $ f (option_t.mk_label x) : m (option α))
128+
129+
instance [monad_cont m] : monad_cont (option_t m) :=
130+
{ call_cc := λ α β, option_t.call_cc }
131+
132+
instance [monad_cont m] [is_lawful_monad_cont m] : is_lawful_monad_cont (option_t m) :=
133+
{ call_cc_bind_right := by { intros, simp [call_cc,option_t.call_cc,call_cc_bind_right], ext, dsimp, congr, ext ⟨ ⟩; simp [option_t.bind_cont,@call_cc_dummy m _], },
134+
call_cc_bind_left := by { intros, simp [call_cc,option_t.call_cc,call_cc_bind_right,option_t.goto_mk_label,map_eq_bind_pure_comp,bind_assoc,@call_cc_bind_left m _], ext, refl },
135+
call_cc_dummy := by { intros, simp [call_cc,option_t.call_cc,@call_cc_dummy m _], ext, refl }, }
136+
137+
def writer_t.mk_label {α β ω} [has_one ω] : label (α × ω) m β → label α (writer_t ω m) β
138+
| ⟨ f ⟩ := ⟨ λ a, monad_lift $ f (a,1) ⟩
139+
140+
lemma writer_t.goto_mk_label {α β ω : Type*} [has_one ω] (x : label (α × ω) m β) (i : α) :
141+
goto (writer_t.mk_label x) i = monad_lift (goto x (i,1)) := by cases x; refl
142+
143+
def writer_t.call_cc [monad_cont m] {α β ω : Type*} [has_one ω] (f : label α (writer_t ω m) β → writer_t ω m α) : writer_t ω m α :=
144+
⟨ call_cc (writer_t.run ∘ f ∘ writer_t.mk_label : label (α × ω) m β → m (α × ω)) ⟩
145+
146+
instance (ω) [monad m] [has_one ω] [monad_cont m] : monad_cont (writer_t ω m) :=
147+
{ call_cc := λ α β, writer_t.call_cc }
148+
149+
def state_t.mk_label {α β σ : Type u} : label (α × σ) m (β × σ) → label α (state_t σ m) β
150+
| ⟨ f ⟩ := ⟨ λ a, ⟨ λ s, f (a,s) ⟩ ⟩
151+
152+
lemma state_t.goto_mk_label {α β σ : Type u} (x : label (α × σ) m (β × σ)) (i : α) :
153+
goto (state_t.mk_label x) i = ⟨ λ s, (goto x (i,s)) ⟩ := by cases x; refl
154+
155+
def state_t.call_cc {σ} [monad_cont m] {α β : Type*} (f : label α (state_t σ m) β → state_t σ m α) : state_t σ m α :=
156+
⟨ λ r, call_cc (λ f', (f $ state_t.mk_label f').run r) ⟩
157+
158+
instance {σ} [monad_cont m] : monad_cont (state_t σ m) :=
159+
{ call_cc := λ α β, state_t.call_cc }
160+
161+
instance {σ} [monad_cont m] [is_lawful_monad_cont m] : is_lawful_monad_cont (state_t σ m) :=
162+
{ call_cc_bind_right := by { intros, simp [call_cc,state_t.call_cc,call_cc_bind_right,(>>=),state_t.bind], ext, dsimp, congr, ext ⟨x₀,x₁⟩, refl },
163+
call_cc_bind_left := by { intros, simp [call_cc,state_t.call_cc,call_cc_bind_left,(>>=),state_t.bind,state_t.goto_mk_label], ext, refl },
164+
call_cc_dummy := by { intros, simp [call_cc,state_t.call_cc,call_cc_bind_right,(>>=),state_t.bind,@call_cc_dummy m _], ext, refl }, }
165+
166+
def reader_t.mk_label {α β} (ρ) : label α m β → label α (reader_t ρ m) β
167+
| ⟨ f ⟩ := ⟨ monad_lift ∘ f ⟩
168+
169+
lemma reader_t.goto_mk_label {α ρ β} (x : label α m β) (i : α) :
170+
goto (reader_t.mk_label ρ x) i = monad_lift (goto x i) := by cases x; refl
171+
172+
def reader_t.call_cc {ε} [monad_cont m] {α β : Type*} (f : label α (reader_t ε m) β → reader_t ε m α) : reader_t ε m α :=
173+
⟨ λ r, call_cc (λ f', (f $ reader_t.mk_label _ f').run r) ⟩
174+
175+
instance {ρ} [monad_cont m] : monad_cont (reader_t ρ m) :=
176+
{ call_cc := λ α β, reader_t.call_cc }
177+
178+
instance {ρ} [monad_cont m] [is_lawful_monad_cont m] : is_lawful_monad_cont (reader_t ρ m) :=
179+
{ call_cc_bind_right := by { intros, simp [call_cc,reader_t.call_cc,call_cc_bind_right], ext, refl },
180+
call_cc_bind_left := by { intros, simp [call_cc,reader_t.call_cc,call_cc_bind_left,reader_t.goto_mk_label], ext, refl },
181+
call_cc_dummy := by { intros, simp [call_cc,reader_t.call_cc,@call_cc_dummy m _], ext, refl } }

src/category/monad/writer.lean

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
2+
/-
3+
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
4+
Released under Apache 2.0 license as described in the file LICENSE.
5+
Authors: Simon Hudon
6+
7+
The writer monad transformer for passing immutable state.
8+
-/
9+
10+
import tactic.interactive category.monad.basic
11+
universes u v w
12+
13+
structure writer_t (ω : Type u) (m : Type u → Type v) (α : Type u) : Type (max u v) :=
14+
(run : m (α × ω))
15+
16+
@[reducible] def writer (ω : Type u) := writer_t ω id
17+
18+
attribute [pp_using_anonymous_constructor] writer_t
19+
20+
namespace writer_t
21+
section
22+
variable {ω : Type u}
23+
variable {m : Type u → Type v}
24+
variable [monad m]
25+
variables {α β : Type u}
26+
open function
27+
28+
@[extensionality]
29+
protected lemma ext (x x' : writer_t ω m α)
30+
(h : x.run = x'.run) :
31+
x = x' := by cases x; cases x'; congr; apply h
32+
33+
@[inline] protected def tell (w : ω) : writer_t ω m punit :=
34+
⟨pure (punit.star, w)⟩
35+
36+
@[inline] protected def listen : writer_t ω m α → writer_t ω m (α × ω)
37+
| ⟨ cmd ⟩ := ⟨ (λ x : α × ω, ((x.1,x.2),x.2)) <$> cmd ⟩
38+
39+
@[inline] protected def pass : writer_t ω m (α × (ω → ω)) → writer_t ω m α
40+
| ⟨ cmd ⟩ := ⟨ uncurry (uncurry $ λ x (f : ω → ω) w, (x,f w)) <$> cmd ⟩
41+
42+
@[inline] protected def pure [has_one ω] (a : α) : writer_t ω m α :=
43+
⟨ pure (a,1) ⟩
44+
45+
@[inline] protected def bind [has_mul ω] (x : writer_t ω m α) (f : α → writer_t ω m β) : writer_t ω m β :=
46+
do x ← x.run,
47+
x' ← (f x.1).run,
48+
pure (x'.1,x.2 * x'.2) ⟩
49+
50+
instance [has_one ω] [has_mul ω] : monad (writer_t ω m) :=
51+
{ pure := λ α, writer_t.pure, bind := λ α β, writer_t.bind }
52+
53+
instance [monoid ω] [is_lawful_monad m] : is_lawful_monad (writer_t ω m) :=
54+
{ id_map := by { intros, cases x, simp [(<$>),writer_t.bind,writer_t.pure] },
55+
pure_bind := by { intros, simp [has_pure.pure,writer_t.pure,(>>=),writer_t.bind], ext; refl },
56+
bind_assoc := by { intros, simp [(>>=),writer_t.bind,mul_assoc] with functor_norm } }
57+
58+
@[inline] protected def lift [has_one ω] (a : m α) : writer_t ω m α :=
59+
⟨ flip prod.mk 1 <$> a ⟩
60+
61+
instance (m) [monad m] [has_one ω] : has_monad_lift m (writer_t ω m) :=
62+
⟨ λ α, writer_t.lift ⟩
63+
64+
@[inline] protected def monad_map {m m'} [monad m] [monad m'] {α} (f : Π {α}, m α → m' α) : writer_t ω m α → writer_t ω m' α :=
65+
λ x, ⟨ f x.run ⟩
66+
67+
instance (m m') [monad m] [monad m'] : monad_functor m m' (writer_t ω m) (writer_t ω m') :=
68+
⟨@writer_t.monad_map ω m m' _ _⟩
69+
70+
@[inline] protected def adapt {ω' : Type u} [monad m] {α : Type u} (f : ω → ω') : writer_t ω m α → writer_t ω' m α :=
71+
λ x, ⟨prod.map id f <$> x.run⟩
72+
73+
instance (ε) [has_one ω] [monad m] [monad_except ε m] : monad_except ε (writer_t ω m) :=
74+
{ throw := λ α, writer_t.lift ∘ throw,
75+
catch := λ α x c, ⟨catch x.run (λ e, (c e).run)⟩ }
76+
end
77+
end writer_t
78+
79+
80+
/-- An implementation of [MonadReader](https://hackage.haskell.org/package/mtl-2.2.2/docs/Control-Monad-Reader-Class.html#t:MonadReader).
81+
It does not contain `local` because this function cannot be lifted using `monad_lift`.
82+
Instead, the `monad_reader_adapter` class provides the more general `adapt_reader` function.
83+
84+
Note: This class can be seen as a simplification of the more "principled" definition
85+
```
86+
class monad_reader (ρ : out_param (Type u)) (n : Type u → Type u) :=
87+
(lift {} {α : Type u} : (∀ {m : Type u → Type u} [monad m], reader_t ρ m α) → n α)
88+
```
89+
-/
90+
class monad_writer (ω : out_param (Type u)) (m : Type u → Type v) :=
91+
(tell {} (w : ω) : m punit)
92+
(listen {α} : m α → m (α × ω))
93+
(pass {α : Type u} : m (α × (ω → ω)) → m α)
94+
95+
export monad_writer
96+
97+
instance {ω : Type u} {m : Type u → Type v} [monad m] : monad_writer ω (writer_t ω m) :=
98+
{ tell := writer_t.tell,
99+
listen := λ α, writer_t.listen,
100+
pass := λ α, writer_t.pass }
101+
102+
instance {ω ρ : Type u} {m : Type u → Type v} [monad m] [monad_writer ω m] : monad_writer ω (reader_t ρ m) :=
103+
{ tell := λ x, monad_lift (tell x : m punit),
104+
listen := λ α ⟨ cmd ⟩, ⟨ λ r, listen (cmd r) ⟩,
105+
pass := λ α ⟨ cmd ⟩, ⟨ λ r, pass (cmd r) ⟩ }
106+
107+
def swap_right {α β γ} : (α × β) × γ → (α × γ) × β
108+
| ⟨⟨x,y⟩,z⟩ := ((x,z),y)
109+
110+
instance {ω σ : Type u} {m : Type u → Type v} [monad m] [monad_writer ω m] : monad_writer ω (state_t σ m) :=
111+
{ tell := λ x, monad_lift (tell x : m punit),
112+
listen := λ α ⟨ cmd ⟩, ⟨ λ r, swap_right <$> listen (cmd r) ⟩,
113+
pass := λ α ⟨ cmd ⟩, ⟨ λ r, pass (swap_right <$> cmd r) ⟩ }
114+
open function
115+
116+
def except_t.pass_aux {ε α ω} : except ε (α × (ω → ω)) → except ε α × (ω → ω)
117+
| (except.error a) := (except.error a,id)
118+
| (except.ok (x,y)) := (except.ok x,y)
119+
120+
instance {ω ε : Type u} {m : Type u → Type v} [monad m] [monad_writer ω m] : monad_writer ω (except_t ε m) :=
121+
{ tell := λ x, monad_lift (tell x : m punit),
122+
listen := λ α ⟨ cmd ⟩, ⟨ uncurry (λ x y, flip prod.mk y <$> x) <$> listen cmd ⟩,
123+
pass := λ α ⟨ cmd ⟩, ⟨ pass (except_t.pass_aux <$> cmd) ⟩ }
124+
125+
def option_t.pass_aux {α ω} : option (α × (ω → ω)) → option α × (ω → ω)
126+
| none := (none ,id)
127+
| (some (x,y)) := (some x,y)
128+
129+
instance {ω : Type u} {m : Type u → Type v} [monad m] [monad_writer ω m] : monad_writer ω (option_t m) :=
130+
{ tell := λ x, monad_lift (tell x : m punit),
131+
listen := λ α ⟨ cmd ⟩, ⟨ uncurry (λ x y, flip prod.mk y <$> x) <$> listen cmd ⟩,
132+
pass := λ α ⟨ cmd ⟩, ⟨ pass (option_t.pass_aux <$> cmd) ⟩ }
133+
134+
/-- Adapt a monad stack, changing the type of its top-most environment.
135+
136+
This class is comparable to [Control.Lens.Magnify](https://hackage.haskell.org/package/lens-4.15.4/docs/Control-Lens-Zoom.html#t:Magnify), but does not use lenses (why would it), and is derived automatically for any transformer implementing `monad_functor`.
137+
138+
Note: This class can be seen as a simplification of the more "principled" definition
139+
```
140+
class monad_reader_functor (ρ ρ' : out_param (Type u)) (n n' : Type u → Type u) :=
141+
(map {} {α : Type u} : (∀ {m : Type u → Type u} [monad m], reader_t ρ m α → reader_t ρ' m α) → n α → n' α)
142+
```
143+
-/
144+
class monad_writer_adapter (ω ω' : out_param (Type u)) (m m' : Type u → Type v) :=
145+
(adapt_writer {} {α : Type u} : (ω → ω') → m α → m' α)
146+
export monad_writer_adapter (adapt_writer)
147+
148+
section
149+
variables {ω ω' : Type u} {m m' : Type u → Type v}
150+
151+
instance monad_writer_adapter_trans {n n' : Type u → Type v} [monad_functor m m' n n'] [monad_writer_adapter ω ω' m m'] : monad_writer_adapter ω ω' n n' :=
152+
⟨λ α f, monad_map (λ α, (adapt_writer f : m α → m' α))⟩
153+
154+
instance [monad m] : monad_writer_adapter ω ω' (writer_t ω m) (writer_t ω' m) :=
155+
⟨λ α, writer_t.adapt⟩
156+
end
157+
158+
instance (ω : Type u) (m out) [monad_run out m] : monad_run (λ α, out (α × ω)) (writer_t ω m) :=
159+
⟨λ α x, run $ x.run ⟩

0 commit comments

Comments
 (0)