Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rewrite rules to Coq #18038

Merged
merged 31 commits into from Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
3162e2a
Add symbols, patterns, rewrite rule fields
Sep 18, 2023
7818d79
Add kernel (lazy) reduction for rewrite rules
Sep 18, 2023
ec048b9
Add cbv reduction for rewrite rules
Sep 18, 2023
99f5cae
Add cbn / simpl reduction for rewrite rules
Sep 18, 2023
76dfc9f
Add evar unification support for rewrite rules
Sep 18, 2023
653aa83
Add syntax to declare symbols and rewrite rules
Sep 18, 2023
2053c7d
Add minimal support for native and checker for rewrite rules
Sep 18, 2023
a2c4347
Add tests for rewrite rules
Sep 18, 2023
0be32c7
Add documentation for rewrite rules
Sep 18, 2023
5e4d125
Add a global flag for rewrite rules
Sep 22, 2023
d1be404
Add a flag to protect patvars from unification
Sep 26, 2023
0d1bbe1
Possibly problematic changes to elaboration
Sep 27, 2023
2c9b9a6
Support for sort-polymorphic rewrite rules
Nov 13, 2023
35a7bdc
Add a warning for redexes in rewrite rules
Nov 13, 2023
df104ad
Allow matching on algebraic instances in rewrite rules
Nov 28, 2023
554fde6
Treat universe levels correctly in rewrite rules
Dec 1, 2023
e911d06
Add checks for irrelevance or eta-patterns through types
Dec 20, 2023
72146e1
Documentation overhaul
Jan 8, 2024
984dc0b
Change syntax to use pipes as separators
Jan 8, 2024
cf05cad
Improve documentation
Jan 10, 2024
e220e80
Change order for assum_list production
Jan 11, 2024
88563e7
Apply review suggestions
Jan 18, 2024
c4d4cc6
Change pattern representation for indexed holes
Jan 24, 2024
63882fe
Change Rewrite Rule Goption to compiler flag
Jan 26, 2024
0e66d8c
Clearer error message when remaining evars, explanations of test file
Jan 26, 2024
b70986f
Better name for additional arguments to coqtop in manual
Jan 29, 2024
32d8a64
Fix bugs (stack overflow on unification, repeated warning declaration…
Jan 31, 2024
8798c8e
More informative message with universe inconsistencies in RR
Feb 2, 2024
3b9768f
Add overlay
Feb 5, 2024
7ba8516
Amend coqchk description for rewrite rules
Feb 8, 2024
a1c5f39
Fix relevance on symbols
Feb 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions boot/usage.ml
Expand Up @@ -82,6 +82,7 @@ let print_usage_common co command =
\n -impredicative-set set sort Set impredicative\
\n -allow-sprop allow using the proof irrelevant SProp sort\
\n -disallow-sprop forbid using the proof irrelevant SProp sort\
\n -allow-rewrite-rules allows declaring symbols and rewrite rules\
\n -indices-matter levels of indices (and nonuniform parameters) contribute to the level of inductives\
\n -type-in-type disable universe consistency checking\
\n -mangle-names x mangle auto-generated names using prefix x\
Expand Down
5 changes: 5 additions & 0 deletions checker/check_stat.ml
Expand Up @@ -28,6 +28,10 @@ let pr_impredicative_set env =
if is_impredicative_set env then str "Theory: Set is impredicative"
else str "Theory: Set is predicative"

let pr_rewrite_rules env =
if rewrite_rules_allowed env then str "Theory: Rewrite rules are allowed (consistency, subject reduction, confluence and normalization might be broken)"
else str "Theory: Rewrite rules are not allowed"

let pr_assumptions ass axs =
if axs = [] then
str ass ++ str ": <none>"
Expand Down Expand Up @@ -65,6 +69,7 @@ let print_context env opac =
(fnl() ++ str"CONTEXT SUMMARY" ++ fnl() ++
str"===============" ++ fnl() ++ fnl() ++
str "* " ++ hov 0 (pr_impredicative_set env ++ fnl()) ++ fnl() ++
str "* " ++ hov 0 (pr_rewrite_rules env ++ fnl()) ++ fnl() ++
str "* " ++ hov 0 (pr_axioms env opac ++ fnl()) ++ fnl() ++
str "* " ++ hov 0 (pr_type_in_type env ++ fnl()) ++ fnl() ++
str "* " ++ hov 0 (pr_unguarded env ++ fnl()) ++ fnl() ++
Expand Down
97 changes: 95 additions & 2 deletions checker/mod_checking.ml
Expand Up @@ -53,7 +53,7 @@ let check_constant_declaration env opac kn cb opacify =
if not (Sorts.relevance_equal cb.const_relevance (Sorts.relevance_of_sort jty.utj_type))
then raise Pp.(BadConstant (kn, str "incorrect const_relevance"));
let body, env = match cb.const_body with
| Undef _ | Primitive _ -> None, env
| Undef _ | Primitive _ | Symbol _ -> None, env
| Def c -> Some c, env
| OpaqueDef o ->
let c, u = !indirect_accessor o in
Expand Down Expand Up @@ -81,6 +81,96 @@ let check_constant_declaration env opac kn cb opacify =
let opac = check_constant_declaration env opac kn cb opacify in
Environ.add_constant kn cb env, opac

let check_instance_mask udecl umask lincheck =
match udecl, umask with
| _, None -> lincheck
| Monomorphic, Some ([||], [||]) -> lincheck
| Polymorphic uctx, Some (qmask, umask) ->
let lincheck = Array.fold_left_i (fun i lincheck mask -> Partial_subst.maybe_add_quality mask () lincheck) lincheck qmask in
let lincheck = Array.fold_left_i (fun i lincheck mask -> Partial_subst.maybe_add_univ mask () lincheck) lincheck umask in
if (Array.length qmask, Array.length umask) <> UVars.AbstractContext.size uctx then CErrors.anomaly Pp.(str "Bad univ mask length.");
lincheck
| _ -> CErrors.anomaly Pp.(str "Bad univ mask length.")

let rec get_holes_profiles env nargs ndecls lincheck el =
List.fold_left (get_holes_profiles_elim env nargs ndecls) lincheck el

and get_holes_profiles_elim env nargs ndecls lincheck = function
| PEApp args -> Array.fold_left (get_holes_profiles_parg env nargs ndecls) lincheck args
| PECase (ind, u, ret, brs) ->
let mib, mip = Inductive.lookup_mind_specif env ind in
let lincheck = check_instance_mask mib.mind_universes u lincheck in
let lincheck = get_holes_profiles_parg env (nargs + mip.mind_nrealargs +1) (ndecls + mip.mind_nrealdecls) lincheck ret in
Array.fold_left3 (fun lincheck nargs_b ndecls_b -> get_holes_profiles_parg env (nargs + nargs_b) (ndecls + ndecls_b) lincheck) lincheck mip.mind_consnrealargs mip.mind_consnrealdecls brs
| PEProj proj ->
let () = lookup_projection proj env |> ignore in
lincheck

and get_holes_profiles_parg env nargs ndecls lincheck = function
| EHoleIgnored -> lincheck
| EHole i -> Partial_subst.add_term i nargs lincheck
| ERigid (h, el) ->
let lincheck = get_holes_profiles_head env nargs ndecls lincheck h in
get_holes_profiles env nargs ndecls lincheck el

and get_holes_profiles_head env nargs ndecls lincheck = function
| PHRel n -> if n <= ndecls then lincheck else Type_errors.error_unbound_rel env n
| PHSymbol (c, u) ->
let cb = lookup_constant c env in
check_instance_mask cb.const_universes u lincheck
| PHConstr (c, u) ->
let (mib, _) = Inductive.lookup_mind_specif env (inductive_of_constructor c) in
check_instance_mask mib.mind_universes u lincheck
| PHInd (ind, u) ->
let (mib, _) = Inductive.lookup_mind_specif env ind in
check_instance_mask mib.mind_universes u lincheck
| PHInt _ | PHFloat _ -> lincheck
| PHSort PSSProp -> if Environ.sprop_allowed env then lincheck else Type_errors.error_disallowed_sprop env
| PHSort PSType io -> Partial_subst.maybe_add_univ io () lincheck
| PHSort PSQSort (qio, uio) ->
lincheck
|> Partial_subst.maybe_add_quality qio ()
|> Partial_subst.maybe_add_univ uio ()
| PHSort _ -> lincheck
| PHLambda (tys, bod) | PHProd (tys, bod) ->
let lincheck = Array.fold_left_i (fun i -> get_holes_profiles_parg env (nargs + i) (ndecls + i)) lincheck tys in
let lincheck = get_holes_profiles_parg env (nargs + Array.length tys) (ndecls + Array.length tys) lincheck bod in
lincheck

let check_rhs env holes_profile rhs =
let rec check i c = match Constr.kind c with
| App (f, args) when Constr.isRel f ->
let n = Constr.destRel f in
if n <= i then () else
if n - i > Array.length holes_profile then CErrors.anomaly Pp.(str "Malformed right-hand-side substitution site");
let d = holes_profile.(n-i-1) in
if Array.length args >= d then () else CErrors.anomaly Pp.(str "Malformed right-hand-side substitution site")
| Rel n when n > i ->
if n - i > Array.length holes_profile then CErrors.anomaly Pp.(str "Malformed right-hand-side substitution site");
let d = holes_profile.(n-i-1) in
if d = 0 then () else CErrors.anomaly Pp.(str "Malformed right-hand-side substitution site")
| _ -> Constr.iter_with_binders succ check i c
in
check 0 rhs

let check_rewrite_rule env lab i (symb, rule) =
Flags.if_verbose Feedback.msg_notice (str " checking rule:" ++ Label.print lab ++ str"#" ++ Pp.int i);
let { nvars; lhs_pat; rhs } = rule in
let symb_cb = Environ.lookup_constant symb env in
let () =
match symb_cb.const_body with Symbol _ -> ()
| _ -> ignore @@ invalid_arg "Rule defined on non-symbol"
in
let lincheck = Partial_subst.make nvars in
let lincheck = check_instance_mask symb_cb.const_universes (fst lhs_pat) lincheck in
let lincheck = get_holes_profiles env 0 0 lincheck (snd lhs_pat) in
let holes_profile, _, _ = Partial_subst.to_arrays lincheck in
let () = check_rhs env holes_profile rhs in
()

let check_rewrite_rules_body env lab rrb =
List.iteri (check_rewrite_rule env lab) rrb.rewrules_rules

(** {6 Checking modules } *)

(** We currently ignore the [mod_type_alg] and [typ_expr_alg] fields.
Expand All @@ -107,7 +197,7 @@ let rec collect_constants_without_body sign mp accu =
let c = Constant.make2 mp lab in
if Declareops.constant_has_body cb then s else Cset.add c s
| SFBmodule msb -> collect_constants_without_body msb.mod_type (MPdot(mp,lab)) s
| SFBmind _ | SFBmodtype _ -> s in
| SFBmind _ | SFBrules _ | SFBmodtype _ -> s in
match sign with
| MoreFunctor _ -> Cset.empty (* currently ignored *)
| NoFunctor struc ->
Expand Down Expand Up @@ -184,6 +274,9 @@ and check_structure_field env opac mp lab res opacify = function
| SFBmodtype mty ->
check_module_type env mty;
add_modtype mty env, opac
| SFBrules rrb ->
check_rewrite_rules_body env lab rrb;
Environ.add_rewrite_rules rrb.rewrules_rules env, opac

and check_signature env opac sign mp_mse res opacify = match sign with
| MoreFunctor (arg_id, mtb, body) ->
Expand Down
1 change: 1 addition & 0 deletions checker/safe_checking.ml
Expand Up @@ -12,6 +12,7 @@ open Declarations
open Environ

let import senv opac clib univs digest =
let senv = Safe_typing.check_flags_for_library clib senv in
let mb = Safe_typing.module_of_library clib in
let env = Safe_typing.env_of_safe_env senv in
let env = push_context_set ~strict:true (Safe_typing.univs_of_library clib) env in
Expand Down
43 changes: 41 additions & 2 deletions checker/values.ml
Expand Up @@ -244,7 +244,7 @@ let v_primitive =

let v_cst_def =
v_sum "constant_def" 0
[|[|Opt Int|]; [|v_constr|]; [|v_opaque|]; [|v_primitive|]|]
[|[|Opt Int|]; [|v_constr|]; [|v_opaque|]; [|v_primitive|]; [|v_bool|]|]

let v_typing_flags =
v_tuple "typing_flags"
Expand Down Expand Up @@ -341,6 +341,43 @@ let v_retro_action =
let v_retroknowledge =
v_sum "module_retroknowledge" 1 [|[|List v_retro_action|]|]

let v_instance_mask = Opt (v_pair (Array (Opt Int)) (Array (Opt Int)))

let v_sort_pattern = Sum ("sort_pattern", 3,
[|[|Opt Int|]; (* PSType *)
[|Opt Int; Opt Int|] (* PSQSort *)
|])

let rec v_hpattern = Sum ("head_pattern", 0,
[|[|Int|]; (* PHRel *)
[|v_sort_pattern|]; (* PHSort *)
[|v_cst; v_instance_mask|]; (* PHSymbol *)
[|v_ind; v_instance_mask|]; (* PHInd *)
[|v_cons; v_instance_mask|]; (* PHConstr *)
[|v_uint63|]; (* PHInt *)
[|Float64|]; (* PHFloat *)
[|Array v_patarg; v_patarg|]; (* PHLambda *)
[|Array v_patarg; v_patarg|]; (* PHProd *)
|])

and v_elimination = Sum ("pattern_elimination", 0,
[|[|Array v_patarg|]; (* PEApp *)
[|v_ind; v_instance_mask; v_patarg; Array v_patarg|]; (* PECase *)
[|v_proj|]; (* PEProj *)
|])

and v_head_elim = Tuple ("head*elims", [|v_hpattern; List v_elimination|])

and v_patarg = Sum ("pattern_argument", 1,
[|[|Int|]; (* EHole *)
[|v_head_elim|]; (* ERigid *)
|])

let v_rewrule = v_tuple "rewrite_rule"
[| v_tuple "nvars" [| Int; Int; Int |]; v_pair v_instance_mask (List v_elimination); v_constr |]
let v_rrb = v_tuple "rewrite_rules_body"
[| List (v_pair v_cst v_rewrule) |]

let rec v_mae =
Sum ("module_alg_expr",0,
[|[|v_mp|]; (* SEBident *)
Expand All @@ -352,6 +389,7 @@ let rec v_sfb =
Sum ("struct_field_body",0,
[|[|v_cb|]; (* SFBconst *)
[|v_ind_pack|]; (* SFBmind *)
[|v_rrb|]; (* SFBrules *)
[|v_module|]; (* SFBmodule *)
[|v_modtype|] (* SFBmodtype *)
|])
Expand Down Expand Up @@ -380,8 +418,9 @@ and v_modtype =

let v_vodigest = Sum ("module_impl",0, [| [|String|]; [|String;String|] |])
let v_deps = Array (v_tuple "dep" [|v_dp;v_vodigest|])
let v_flags = v_tuple "flags" [|v_bool|] (* Allow Rewrite Rules *)
let v_compiled_lib =
v_tuple "compiled" [|v_dp;v_module;v_context_set;v_deps|]
v_tuple "compiled" [|v_dp;v_module;v_context_set;v_deps; v_flags|]

(** STM objects *)

Expand Down
51 changes: 51 additions & 0 deletions clib/cArray.ml
Expand Up @@ -25,6 +25,7 @@ sig
'a array -> 'b array -> 'c array -> 'd array -> bool
val for_all_i : (int -> 'a -> bool) -> int -> 'a array -> bool
val findi : (int -> 'a -> bool) -> 'a array -> int option
val find2_map : ('a -> 'b -> 'c option) -> 'a array -> 'b array -> 'c option
val hd : 'a array -> 'a
val tl : 'a array -> 'a array
val last : 'a array -> 'a
Expand All @@ -45,11 +46,16 @@ sig
('a -> 'b -> 'c -> 'd -> 'e -> 'a) -> 'a -> 'b array -> 'c array -> 'd array -> 'e array -> 'a
val fold_left2_i :
(int -> 'a -> 'b -> 'c -> 'a) -> 'a -> 'b array -> 'c array -> 'a
val fold_left3_i :
(int -> 'a -> 'b -> 'c -> 'd -> 'a) -> 'a -> 'b array -> 'c array -> 'd array -> 'a
val fold_left_from : int -> ('a -> 'b -> 'a) -> 'a -> 'b array -> 'a
val map_to_list : ('a -> 'b) -> 'a array -> 'b list
val map_of_list : ('a -> 'b) -> 'a list -> 'b array
val chop : int -> 'a array -> 'a array * 'a array
val split : ('a * 'b) array -> 'a array * 'b array
val split3 : ('a * 'b * 'c) array -> 'a array * 'b array * 'c array
val split4 : ('a * 'b * 'c * 'd) array -> 'a array * 'b array * 'c array * 'd array
val transpose : 'a array array -> 'a array array
val map2_i : (int -> 'a -> 'b -> 'c) -> 'a array -> 'b array -> 'c array
val map3 :
('a -> 'b -> 'c -> 'd) -> 'a array -> 'b array -> 'c array -> 'd array
Expand All @@ -60,6 +66,7 @@ sig
val iter3 : ('a -> 'b -> 'c -> unit) -> 'a array -> 'b array -> 'c array -> unit
val fold_left_map : ('a -> 'b -> 'a * 'c) -> 'a -> 'b array -> 'a * 'c array
val fold_right_map : ('a -> 'c -> 'b * 'c) -> 'a array -> 'c -> 'b array * 'c
val fold_left_map_i : (int -> 'a -> 'b -> 'a * 'c) -> 'a -> 'b array -> 'a * 'c array
val fold_left2_map : ('a -> 'b -> 'c -> 'a * 'd) -> 'a -> 'b array -> 'c array -> 'a * 'd array
val fold_left2_map_i : (int -> 'a -> 'b -> 'c -> 'a * 'd) -> 'a -> 'b array -> 'c array -> 'a * 'd array
val fold_right2_map : ('a -> 'b -> 'c -> 'd * 'c) -> 'a array -> 'b array -> 'c -> 'd array * 'c
Expand Down Expand Up @@ -189,6 +196,19 @@ let findi (pred: int -> 'a -> bool) (arr: 'a array) : int option =
None
with Found i -> Some i

let find2_map (type a) pred arr1 arr2 =
let exception Found of a in
let n = Array.length arr1 in
if not (Array.length arr2 = n) then failwith "Array.find2_map";
try
for i=0 to n - 1 do
match pred (Array.unsafe_get arr1 i) (Array.unsafe_get arr2 i) with
| Some r -> raise (Found r)
| None -> ()
done;
None
with Found i -> Some i

let hd v =
match Array.length v with
| 0 -> failwith "Array.hd"
Expand Down Expand Up @@ -279,6 +299,15 @@ let fold_left3 f a v1 v2 v3 =
invalid_arg "Array.fold_left3";
fold a 0

let fold_left3_i f a v1 v2 v3 =
let lv1 = Array.length v1 in
let rec fold a n =
if n >= lv1 then a else fold (f n a (uget v1 n) (uget v2 n) (uget v3 n)) (succ n)
in
if Array.length v2 <> lv1 || Array.length v3 <> lv1 then
invalid_arg "Array.fold_left3_i";
fold a 0

let fold_left4 f a v1 v2 v3 v4 =
let lv1 = Array.length v1 in
let rec fold a n =
Expand Down Expand Up @@ -336,6 +365,23 @@ let chop n v =
let split v =
(Array.map fst v, Array.map snd v)

let split3 v =
(Array.map (fun (a, _, _) -> a) v,
Array.map (fun (_, b, _) -> b) v,
Array.map (fun (_, _, c) -> c) v)

let split4 v =
(Array.map (fun (a, _, _, _) -> a) v,
Array.map (fun (_, b, _, _) -> b) v,
Array.map (fun (_, _, c, _) -> c) v,
Array.map (fun (_, _, _, d) -> d) v)

let transpose a =
let n = Array.length a in
if n = 0 then [||] else
let n' = Array.length (Array.unsafe_get a 0) in
Array.init n' (fun i -> Array.init n (fun j -> a.(j).(i)))

let map2_i f v1 v2 =
let len1 = Array.length v1 in
let len2 = Array.length v2 in
Expand Down Expand Up @@ -447,6 +493,11 @@ let fold_left2_map f e v1 v2 =
let v' = map2 (fun x1 x2 -> let (e,y) = f !e' x1 x2 in e' := e; y) v1 v2 in
(!e',v')

let fold_left_map_i f e v =
let e' = ref e in
let v' = mapi (fun idx x -> let (e,y) = f idx !e' x in e' := e; y) v in
(!e',v')

let fold_left2_map_i f e v1 v2 =
let e' = ref e in
let v' = map2_i (fun idx x1 x2 -> let (e,y) = f idx !e' x1 x2 in e' := e; y) v1 v2 in
Expand Down
12 changes: 12 additions & 0 deletions clib/cArray.mli
Expand Up @@ -38,6 +38,10 @@ sig

val findi : (int -> 'a -> bool) -> 'a array -> int option

val find2_map : ('a -> 'b -> 'c option) -> 'a array -> 'b array -> 'c option
(** First result which is not None, or None;
[Failure "Array.find2_map"] if the arrays don't have the same length *)

val hd : 'a array -> 'a
(** First element of an array, or [Failure "Array.hd"] if empty. *)

Expand Down Expand Up @@ -68,6 +72,8 @@ sig
('a -> 'b -> 'c -> 'd -> 'e -> 'a) -> 'a -> 'b array -> 'c array -> 'd array -> 'e array -> 'a
val fold_left2_i :
(int -> 'a -> 'b -> 'c -> 'a) -> 'a -> 'b array -> 'c array -> 'a
val fold_left3_i :
(int -> 'a -> 'b -> 'c -> 'd -> 'a) -> 'a -> 'b array -> 'c array -> 'd array -> 'a
val fold_left_from : int -> ('a -> 'b -> 'a) -> 'a -> 'b array -> 'a

val map_to_list : ('a -> 'b) -> 'a array -> 'b list
Expand All @@ -81,6 +87,9 @@ sig
Raise [Failure "Array.chop"] if [i] is not a valid index. *)

val split : ('a * 'b) array -> 'a array * 'b array
val split3 : ('a * 'b * 'c) array -> 'a array * 'b array * 'c array
val split4 : ('a * 'b * 'c * 'd) array -> 'a array * 'b array * 'c array * 'd array
val transpose : 'a array array -> 'a array array

val map2_i : (int -> 'a -> 'b -> 'c) -> 'a array -> 'b array -> 'c array
val map3 :
Expand All @@ -104,6 +113,9 @@ sig
val fold_right_map : ('a -> 'c -> 'b * 'c) -> 'a array -> 'c -> 'b array * 'c
(** Same, folding on the right *)

val fold_left_map_i : (int -> 'a -> 'b -> 'a * 'c) -> 'a -> 'b array -> 'a * 'c array
(** Same than [fold_left_map] but passing the index of the array *)

val fold_left2_map : ('a -> 'b -> 'c -> 'a * 'd) -> 'a -> 'b array -> 'c array -> 'a * 'd array
(** Same with two arrays, folding on the left; see also [Smart.fold_left2_map] *)

Expand Down