From 04b46ebf9a2df27fe43fa5d1a080919ff0cfcde0 Mon Sep 17 00:00:00 2001 From: Hugo Herbelin Date: Sat, 27 Jan 2024 20:44:47 +0100 Subject: [PATCH] Fixing #4056: add refolding of named cofixpoints in tactic "simpl". We mostly reuse the code existing to refold fixpoints. --- pretyping/tacred.ml | 296 +++++++++++++++++++++++-------------- test-suite/success/simpl.v | 47 ++++++ 2 files changed, 228 insertions(+), 115 deletions(-) diff --git a/pretyping/tacred.ml b/pretyping/tacred.ml index 1c77e8e208186..96875d4b4a38a 100644 --- a/pretyping/tacred.ml +++ b/pretyping/tacred.ml @@ -191,9 +191,17 @@ type constant_elimination = | EliminationFix of fix_evaluation_data | EliminationCases of constr * int | EliminationProj of constr * int - | NotAnElimination + | NotAnElimination of constr + | NotAnEliminationConstant -(* [compute_consteval] determines whether f is an "elimination constant" +type constant_coelimination = + | CoEliminationCoFix of fix_evaluation_data + | CoEliminationConstruct of constr + | CoEliminationPrimitive of constr + | NotACoElimination of constr + | NotACoEliminationConstant + +(* [compute_constant_elimination] determines whether f is an "elimination constant" either [yn:Tn]..[y1:T1](match yi with f1..fk end g1 ..gp) @@ -217,7 +225,7 @@ type constant_elimination = the xp..x1. *) -let compute_fix_reversibility sigma labs args fix = +let compute_constant_reversibility sigma labs args fix = let nlam = List.length labs in let nargs = List.length args in if nargs > nlam then @@ -251,7 +259,7 @@ let compute_fix_reversibility sigma labs args fix = typed_reversible_args, nlam, nargs let check_fix_reversibility env sigma ref u labs args minarg refs ((lv,i),_ as fix) = - let li, nlam, nargs = compute_fix_reversibility sigma labs args (mkFix fix) in + let li, nlam, nargs = compute_constant_reversibility sigma labs args (mkFix fix) in let k = lv.(i) in let refolding_data = { refolding_names = refs; @@ -275,22 +283,34 @@ let check_fix_reversibility env sigma ref u labs args minarg refs ((lv,i),_ as f refolding_data; } -let compute_fix_wrapper ((cache,_),allowed_reds) env sigma ref u = +let check_cofix_reversibility env sigma ref u labs args minarg refs (i,_ as cofix) = + let li, nlam, nargs = compute_constant_reversibility sigma labs args (mkCoFix cofix) in + let refolding_data = { + refolding_names = refs; + refolding_wrapper_data = li; + expected_args = nlam; + } in + { + trigger_min_args = max minarg nlam; (* Does not matter; will be maximally applied anyway *) + refolding_target = ref; + refolding_data; + } + +let compute_recursive_wrapper ((cache,_,_),allowed_reds) env sigma ref u = try match reference_opt_value cache env sigma ref u with | None -> None | Some c -> let labs, ccl = whd_decompose_lambda env sigma c in let c, l = whd_stack_gen allowed_reds env sigma ccl in - assert (isFix sigma c); Some (labs, l) with Not_found (* Undefined ref *) -> None (* Heuristic to look if global names are associated to other components of a mutual fixpoint *) -let invert_names allowed_reds env sigma ref u names i = +let invert_recursive_names cache_reds env sigma ref u names i = let labs, l = - match compute_fix_wrapper allowed_reds env sigma ref u with + match compute_recursive_wrapper cache_reds env sigma ref u with | None -> assert false | Some (labs, l) -> labs, l in let make_name j = @@ -308,7 +328,7 @@ let invert_names allowed_reds env sigma ref u names i = match refi with | None -> None | Some ref -> - match compute_fix_wrapper allowed_reds env sigma ref u with + match compute_recursive_wrapper cache_reds env sigma ref u with | None -> None | Some (labs', l') -> let eq_constr c1 c2 = EConstr.eq_constr sigma c1 c2 in @@ -321,12 +341,12 @@ let deactivate_delta allowed_reds = (* Act both on Delta and transparent state as not all reduction functions work the same *) RedFlags.(red_add_transparent (red_sub allowed_reds fDELTA) TransparentState.empty) -(* [compute_consteval] stepwise expands an arbitrary long sequence of +(* [compute_constant_elimination] stepwise expands an arbitrary long sequence of reversible constants, eventually refolding to the initial constant for unary fixpoints and to the last constant encapsulating the Fix for mutual fixpoints *) -let compute_consteval ((cache,_),allowed_reds as cache_reds) env sigma ref u = +let compute_constant_elimination ((cache,_,_),allowed_reds as cache_reds) env sigma ref u = let allowed_reds_no_delta = deactivate_delta allowed_reds in let rec srec env all_abs lastref lastu onlyproj c stk = let c', args = whd_stack_gen allowed_reds_no_delta env sigma c in @@ -342,14 +362,14 @@ let compute_consteval ((cache,_),allowed_reds as cache_reds) env sigma ref u = let nbfix = Array.length lv in (if nbfix = 1 then (* Try to refold to [ref] *) - let names = [|Some (ref,u)|] in - try EliminationFix (check_fix_reversibility env sigma ref u all_abs args n_all_abs names fix) - with Elimconst -> NotAnElimination + let refs = [|Some (ref,u)|] in + try EliminationFix (check_fix_reversibility env sigma ref u all_abs args n_all_abs refs fix) + with Elimconst -> NotAnEliminationConstant else (* Try to refold to [lastref] *) - let last_labs, last_args, names = invert_names cache_reds env sigma lastref lastu names i in + let last_labs, last_args, names = invert_recursive_names cache_reds env sigma lastref lastu names i in try EliminationFix (check_fix_reversibility env sigma lastref lastu last_labs last_args n_all_abs names fix) - with Elimconst -> NotAnElimination) + with Elimconst -> NotAnEliminationConstant) | Case (_,_,_,_,_,d,_) when isRel sigma d && not onlyproj -> EliminationCases (it_mkLambda (Stack.zip sigma (c',Stack.append_app_list args stk)) all_abs, List.length all_abs) | Case (ci,u,pms,p,iv,d,lf) -> srec env all_abs lastref lastu true d Stack.(Case (mkCaseStk (ci,u,pms,p,iv,lf)) :: append_app_list args stk) @@ -359,28 +379,87 @@ let compute_consteval ((cache,_),allowed_reds as cache_reds) env sigma ref u = (* Continue stepwise unfolding from [c' args] *) let ref, u = destEvalRefU sigma c' in (match reference_opt_value cache env sigma ref u with - | None -> NotAnElimination (* e.g. if a rel *) + | None -> NotAnEliminationConstant (* e.g. if a rel *) | Some c -> srec env all_abs ref u onlyproj (applist (c, args)) stk) - | _ -> NotAnElimination + | _ -> NotAnEliminationConstant in match reference_opt_value cache env sigma ref u with - | None -> NotAnElimination - | Some c -> srec env [] ref u false c Stack.empty + | None -> NotAnEliminationConstant + | Some c -> match srec env [] ref u false c Stack.empty with NotAnEliminationConstant -> NotAnElimination c | e -> e + +(* [compute_constant_coelimination] stepwise expands an arbitrary long sequence of + reversible constants, eventually refolding to the initial constant + for unary cofixpoints and to the last constant encapsulating the CoFix + for mutual cofixpoints *) + +let compute_constant_coelimination ((cache,_,_),allowed_reds as cache_reds) env sigma ref u = + let allowed_reds_no_delta = deactivate_delta allowed_reds in + let rec srec env all_abs lastref lastu c = + let c', args = whd_stack_gen allowed_reds_no_delta env sigma c in + (* We now know that the initial [ref] evaluates to [fun all_abs => c' args] *) + (* and that the last visited name in the evaluation is [lastref] *) + match EConstr.kind sigma c' with + | Lambda (id,t,g) -> + assert (List.is_empty args); + let open Context.Rel.Declaration in + srec (push_rel (LocalAssum (id,t)) env) ((id,t)::all_abs) lastref lastu g + | Construct _ -> + let c = it_mkLambda (applist (c', args)) all_abs in + CoEliminationConstruct c + | Int _ | Float _ | Array _ (* reduced by primitives *) -> + let c = it_mkLambda (applist (c', args)) all_abs in + CoEliminationPrimitive c + | CoFix (i,(names,_,_) as cofix) -> + let n_all_abs = List.length all_abs in + let nbfix = Array.length names in + (if nbfix = 1 then + (* Try to refold to [ref] *) + let refs = [|Some (ref,u)|] in + try CoEliminationCoFix (check_cofix_reversibility env sigma ref u all_abs args n_all_abs refs cofix) + with Elimconst -> NotACoEliminationConstant + else + (* Try to refold to [lastref] *) + let last_labs, last_args, names = invert_recursive_names cache_reds env sigma lastref lastu names i in + try CoEliminationCoFix (check_cofix_reversibility env sigma lastref lastu last_labs last_args n_all_abs names cofix) + with Elimconst -> NotACoEliminationConstant) + | _ when isTransparentEvalRef env sigma (RedFlags.red_transparent allowed_reds) c' -> + (* Continue stepwise unfolding from [c' args] *) + let ref, u = destEvalRefU sigma c' in + (match reference_opt_value cache env sigma ref u with + | None -> NotACoEliminationConstant (* e.g. if a rel *) + | Some c -> srec env all_abs ref u (applist (c, args))) + | _ -> NotACoEliminationConstant + in + match reference_opt_value cache env sigma ref u with + | None -> NotACoEliminationConstant + | Some c -> match srec env [] ref u c with NotACoEliminationConstant -> NotACoElimination c | e -> e let make_simpl_cache () = - CacheTable.create 12, CacheTable.create 12 + CacheTable.create 12, CacheTable.create 12, CacheTable.create 12 + +let compute_reference_elimination ((_,elim_cache,_),_ as cache_reds) env sigma ref u = + match ref with + | EvalConst cst as ref -> + let cu = cst, EInstance.kind sigma u in + (match CacheTable.find_opt elim_cache cu with + | Some v -> v + | None -> + let v = compute_constant_elimination cache_reds env sigma ref u in + CacheTable.add elim_cache cu v; + v) + | ref -> compute_constant_elimination cache_reds env sigma ref u -let reference_eval ((_,cache),_ as cache_reds) env sigma ref u = +let compute_reference_coelimination ((_,_,coelim_cache),_ as cache_reds) env sigma ref u = match ref with | EvalConst cst as ref -> let cu = cst, EInstance.kind sigma u in - (match CacheTable.find_opt cache cu with + (match CacheTable.find_opt coelim_cache cu with | Some v -> v | None -> - let v = compute_consteval cache_reds env sigma ref u in - CacheTable.add cache cu v; + let v = compute_constant_coelimination cache_reds env sigma ref u in + CacheTable.add coelim_cache cu v; v) - | ref -> compute_consteval cache_reds env sigma ref u + | ref -> compute_constant_coelimination cache_reds env sigma ref u (* If f is bound to EliminationFix (n',refs,infos), then n' is the minimal number of args for triggering the reduction and infos is @@ -443,68 +522,49 @@ let mkLambda_with_eta sigma x t c = if isRelN sigma 1 b then applist (f, List.map (Vars.lift (-1)) args) else mkLambda (x, t, c) -let contract_fix env sigma f - ((recindices,bodynum),(_names,_types,bodies as typedbodies) as fixp) = match f with -| None -> contract_fix sigma fixp -| Some f -> - let {refolding_names; refolding_wrapper_data = lv; expected_args = n}, largs = f in - let lu = List.firstn n largs in - let p = List.length lv in - let lyi = List.map fst lv in - let la = - List.map_i (fun q aq -> - (* k from the comment is q+1 *) - try mkRel (p+1-(List.index Int.equal (n-q) lyi)) - with Not_found -> Vars.lift p aq) - 0 lu - in - let make_Fi i = match refolding_names.(i) with - | None -> mkFix((recindices,i),typedbodies) - | Some (ref, u) -> - let body = applist (mkEvalRef ref u, la) in - List.fold_left_i (fun q (* j = n+1-q *) c (ij,tij) -> - let subst = List.map (Vars.lift (-q)) (List.firstn (n-ij) la) in - let tij' = Vars.substl (List.rev subst) tij in - let x = make_annot xname Sorts.Relevant in (* TODO relevance *) - mkLambda_with_eta sigma x tij' c) +let contract_rec env sigma f nbodies mk_rec contract body = + match f with + | None -> contract () + | Some f -> + let {refolding_names; refolding_wrapper_data = lv; expected_args = n}, largs = f in + let lu = List.firstn n largs in + let p = List.length lv in + let lyi = List.map fst lv in + let la = + List.map_i (fun q aq -> + (* k from the comment is q+1 *) + try mkRel (p+1-(List.index Int.equal (n-q) lyi)) + with Not_found -> Vars.lift p aq) + 0 lu + in + let make_Fi i = match refolding_names.(i) with + | None -> mk_rec i + | Some (ref, u) -> + let body = applist (mkEvalRef ref u, la) in + List.fold_left_i (fun q (* j = n+1-q *) c (ij,tij) -> + let subst = List.map (Vars.lift (-q)) (List.firstn (n-ij) la) in + let tij' = Vars.substl (List.rev subst) tij in + let x = make_annot xname Sorts.Relevant in (* TODO relevance *) + mkLambda_with_eta sigma x tij' c) 1 body (List.rev lv) - in - let nbodies = Array.length recindices in - let lbodies = List.init nbodies make_Fi in - let c = substl_with_function (List.rev lbodies) sigma (nf_beta env sigma bodies.(bodynum)) in - nf_beta env sigma c - -let contract_cofix env sigma f - (bodynum,(names,_,bodies as typedbodies) as fixp) args = match f with -| None -> contract_cofix sigma fixp -| Some f -> - let make_Fi i = - let cofix = mkCoFix (i,typedbodies) in - match f with - | EvalConst kn, u -> - begin - if Int.equal i bodynum then mkConstU (kn, u) - else match names.(i).binder_name with - | Anonymous -> cofix - | Name id -> - (* In case of a call to another component of a block of - mutual inductive, try to reuse the global name if - the block was indeed initially built as a global - definition *) - let kn = Constant.change_label kn (Label.of_id id) in - let cst = (kn, EInstance.kind sigma u) in - try match constant_opt_value_in env cst with - | None -> cofix - (* TODO: check kn is correct *) - | Some _ -> mkConstU (kn, u) - with Not_found -> cofix - end - | _ -> - cofix in - let nbodies = Array.length bodies in - let subbodies = List.init nbodies make_Fi in - substl_with_function (List.rev subbodies) - sigma (nf_beta env sigma bodies.(bodynum)) + in + let lbodies = List.init nbodies make_Fi in + let c = substl_with_function (List.rev lbodies) sigma (nf_beta env sigma body) in + nf_beta env sigma c + +let contract_fix env sigma f ((recindices,bodynum),(_names,_types,bodies as typedbodies) as fixp) = + contract_rec env sigma f + (Array.length bodies) + (fun i -> mkFix((recindices,i),typedbodies)) + (fun () -> contract_fix sigma fixp) + bodies.(bodynum) + +let contract_cofix env sigma f (bodynum,(_names,_types,bodies as typedbodies) as cofixp) = + contract_rec env sigma f + (Array.length bodies) + (fun i -> mkCoFix(i,typedbodies)) + (fun () -> contract_cofix sigma cofixp) + bodies.(bodynum) let reducible_construct sigma c = match EConstr.kind sigma c with | Construct _ | CoFix _ (* reduced by case *) @@ -519,14 +579,12 @@ let reduce_mind_case env sigma f (ci, u, pms, p, iv, (hd, args), lf) = let ctx = EConstr.expand_branch env sigma u pms cstr br in let br = it_mkLambda_or_LetIn (snd br) ctx in Reduced (applist (br, real_cargs)) - (* TODO, consider the case of lambdas in front of the CoFix ?? *) | CoFix (bodynum,(names,_,_) as cofix) -> - let cofix_def = contract_cofix env sigma f cofix args in + let cofix_def = contract_cofix env sigma f cofix in Reduced (mkCase (ci, u, pms, p, iv, applist(cofix_def, args), lf)) | Int _ | Float _ | Array _ -> NotReducible | _ -> assert false - let match_eval_ref env sigma constr stack = match EConstr.kind sigma constr with | Const (sp, u) -> @@ -639,11 +697,19 @@ let make_simpl_reds env = let reds = red_add reds fBETA in reds +let rec descend cache env sigma target (ref,u) args = + let c = reference_value cache env sigma ref u in + if evaluable_reference_eq sigma ref target then + (c,args) + else + let c', lrest = whd_betalet_stack env sigma (applist (c, args)) in + descend cache env sigma target (destEvalRefU sigma c') lrest + (* [red_elim_const] contracts iota/fix/cofix redexes hidden behind constants by keeping the name of the constants in the recursive calls; it fails if no redex is around *) -let rec red_elim_const ((cache,_),_ as cache_reds) env sigma ref u largs = +let rec red_elim_const ((cache,_,_),_ as cache_reds) env sigma ref u largs = let open ReductionBehaviour in let nargs = List.length largs in let* largs, unfold_anyway, unfold_nonelim, nocase = @@ -681,7 +747,7 @@ let rec red_elim_const ((cache,_),_ as cache_reds) env sigma ref u largs = not is_empty && nargs >= n, true) in - let ans = match reference_eval cache_reds env sigma ref u with + let ans = match compute_reference_elimination cache_reds env sigma ref u with | EliminationCases (c,n) when nargs >= n -> let c', lrest = whd_nothing_for_iota env sigma (c, largs) in let* ans = special_red_case cache_reds env sigma (EConstr.destCase sigma c') in @@ -692,20 +758,12 @@ let rec red_elim_const ((cache,_),_ as cache_reds) env sigma ref u largs = Reduced ((ans, lrest), nocase) | EliminationFix {trigger_min_args; refolding_target; refolding_data} when nargs >= trigger_min_args -> - let rec descend (ref,u) args = - let c = reference_value cache env sigma ref u in - if evaluable_reference_eq sigma ref refolding_target then - (c,args) - else - let c', lrest = whd_betalet_stack env sigma (applist(c,args)) in - descend (destEvalRefU sigma c') lrest in - let (_, midargs as s) = descend (ref,u) largs in + let (_, midargs as s) = descend cache env sigma refolding_target (ref,u) largs in let d, lrest = whd_nothing_for_iota env sigma s in let f = refolding_data, midargs in let* (c, rest) = reduce_fix cache_reds env sigma (Some f) (destFix sigma d) lrest in Reduced ((c, rest), nocase) - | NotAnElimination when unfold_nonelim -> - let c = reference_value cache env sigma ref u in + | NotAnElimination c when unfold_nonelim -> Reduced ((whd_betaiotazeta env sigma (applist (c, largs)), []), nocase) | _ -> NotReducible in @@ -850,26 +908,34 @@ and reduce_proj allowed_reds env sigma c = in redrec c and special_red_case allowed_reds env sigma (ci, u, pms, p, iv, c, lf) = - let* f, head, args = whd_construct allowed_reds env sigma (c, []) in - reduce_mind_case env sigma f (ci, u, pms, p, iv, (head, args), lf) + let* f, s = whd_construct allowed_reds env sigma (c, []) in + reduce_mind_case env sigma f (ci, u, pms, p, iv, s, lf) and whd_construct_stack allowed_reds env sigma s = - let* _, head, args = whd_construct allowed_reds env sigma (s, []) in - Reduced (head, args) + let* _, s = whd_construct allowed_reds env sigma (s, []) in + Reduced s (* reduce until finding an applied constructor (or primitive value) or fail *) -and whd_construct ((cache,_),_ as allowed_reds) env sigma s = - let (constr, cargs) = whd_simpl_stack allowed_reds env sigma s in +and whd_construct ((cache,_,_),allowed_reds as cache_reds) env sigma c = + let (constr, cargs) = whd_simpl_stack cache_reds env sigma c in match match_eval_ref env sigma constr cargs with | Some (ref, u) -> - (match reference_opt_value cache env sigma ref u with - | None -> NotReducible - | Some gvalue -> - if reducible_construct sigma gvalue then Reduced (Some (ref, u), gvalue, cargs) - else whd_construct allowed_reds env sigma (gvalue, cargs)) + (match compute_reference_coelimination cache_reds env sigma ref u with + | CoEliminationConstruct c -> Reduced (None, whd_stack_gen allowed_reds env sigma (applist (c, cargs))) + | CoEliminationPrimitive c -> Reduced (None, whd_stack_gen allowed_reds env sigma (applist (c, cargs))) + | CoEliminationCoFix {refolding_target; refolding_data} -> + let (_, midargs as s) = descend cache env sigma refolding_target (ref,u) cargs in + let s = whd_nothing_for_iota env sigma s in + let f = refolding_data, midargs in + Reduced (Some f, s) + | NotACoElimination c -> + (* Now try to get a construct/cofix/prim using the arguments of the constant + so that possible internal iota-redexes are triggered *) + whd_construct cache_reds env sigma (c, cargs) + | NotACoEliminationConstant -> NotReducible) | None -> - if reducible_construct sigma constr then Reduced (None, constr, cargs) + if reducible_construct sigma constr then Reduced (None, (constr, cargs)) else NotReducible (************************************************************************) @@ -1318,7 +1384,7 @@ let find_hnf_rectype env sigma t = exception NotStepReducible let one_step_reduce env sigma c = - let (cache,_), _ as cache_reds = make_simpl_cache(), RedFlags.betadeltazeta in + let (cache,_,_), _ as cache_reds = make_simpl_cache(), RedFlags.betadeltazeta in let rec redrec (x, stack) = match EConstr.kind sigma x with | Lambda (n,t,c) -> diff --git a/test-suite/success/simpl.v b/test-suite/success/simpl.v index 03f46046cc48d..ca2b550b61859 100644 --- a/test-suite/success/simpl.v +++ b/test-suite/success/simpl.v @@ -264,3 +264,50 @@ Goal test_case 2 = true. simpl. match goal with [ |- true = _ ] => idtac end. Ab (* REDUCED *) End IotaTrigger3. + +Module Bug4056. + +CoInductive stream {A:Type} : Type := + | scons: A->stream->stream. + +Definition stream_unfold {A} (s: @ stream A) := match s with +| scons a s' => (a, scons a s') +end. + +Section A. + CoFixpoint inf_stream1 (x:nat) (n:nat) := + scons n (inf_stream1 x (n+x)). +End A. + +Section B. + Variable (x:nat). + CoFixpoint inf_stream2 (n:nat) := + scons n (inf_stream2 (n+x)). +End B. + +Goal (forall x n, stream_unfold (inf_stream1 x n) = stream_unfold (inf_stream2 x n)). +(* simpl was exposing the cofix on the rhs but not the lhs *) +intros. simpl. +match goal with [ |- (n, scons n (inf_stream1 x (n + x))) = (n, scons n (inf_stream2 x (n + x))) ] => idtac end. +Abort. + +Section C. + Variable (x:nat). + CoFixpoint mut_stream1 (n:nat) := scons n (mut_stream2 (n+x)) + with mut_stream2 (n:nat) := scons n (mut_stream1 (n+x)). +End C. + +Goal (forall x n, stream_unfold (mut_stream1 x n) = stream_unfold (mut_stream2 x n)). +intros. simpl. +match goal with [ |- (n, scons n (mut_stream2 x (n + x))) = (n, scons n (mut_stream1 x (n + x))) ] => idtac end. +Abort. + +Definition inf_stream2_copy n := inf_stream2 n. (* inversible *) +Definition mut_stream2_copy n := mut_stream2 n. (* inversible only towards mut_stream1/mut_stream2 *) + +Goal (forall x n, stream_unfold (inf_stream2_copy x n) = stream_unfold (mut_stream2_copy x n)). +intros. simpl. +match goal with [ |- (n, scons n (inf_stream2_copy x (n + x))) = (n, scons n (mut_stream1 x (n + x))) ] => idtac end. +Abort. + +End Bug4056.