Skip to content

Commit

Permalink
Fixes #11237: support constructor parameters in prim notations for pa…
Browse files Browse the repository at this point in the history
…tterns.

Additionally, we restructurate the code a bit, moving the check on
well-formedness of patterns in constrintern.ml so that the separation
of roles between check_allowed_ref_in_pat (just check the validity so
that another interpretation can be taken if it fails) and rcp_of_pat
(which interpret the pattern) is clearer.

We also check that the "in" clause of "match" accepts a notation for
inductive type at the head.
  • Loading branch information
herbelin committed Jul 31, 2023
1 parent ce3025d commit b9a867b
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 26 deletions.
48 changes: 36 additions & 12 deletions interp/constrintern.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1706,28 +1706,50 @@ let drop_notations_pattern (test_kind_top,test_kind_inner) genv env pat =
error_invalid_pattern_notation ?loc ()
in
(* [rcp_of_glob] : from [glob_constr] to [raw_cases_pattern_expr] *)
let rec rcp_of_glob scopes x = DAst.(map (function
| GVar id -> RCPatAtom (Some (CAst.make ?loc:x.loc id,scopes))
let make_pars ?loc g =
let env = Global.env () in
let n = match g with
| GlobRef.ConstructRef (ind,_) -> Inductiveops.inductive_nparams env ind
| _ -> 0 in
List.make n (DAst.make ?loc @@ RCPatAtom None)
in
(* Check Ind/Construct structure of patterns for primitive notation *)
let rec check_allowed_ref_in_pat test_kind = DAst.(with_loc_val (fun ?loc -> function
| GVar _ | GHole _ -> ()
| GRef (g,_) -> test_kind.test_kind ?loc g
| GApp (f, l) ->
begin match DAst.get f with
| GRef (g, _) ->
test_kind.test_kind ?loc g;
let nparams = match g with
| IndRef ind | ConstructRef (ind,_) -> Inductiveops.inductive_nparams (Global.env ()) ind
| _ -> assert false in
let l = try List.skipn nparams l with Failure _ -> raise Not_found in
List.iter (check_allowed_ref_in_pat test_kind_inner) l
| _ -> raise Not_found
end
| _ -> raise Not_found)) in
(* Interpret a primitive notation (part of Glob_ops.cases_pattern_of_glob_constr) *)
let rec rcp_of_glob scopes x = DAst.(map_with_loc (fun ?loc -> function
| GVar id -> RCPatAtom (Some (CAst.make ?loc id,scopes))
| GHole (_,_) -> RCPatAtom None
| GRef (g,_) -> RCPatCstr (g, [])
| GApp (r, l) ->
begin match DAst.get r with
| GRef (g,_) ->
let allscs = find_arguments_scope g in
let allscs = simple_adjust_scopes (List.length l) allscs in
RCPatCstr (g, List.map2 (fun sc a -> rcp_of_glob (sc,snd scopes) a) allscs l)
let params = make_pars ?loc g in (* Rem: no letins *)
let nparams = List.length params in
let allscs = List.skipn nparams allscs in
let l = List.skipn nparams l in
let pl = List.map2 (fun sc a -> rcp_of_glob (sc,snd scopes) a) allscs l in
RCPatCstr (g, params @ pl)
| _ ->
CErrors.anomaly Pp.(str "Invalid return pattern from Notation.interp_prim_token_cases_pattern_expr.")
end
| _ -> CErrors.anomaly Pp.(str "Invalid return pattern from Notation.interp_prim_token_cases_pattern_expr."))) x
in
let make_pars ?loc g =
let env = Global.env () in
let n = match g with
| GlobRef.ConstructRef (ind,_) -> Inductiveops.inductive_nparams env ind
| _ -> 0 in
List.make n (DAst.make ?loc @@ RCPatAtom None)
in
let rec drop_abbrev {test_kind} ?loc scopes qid add_par_if_no_ntn_with_par no_impl pats =
try
if qualid_is_ident qid && Option.cata (Id.Set.mem (qualid_basename qid)) false env.pat_ids && List.is_empty pats then
Expand Down Expand Up @@ -1775,7 +1797,8 @@ let drop_notations_pattern (test_kind_top,test_kind_inner) genv env pat =
end
| CPatNotation (_,(InConstrEntry,"- _"),([a],[],[]),[]) when is_non_zero_pat a ->
let p = match a.CAst.v with CPatPrim (Number (_, p)) -> p | _ -> assert false in
let pat, _df = Notation.interp_prim_token_cases_pattern_expr ?loc (ensure_kind test_kind_inner) (Number (SMinus,p)) scopes in
let pat, _df = Notation.interp_prim_token_cases_pattern_expr ?loc
(check_allowed_ref_in_pat test_kind) (Number (SMinus,p)) scopes in
rcp_of_glob scopes pat
| CPatNotation (_,(InConstrEntry,"( _ )"),([a],[],[]),[]) ->
in_pat test_kind scopes a
Expand All @@ -1788,7 +1811,8 @@ let drop_notations_pattern (test_kind_top,test_kind_inner) genv env pat =
| CPatDelimiters (key, e) ->
in_pat test_kind ([],find_delimiters_scope ?loc key::snd scopes) e
| CPatPrim p ->
let pat, _df = Notation.interp_prim_token_cases_pattern_expr ?loc test_kind_inner.test_kind p scopes in
let pat, _df = Notation.interp_prim_token_cases_pattern_expr ?loc
(check_allowed_ref_in_pat test_kind) p scopes in
rcp_of_glob scopes pat
| CPatAtom (Some id) ->
begin
Expand Down
15 changes: 2 additions & 13 deletions interp/notation.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1449,19 +1449,8 @@ let interp_prim_token_gen ?loc g p local_scopes =
let interp_prim_token ?loc =
interp_prim_token_gen ?loc (fun _ -> ())

let rec check_allowed_ref_in_pat looked_for = DAst.(with_val (function
| GVar _ | GHole _ -> ()
| GRef (g,_) -> looked_for g
| GApp (f, l) ->
begin match DAst.get f with
| GRef (g, _) ->
looked_for g; List.iter (check_allowed_ref_in_pat looked_for) l
| _ -> raise Not_found
end
| _ -> raise Not_found))

let interp_prim_token_cases_pattern_expr ?loc looked_for p =
interp_prim_token_gen ?loc (check_allowed_ref_in_pat looked_for) p
let interp_prim_token_cases_pattern_expr ?loc check_allowed p =
interp_prim_token_gen ?loc check_allowed p

let warn_deprecated_notation =
Deprecation.create_warning ~object_name:"Notation" ~warning_name_if_no_since:"deprecated-notation"
Expand Down
2 changes: 1 addition & 1 deletion interp/notation.mli
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ val declare_string_interpreter : ?local:bool -> scope_name -> required_module ->
val interp_prim_token : ?loc:Loc.t -> prim_token -> subscopes ->
glob_constr * scope_name option
(* This function returns a glob_const representing a pattern *)
val interp_prim_token_cases_pattern_expr : ?loc:Loc.t -> (GlobRef.t -> unit) -> prim_token ->
val interp_prim_token_cases_pattern_expr : ?loc:Loc.t -> (Glob_term.glob_constr -> unit) -> prim_token ->
subscopes -> glob_constr * scope_name option

(** Return the primitive token associated to a [term]/[cases_pattern];
Expand Down
19 changes: 19 additions & 0 deletions test-suite/output/StringSyntaxPrimitive.v
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,22 @@ Module Test4.
Definition float' := float.
Check mk_floatList (@cons float' 1 [0; 0])%float%list.
End Test4.

Module Bug11237.

Inductive bytes := wrap_bytes { unwrap_bytes : list byte }.

Delimit Scope bytes_scope with bytes.
Bind Scope bytes_scope with bytes.
String Notation bytes wrap_bytes unwrap_bytes : bytes_scope.

Open Scope bytes_scope.

Example test_match :=
match "foo" with
| "foo" => "bar"
| "bar" => "foo"
| x => x
end.

End Bug11237.
7 changes: 7 additions & 0 deletions test-suite/success/Notations2.v
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,10 @@ Fail Notation "[ x ]" := (id x) (x custom doesntexist, only printing).
Fail Notation "# x" := (id x) (in custom doesntexist, only printing).

End TestNonExistentCustomOnlyPrinting.

Module NotationClauseIn.

Notation "1" := unit.
Check fun x => match x in 1 with tt => 0 end.

End NotationClauseIn.

0 comments on commit b9a867b

Please sign in to comment.