Skip to content

Commit a31c278

Browse files
committed
derive functor and traversable instances
1 parent 48522a3 commit a31c278

File tree

3 files changed

+205
-27
lines changed

3 files changed

+205
-27
lines changed

data/traversable/derive.lean

Lines changed: 159 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ Author: Simon Hudon
66
Automation to construct `traversable` instances
77
-/
88

9-
import .basic
9+
import .basic .instances
1010
import category.basic
11-
import tactic.basic
11+
import tactic.basic tactic.cache
1212

1313
namespace tactic.interactive
1414

@@ -17,19 +17,53 @@ open tactic list monad functor
1717
def succeeds {m : Type* → Type*} [alternative m] {α} (cmd : m α) : m bool :=
1818
(tt <$ cmd) <|> pure ff
1919

20-
meta def traverse_field (n : name) (cl f v e : expr) : tactic (option expr) :=
20+
meta def nested_map (f v : expr) : expr → tactic expr
21+
| t :=
22+
do t ← instantiate_mvars t,
23+
mcond (succeeds $ is_def_eq t v)
24+
(pure f)
25+
(if ¬ v.occurs (t.app_fn)
26+
then do
27+
cl ← mk_app ``functor [t.app_fn],
28+
_inst ← mk_instance cl,
29+
f' ← nested_map t.app_arg,
30+
mk_mapp ``functor.map [t.app_fn,_inst,none,none,f']
31+
else fail format!"type {t} is not a functor with respect to variable {v}")
32+
33+
meta def map_field (n : name) (cl f v e : expr) : tactic expr :=
2134
do t ← infer_type e >>= whnf,
2235
if t.get_app_fn.const_name = n
23-
then return none
36+
then fail "recursive types not supported"
2437
else if v.occurs t
25-
then mcond (succeeds $ is_def_eq t v)
26-
(pure $ some $ expr.app f e)
27-
(if ¬ v.occurs (t.app_fn)
28-
then some <$> to_expr ``(compose.mk (traversable.traverse %%f %%e))
29-
else fail format!"type {t} is not traversable with respect to variable {v}")
38+
then do f' ← nested_map f v t,
39+
pure $ f' e
3040
else
31-
(is_def_eq t.app_fn cl >> some <$> to_expr ``(compose.mk %%e))
32-
<|> some <$> to_expr ``(@pure %%cl _ _ %%e)
41+
(is_def_eq t.app_fn cl >> to_expr ``(compose.mk %%e))
42+
<|> pure e
43+
44+
meta def nested_traverse (f v : expr) : expr → tactic expr
45+
| t :=
46+
do t ← instantiate_mvars t,
47+
mcond (succeeds $ is_def_eq t v)
48+
(pure f)
49+
(if ¬ v.occurs (t.app_fn)
50+
then do
51+
cl ← mk_app ``traversable [t.app_fn],
52+
_inst ← mk_instance cl,
53+
f' ← nested_traverse t.app_arg,
54+
mk_mapp ``traversable.traverse [t.app_fn,_inst,none,none,none,none,f']
55+
else fail format!"type {t} is not traversable with respect to variable {v}")
56+
57+
meta def traverse_field (n : name) (_inst cl f v e : expr) : tactic expr :=
58+
do t ← infer_type e >>= whnf,
59+
if t.get_app_fn.const_name = n
60+
then fail "recursive types not supported"
61+
else if v.occurs t
62+
then do f' ← nested_traverse f v t,
63+
pure $ f' e
64+
else
65+
(is_def_eq t.app_fn cl >> to_expr ``(compose.mk %%e))
66+
<|> to_expr ``(@pure _ (%%_inst).to_has_pure _ (ulift.up %%e))
3367

3468
meta def seq_apply_constructor : list expr → expr → tactic expr
3569
| (x :: xs) e := to_expr ``(%%e <*> %%x) >>= seq_apply_constructor xs
@@ -48,32 +82,132 @@ do c ← mk_const n,
4882
t ← infer_type c,
4983
fill_implicit_arg' c t
5084

51-
meta def traverse_constructor (c n : name) (f v : expr) (args : list expr) : tactic unit :=
85+
meta def mk_down (e : expr) : tactic expr :=
86+
to_expr ``(ulift.down %%e) <|> pure e
87+
88+
meta def map_constructor (c n : name) (f v : expr) (args : list expr) : tactic unit :=
5289
do g ← target,
53-
args' ← mmap (traverse_field n g.app_fn f v) args,
90+
args' ← mmap (map_field n g.app_fn f v) args,
5491
constr ← fill_implicit_arg c,
55-
constr' ← to_expr ``(@pure %%(g.app_fn) _ _ %%constr),
56-
r ← seq_apply_constructor (filter_map id args') constr',
57-
() <$ tactic.apply r
92+
let r := constr.mk_app args',
93+
() <$ tactic.exact r
5894

59-
open applicative
95+
meta def mk_map (type : name) (cs : list name) := do
96+
`[intros α β f x],
97+
reset_instance_cache,
98+
x ← get_local `x,
99+
xs ← tactic.induction x,
100+
f ← get_local `f,
101+
α ← get_local ,
102+
β ← get_local ,
103+
() <$ mmap₂'
104+
(λ (c : name) (x : name × list expr × list (name × expr)),
105+
solve1 (map_constructor c type f α x.2.1))
106+
cs xs
60107

61-
meta def derive_traverse : tactic unit :=
62-
do `(traversable %%f) ← target | failed,
63-
env ← get_env,
64-
let n := f.get_app_fn.const_name,
65-
let cs := env.constructors_of n,
66-
constructor,
67-
`[intros m _ α β f x],
108+
meta def traverse_constructor (c n : name) (_inst f v : expr) (args : list expr) : tactic unit :=
109+
do g ← target,
110+
args' ← mmap (traverse_field n _inst g.app_fn f v) args,
111+
constr ← fill_implicit_arg c,
112+
v ← mk_mvar,
113+
constr' ← to_expr ``(@pure %%(g.app_fn) (%%_inst).to_has_pure _ %%v),
114+
r ← seq_apply_constructor args' constr',
115+
gs ← get_goals,
116+
set_goals [v],
117+
vs ← tactic.intros >>= mmap mk_down,
118+
tactic.exact (constr.mk_app vs),
119+
done,
120+
set_goals gs,
121+
() <$ tactic.exact r
122+
123+
meta def mk_traverse (type : name) (cs : list name) := do
124+
`[intros m _inst α β f x],
68125
reset_instance_cache,
69126
x ← get_local `x,
70127
xs ← tactic.induction x,
71128
f ← get_local `f,
129+
_inst ← get_local `_inst,
72130
α ← get_local ,
73131
β ← get_local ,
74132
m ← get_local `m,
75133
() <$ mmap₂'
76-
(λ (c : name) (x : name × list expr × _), solve1 (traverse_constructor c n f α x.2.1))
134+
(λ (c : name) (x : name × list expr × list (name × expr)),
135+
solve1 (traverse_constructor c type _inst f α x.2.1))
77136
cs xs
78137

138+
open applicative
139+
140+
meta def derive_functor : tactic unit :=
141+
do `(functor %%f) ← target | failed,
142+
env ← get_env,
143+
let n := f.get_app_fn.const_name,
144+
let cs := env.constructors_of n,
145+
refine ``( { functor . map := _ , .. } ),
146+
mk_map n cs
147+
148+
meta def derive_traverse : tactic unit :=
149+
do `(traversable %%f) ← target | failed,
150+
env ← get_env,
151+
let n := f.get_app_fn.const_name,
152+
let cs := env.constructors_of n,
153+
constructor,
154+
mk_traverse n cs
155+
156+
meta def mk_one_instance
157+
(n : name)
158+
(cls : name)
159+
(tac : tactic unit) : tactic unit :=
160+
do decl ← get_decl n,
161+
cls_decl ← get_decl cls,
162+
env ← get_env,
163+
guard (env.is_inductive n) <|> fail format!"failed to derive '{cls}', '{n}' is not an inductive type",
164+
let ls := decl.univ_params.map $ λ n, level.param n,
165+
-- incrementally build up target expression `Π (hp : p) [cls hp] ..., cls (n.{ls} hp ...)`
166+
-- where `p ...` are the inductive parameter types of `n`
167+
let tgt : expr := expr.const n ls,
168+
⟨params, _⟩ ← mk_local_pis (decl.type.instantiate_univ_params (decl.univ_params.zip ls)),
169+
let params := params.init,
170+
let tgt := tgt.mk_app params,
171+
tgt ← mk_app cls [tgt] <|> fail "fish a",
172+
tgt ← params.enum.mfoldr (λ ⟨i, param⟩ tgt,
173+
do -- add typeclass hypothesis for each inductive parameter
174+
tgt ← do {
175+
guard $ i < env.inductive_num_params n,
176+
param_cls ← mk_app cls [param] <|> fail "fish b",
177+
-- TODO(sullrich): omit some typeclass parameters based on usage of `param`?
178+
pure $ expr.pi `a binder_info.inst_implicit param_cls tgt
179+
} <|> pure tgt,
180+
pure $ tgt.bind_pi param
181+
) tgt,
182+
() <$ mk_instance tgt <|> do
183+
(_, val) ← tactic.solve_aux tgt (do
184+
tactic.intros >> tac),
185+
val ← instantiate_mvars val,
186+
let trusted := decl.is_trusted ∧ cls_decl.is_trusted,
187+
add_decl (declaration.defn (n ++ cls)
188+
decl.univ_params
189+
tgt val reducibility_hints.abbrev trusted),
190+
set_basic_attribute `instance (n ++ cls) tt
191+
192+
open function
193+
meta def higher_order_derive_handler
194+
(cls : name)
195+
(tac : tactic unit)
196+
(deps : list (name × tactic unit) := []) :
197+
derive_handler :=
198+
λ p n,
199+
if p.is_constant_of cls then
200+
do mmap' (uncurry $ mk_one_instance n) deps,
201+
mk_one_instance n cls tac,
202+
pure true
203+
else pure false
204+
205+
@[derive_handler]
206+
meta def functor_derive_handler :=
207+
higher_order_derive_handler ``functor derive_functor
208+
209+
@[derive_handler]
210+
meta def traversable_derive_handler :=
211+
higher_order_derive_handler ``traversable derive_traverse [(``functor,derive_functor)]
212+
79213
end tactic.interactive

data/traversable/instances.lean

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ open function functor
6666
variables {f f' : Type u → Type u}
6767
variables [applicative f] [applicative f']
6868

69-
def option.traverse {α β : Type u} (g : α → f β) : option α → f (option β)
69+
protected def option.traverse {α β : Type u} (g : α → f β) : option α → f (option β)
7070
| none := pure none
7171
| (some x) := some <$> g x
7272

@@ -114,7 +114,7 @@ variables {α β : Type u}
114114
open applicative functor
115115
open list (cons)
116116

117-
def list.traverse (g : α → f β) : list α → f (list β)
117+
protected def list.traverse (g : α → f β) : list α → f (list β)
118118
| [] := pure []
119119
| (x :: xs) := cons <$> g x <*> list.traverse xs
120120

tests/examples.lean

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import tactic data.stream.basic data.set.basic data.finset data.multiset
2+
data.traversable.derive
23
open tactic
34

45
universe u
@@ -92,3 +93,46 @@ begin
9293
admit },
9394
trivial
9495
end
96+
97+
/- traversable -/
98+
99+
meta def check_defn (n : name) (e : pexpr) : tactic unit :=
100+
do (declaration.defn _ _ _ d _ _) ← get_decl n,
101+
e' ← to_expr e,
102+
guard (d =ₐ e') <|> trace d >> failed
103+
104+
set_option trace.app_builder true
105+
106+
@[derive traversable]
107+
structure my_struct (α : Type) :=
108+
(y : ℤ)
109+
110+
run_cmd do
111+
check_defn ``my_struct.traversable
112+
``( { traversable .
113+
to_functor := my_struct.functor,
114+
traverse := λ (m : TypeType) (_inst : applicative m) (α β : Type) (f : α → m β) (x : my_struct α),
115+
my_struct.rec (λ (x : ℤ), pure (λ (a : ulift ℤ), {y := a.down}) <*> pure {down := x}) x} )
116+
117+
@[derive traversable]
118+
structure my_struct2 (α : Type u) : Type u :=
119+
(x : α)
120+
(y : ℤ)
121+
(z : list α)
122+
(k : list (list α))
123+
124+
run_cmd do
125+
check_defn ``my_struct2.traversable
126+
``( { traversable .
127+
to_functor := my_struct2.functor,
128+
traverse := λ (m : Type u → Type u) (_inst : applicative m) (α β : Type u) (f : α → m β) (x : my_struct2 α),
129+
my_struct2.rec
130+
(λ (x_x : α) (x_y : ℤ) (x_z : list α) (x_k : list (list α)),
131+
pure
132+
(λ (a : β) (a_1 : ulift ℤ) (a_2 : list β) (a_3 : list (list β)),
133+
{x := a, y := a_1.down, z := a_2, k := a_3}) <*>
134+
f x_x <*>
135+
pure {down := x_y} <*>
136+
traverse f x_z <*>
137+
traverse (traverse f) x_k)
138+
x } )

0 commit comments

Comments
 (0)