Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: generalize ReaderT instances #1447

Merged
merged 1 commit into from
Aug 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 26 additions & 15 deletions src/Init/Control/Lawful.lean
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ end ExceptT

namespace ReaderT

theorem ext [Monad m] {x y : ReaderT ρ m α} (h : ∀ ctx, x.run ctx = y.run ctx) : x = y := by
theorem ext {x y : ReaderT ρ m α} (h : ∀ ctx, x.run ctx = y.run ctx) : x = y := by
simp [run] at h
exact funext h

Expand All @@ -192,34 +192,45 @@ theorem ext [Monad m] {x y : ReaderT ρ m α} (h : ∀ ctx, x.run ctx = y.run ct
@[simp] theorem run_bind [Monad m] (x : ReaderT ρ m α) (f : α → ReaderT ρ m β) (ctx : ρ)
: (x >>= f).run ctx = x.run ctx >>= λ a => (f a).run ctx := rfl

@[simp] theorem run_mapConst [Monad m] (a : α) (x : ReaderT ρ m β) (ctx : ρ)
: (Functor.mapConst a x).run ctx = Functor.mapConst a (x.run ctx) := rfl

@[simp] theorem run_map [Monad m] (f : α → β) (x : ReaderT ρ m α) (ctx : ρ)
: (f <$> x).run ctx = f <$> x.run ctx := rfl

@[simp] theorem run_monadLift [MonadLiftT n m] (x : n α) (ctx : ρ)
: (monadLift x : ReaderT ρ m α).run ctx = (monadLift x : m α) := rfl

@[simp] theorem run_monadMap [Monad m] [MonadFunctor n m] (f : {β : Type u} → n β → n β) (x : ReaderT ρ m α) (ctx : ρ)
@[simp] theorem run_monadMap [MonadFunctor n m] (f : {β : Type u} → n β → n β) (x : ReaderT ρ m α) (ctx : ρ)
: (monadMap @f x : ReaderT ρ m α).run ctx = monadMap @f (x.run ctx) := rfl

@[simp] theorem run_read [Monad m] (ctx : ρ) : (ReaderT.read : ReaderT ρ m ρ).run ctx = pure ctx := rfl

@[simp] theorem run_seq {α β : Type u} [Monad m] [LawfulMonad m] (f : ReaderT ρ m (α → β)) (x : ReaderT ρ m α) (ctx : ρ) : (f <*> x).run ctx = (f.run ctx <*> x.run ctx) := by
rw [seq_eq_bind (m := m)]; rfl
@[simp] theorem run_seq {α β : Type u} [Monad m] (f : ReaderT ρ m (α → β)) (x : ReaderT ρ m α) (ctx : ρ)
: (f <*> x).run ctx = (f.run ctx <*> x.run ctx) := rfl

@[simp] theorem run_seqRight [Monad m] (x : ReaderT ρ m α) (y : ReaderT ρ m β) (ctx : ρ)
: (x *> y).run ctx = (x.run ctx *> y.run ctx) := rfl

@[simp] theorem run_seqLeft [Monad m] (x : ReaderT ρ m α) (y : ReaderT ρ m β) (ctx : ρ)
: (x <* y).run ctx = (x.run ctx <* y.run ctx) := rfl

@[simp] theorem run_seqRight [Monad m] [LawfulMonad m] (x : ReaderT ρ m α) (y : ReaderT ρ m β) (ctx : ρ) : (x *> y).run ctx = (x.run ctx *> y.run ctx) := by
rw [seqRight_eq_bind (m := m)]; rfl
instance [Monad m] [LawfulFunctor m] : LawfulFunctor (ReaderT ρ m) where
id_map := by intros; apply ext; simp
map_const := by intros; funext a b; apply ext; intros; simp [map_const]
comp_map := by intros; apply ext; intros; simp [comp_map]

@[simp] theorem run_seqLeft [Monad m] [LawfulMonad m] (x : ReaderT ρ m α) (y : ReaderT ρ m β) (ctx : ρ) : (x <* y).run ctx = (x.run ctx <* y.run ctx) := by
rw [seqLeft_eq_bind (m := m)]; rfl
instance [Monad m] [LawfulApplicative m] : LawfulApplicative (ReaderT ρ m) where
seqLeft_eq := by intros; apply ext; intros; simp [seqLeft_eq]
seqRight_eq := by intros; apply ext; intros; simp [seqRight_eq]
pure_seq := by intros; apply ext; intros; simp [pure_seq]
map_pure := by intros; apply ext; intros; simp [map_pure]
seq_pure := by intros; apply ext; intros; simp [seq_pure]
seq_assoc := by intros; apply ext; intros; simp [seq_assoc]

instance [Monad m] [LawfulMonad m] : LawfulMonad (ReaderT ρ m) where
id_map := by intros; apply ext; intros; simp
map_const := by intros; rfl
seqLeft_eq := by intros; apply ext; intros; simp; apply LawfulApplicative.seqLeft_eq
seqRight_eq := by intros; apply ext; intros; simp; apply LawfulApplicative.seqRight_eq
pure_seq := by intros; apply ext; intros; simp; apply LawfulApplicative.pure_seq
bind_pure_comp := by intros; apply ext; intros; simp; apply LawfulMonad.bind_pure_comp
bind_map := by intros; rfl
bind_pure_comp := by intros; apply ext; intros; simp [LawfulMonad.bind_pure_comp]
bind_map := by intros; apply ext; intros; simp [bind_map]
pure_bind := by intros; apply ext; intros; simp
bind_assoc := by intros; apply ext; intros; simp

Expand Down
30 changes: 18 additions & 12 deletions src/Init/Prelude.lean
Original file line number Diff line number Diff line change
Expand Up @@ -2838,35 +2838,41 @@ instance (ε) [MonadExceptOf ε m] : MonadExceptOf ε (ReaderT ρ m) where
end

section
variable {ρ : Type u} {m : Type u → Type v} [Monad m] {α β : Type u}
variable {ρ : Type u} {m : Type u → Type v}

/-- `(← read) : ρ` gets the read-only state of a `ReaderT ρ`. -/
@[inline] protected def read : ReaderT ρ m ρ := pure
@[inline] protected def read [Monad m] : ReaderT ρ m ρ :=
pure

/-- The `pure` operation of the `ReaderT` monad. -/
@[inline] protected def pure (a : α) : ReaderT ρ m α := fun _ => pure a
@[inline] protected def pure [Monad m] {α} (a : α) : ReaderT ρ m α :=
fun _ => pure a

/-- The `bind` operation of the `ReaderT` monad. -/
@[inline] protected def bind (x : ReaderT ρ m α) (f : α → ReaderT ρ m β) : ReaderT ρ m β :=
@[inline] protected def bind [Monad m] {α β} (x : ReaderT ρ m α) (f : α → ReaderT ρ m β) : ReaderT ρ m β :=
fun r => bind (x r) fun a => f a r

/-- The `map` operation of the `ReaderT` monad. -/
@[inline] protected def map (f : α → β) (x : ReaderT ρ m α) : ReaderT ρ m β :=
fun r => Functor.map f (x r)
instance [Monad m] : Functor (ReaderT ρ m) where
map f x r := Functor.map f (x r)
mapConst a x r := Functor.mapConst a (x r)

instance : Monad (ReaderT ρ m) where
pure := ReaderT.pure
instance [Monad m] : Applicative (ReaderT ρ m) where
pure := ReaderT.pure
seq f x r := Seq.seq (f r) fun _ => x () r
seqLeft a b r := SeqLeft.seqLeft (a r) fun _ => b () r
seqRight a b r := SeqRight.seqRight (a r) fun _ => b () r

instance [Monad m] : Monad (ReaderT ρ m) where
bind := ReaderT.bind
map := ReaderT.map

instance (ρ m) [Monad m] : MonadFunctor m (ReaderT ρ m) where
instance (ρ m) : MonadFunctor m (ReaderT ρ m) where
monadMap f x := fun ctx => f (x ctx)

/--
`adapt (f : ρ' → ρ)` precomposes function `f` on the reader state of a
`ReaderT ρ`, yielding a `ReaderT ρ'`.
-/
@[inline] protected def adapt {ρ' : Type u} [Monad m] {α : Type u} (f : ρ' → ρ) : ReaderT ρ m α → ReaderT ρ' m α :=
@[inline] protected def adapt {ρ' α : Type u} (f : ρ' → ρ) : ReaderT ρ m α → ReaderT ρ' m α :=
fun x r => x (f r)

end
Expand Down