@@ -6,9 +6,9 @@ Author: Simon Hudon
6
6
Automation to construct `traversable` instances
7
7
-/
8
8
9
- import .basic
9
+ import .basic .instances
10
10
import category.basic
11
- import tactic.basic
11
+ import tactic.basic tactic.cache
12
12
13
13
namespace tactic.interactive
14
14
@@ -17,19 +17,53 @@ open tactic list monad functor
17
17
def succeeds {m : Type * → Type *} [alternative m] {α} (cmd : m α) : m bool :=
18
18
(tt <$ cmd) <|> pure ff
19
19
20
- meta def traverse_field (n : name) (cl f v e : expr) : tactic (option expr) :=
20
+ meta def nested_map (f v : expr) : expr → tactic expr
21
+ | t :=
22
+ do t ← instantiate_mvars t,
23
+ mcond (succeeds $ is_def_eq t v)
24
+ (pure f)
25
+ (if ¬ v.occurs (t.app_fn)
26
+ then do
27
+ cl ← mk_app ``functor [t.app_fn],
28
+ _inst ← mk_instance cl,
29
+ f' ← nested_map t.app_arg,
30
+ mk_mapp ``functor .map [t.app_fn,_inst,none,none,f']
31
+ else fail format!" type {t} is not a functor with respect to variable {v}" )
32
+
33
+ meta def map_field (n : name) (cl f v e : expr) : tactic expr :=
21
34
do t ← infer_type e >>= whnf,
22
35
if t.get_app_fn.const_name = n
23
- then return none
36
+ then fail " recursive types not supported "
24
37
else if v.occurs t
25
- then mcond (succeeds $ is_def_eq t v)
26
- (pure $ some $ expr.app f e)
27
- (if ¬ v.occurs (t.app_fn)
28
- then some <$> to_expr ``(compose.mk (traversable.traverse %%f %%e))
29
- else fail format!" type {t} is not traversable with respect to variable {v}" )
38
+ then do f' ← nested_map f v t,
39
+ pure $ f' e
30
40
else
31
- (is_def_eq t.app_fn cl >> some <$> to_expr ``(compose.mk %%e))
32
- <|> some <$> to_expr ``(@pure %%cl _ _ %%e)
41
+ (is_def_eq t.app_fn cl >> to_expr ``(compose.mk %%e))
42
+ <|> pure e
43
+
44
+ meta def nested_traverse (f v : expr) : expr → tactic expr
45
+ | t :=
46
+ do t ← instantiate_mvars t,
47
+ mcond (succeeds $ is_def_eq t v)
48
+ (pure f)
49
+ (if ¬ v.occurs (t.app_fn)
50
+ then do
51
+ cl ← mk_app ``traversable [t.app_fn],
52
+ _inst ← mk_instance cl,
53
+ f' ← nested_traverse t.app_arg,
54
+ mk_mapp ``traversable .traverse [t.app_fn,_inst,none,none,none,none,f']
55
+ else fail format!" type {t} is not traversable with respect to variable {v}" )
56
+
57
+ meta def traverse_field (n : name) (_inst cl f v e : expr) : tactic expr :=
58
+ do t ← infer_type e >>= whnf,
59
+ if t.get_app_fn.const_name = n
60
+ then fail " recursive types not supported"
61
+ else if v.occurs t
62
+ then do f' ← nested_traverse f v t,
63
+ pure $ f' e
64
+ else
65
+ (is_def_eq t.app_fn cl >> to_expr ``(compose.mk %%e))
66
+ <|> to_expr ``(@pure _ (%%_inst).to_has_pure _ (ulift.up %%e))
33
67
34
68
meta def seq_apply_constructor : list expr → expr → tactic expr
35
69
| (x :: xs) e := to_expr ``(%%e <*> %%x) >>= seq_apply_constructor xs
@@ -48,32 +82,132 @@ do c ← mk_const n,
48
82
t ← infer_type c,
49
83
fill_implicit_arg' c t
50
84
51
- meta def traverse_constructor (c n : name) (f v : expr) (args : list expr) : tactic unit :=
85
+ meta def mk_down (e : expr) : tactic expr :=
86
+ to_expr ``(ulift.down %%e) <|> pure e
87
+
88
+ meta def map_constructor (c n : name) (f v : expr) (args : list expr) : tactic unit :=
52
89
do g ← target,
53
- args' ← mmap (traverse_field n g.app_fn f v) args,
90
+ args' ← mmap (map_field n g.app_fn f v) args,
54
91
constr ← fill_implicit_arg c,
55
- constr' ← to_expr ``(@pure %%(g.app_fn) _ _ %%constr),
56
- r ← seq_apply_constructor (filter_map id args') constr',
57
- () <$ tactic.apply r
92
+ let r := constr.mk_app args',
93
+ () <$ tactic.exact r
58
94
59
- open applicative
95
+ meta def mk_map (type : name) (cs : list name) := do
96
+ `[intros α β f x],
97
+ reset_instance_cache,
98
+ x ← get_local `x ,
99
+ xs ← tactic.induction x,
100
+ f ← get_local `f ,
101
+ α ← get_local `α ,
102
+ β ← get_local `β ,
103
+ () <$ mmap₂'
104
+ (λ (c : name) (x : name × list expr × list (name × expr)),
105
+ solve1 (map_constructor c type f α x.2 .1 ))
106
+ cs xs
60
107
61
- meta def derive_traverse : tactic unit :=
62
- do `(traversable %%f) ← target | failed,
63
- env ← get_env,
64
- let n := f.get_app_fn.const_name,
65
- let cs := env.constructors_of n,
66
- constructor,
67
- `[intros m _ α β f x],
108
+ meta def traverse_constructor (c n : name) (_inst f v : expr) (args : list expr) : tactic unit :=
109
+ do g ← target,
110
+ args' ← mmap (traverse_field n _inst g.app_fn f v) args,
111
+ constr ← fill_implicit_arg c,
112
+ v ← mk_mvar,
113
+ constr' ← to_expr ``(@pure %%(g.app_fn) (%%_inst).to_has_pure _ %%v),
114
+ r ← seq_apply_constructor args' constr',
115
+ gs ← get_goals,
116
+ set_goals [v],
117
+ vs ← tactic.intros >>= mmap mk_down,
118
+ tactic.exact (constr.mk_app vs),
119
+ done,
120
+ set_goals gs,
121
+ () <$ tactic.exact r
122
+
123
+ meta def mk_traverse (type : name) (cs : list name) := do
124
+ `[intros m _inst α β f x],
68
125
reset_instance_cache,
69
126
x ← get_local `x ,
70
127
xs ← tactic.induction x,
71
128
f ← get_local `f ,
129
+ _inst ← get_local `_inst ,
72
130
α ← get_local `α ,
73
131
β ← get_local `β ,
74
132
m ← get_local `m ,
75
133
() <$ mmap₂'
76
- (λ (c : name) (x : name × list expr × _), solve1 (traverse_constructor c n f α x.2 .1 ))
134
+ (λ (c : name) (x : name × list expr × list (name × expr)),
135
+ solve1 (traverse_constructor c type _inst f α x.2 .1 ))
77
136
cs xs
78
137
138
+ open applicative
139
+
140
+ meta def derive_functor : tactic unit :=
141
+ do `(functor %%f) ← target | failed,
142
+ env ← get_env,
143
+ let n := f.get_app_fn.const_name,
144
+ let cs := env.constructors_of n,
145
+ refine ``( { functor . map := _ , .. } ),
146
+ mk_map n cs
147
+
148
+ meta def derive_traverse : tactic unit :=
149
+ do `(traversable %%f) ← target | failed,
150
+ env ← get_env,
151
+ let n := f.get_app_fn.const_name,
152
+ let cs := env.constructors_of n,
153
+ constructor,
154
+ mk_traverse n cs
155
+
156
+ meta def mk_one_instance
157
+ (n : name)
158
+ (cls : name)
159
+ (tac : tactic unit) : tactic unit :=
160
+ do decl ← get_decl n,
161
+ cls_decl ← get_decl cls,
162
+ env ← get_env,
163
+ guard (env.is_inductive n) <|> fail format!" failed to derive '{cls}', '{n}' is not an inductive type" ,
164
+ let ls := decl.univ_params.map $ λ n, level.param n,
165
+ -- incrementally build up target expression `Π (hp : p) [cls hp] ..., cls (n.{ls} hp ...)`
166
+ -- where `p ...` are the inductive parameter types of `n`
167
+ let tgt : expr := expr.const n ls,
168
+ ⟨params, _⟩ ← mk_local_pis (decl.type.instantiate_univ_params (decl.univ_params.zip ls)),
169
+ let params := params.init,
170
+ let tgt := tgt.mk_app params,
171
+ tgt ← mk_app cls [tgt] <|> fail " fish a" ,
172
+ tgt ← params.enum.mfoldr (λ ⟨i, param⟩ tgt,
173
+ do -- add typeclass hypothesis for each inductive parameter
174
+ tgt ← do {
175
+ guard $ i < env.inductive_num_params n,
176
+ param_cls ← mk_app cls [param] <|> fail " fish b" ,
177
+ -- TODO(sullrich): omit some typeclass parameters based on usage of `param`?
178
+ pure $ expr.pi `a binder_info.inst_implicit param_cls tgt
179
+ } <|> pure tgt,
180
+ pure $ tgt.bind_pi param
181
+ ) tgt,
182
+ () <$ mk_instance tgt <|> do
183
+ (_, val) ← tactic.solve_aux tgt (do
184
+ tactic.intros >> tac),
185
+ val ← instantiate_mvars val,
186
+ let trusted := decl.is_trusted ∧ cls_decl.is_trusted,
187
+ add_decl (declaration.defn (n ++ cls)
188
+ decl.univ_params
189
+ tgt val reducibility_hints.abbrev trusted),
190
+ set_basic_attribute `instance (n ++ cls) tt
191
+
192
+ open function
193
+ meta def higher_order_derive_handler
194
+ (cls : name)
195
+ (tac : tactic unit)
196
+ (deps : list (name × tactic unit) := []) :
197
+ derive_handler :=
198
+ λ p n,
199
+ if p.is_constant_of cls then
200
+ do mmap' (uncurry $ mk_one_instance n) deps,
201
+ mk_one_instance n cls tac,
202
+ pure true
203
+ else pure false
204
+
205
+ @[derive_handler]
206
+ meta def functor_derive_handler :=
207
+ higher_order_derive_handler ``functor derive_functor
208
+
209
+ @[derive_handler]
210
+ meta def traversable_derive_handler :=
211
+ higher_order_derive_handler ``traversable derive_traverse [(``functor ,derive_functor)]
212
+
79
213
end tactic.interactive
0 commit comments