|
| 1 | +/- |
| 2 | +Copyright (c) 2023 Kyle Miller. All rights reserved. |
| 3 | +Released under Apache 2.0 license as described in the file LICENSE. |
| 4 | +Authors: Kyle Miller |
| 5 | +-/ |
| 6 | +import Lean |
| 7 | +import Mathlib.Tactic.ToLevel |
| 8 | + |
| 9 | +/-! |
| 10 | +# A `ToExpr` derive handler |
| 11 | +
|
| 12 | +This module defines a `ToExpr` derive handler for inductive types. It supports mutually inductive |
| 13 | +types as well. |
| 14 | +
|
| 15 | +The `ToExpr` derive handlers support universe level polymorphism. This is implemented using the |
| 16 | +`Lean.ToLevel` class. To use `ToExpr` in places where there is universe polymorphism, make sure |
| 17 | +to have a `[ToLevel.{u}]` instance available. |
| 18 | +
|
| 19 | +**Warning:** Import `Mathlib.Tactic.ToExpr` instead of this one. This ensures that you are using |
| 20 | +the universe polymorphic `ToExpr` instances that override the ones from Lean 4 core. |
| 21 | +
|
| 22 | +Implementation note: this derive handler was originally modeled after the `Repr` derive handler. |
| 23 | +-/ |
| 24 | + |
| 25 | +namespace Mathlib.Deriving.ToExpr |
| 26 | + |
| 27 | +open Lean Elab Lean.Parser.Term |
| 28 | +open Meta Command Deriving |
| 29 | + |
| 30 | +def mkToExprHeader (indVal : InductiveVal) : TermElabM Header := do |
| 31 | + -- The auxiliary functions we produce are `indtype -> Expr`. |
| 32 | + let header ← mkHeader ``ToExpr 1 indVal |
| 33 | + return header |
| 34 | + |
| 35 | +/-- Give a term that is equivalent to `(term|mkAppN $f #[$args,*])`. |
| 36 | +As an optimization, `mkAppN` is pre-expanded out to use `Expr.app` directly. -/ |
| 37 | +def mkAppNTerm (f : Term) (args : Array Term) : MetaM Term := |
| 38 | + args.foldlM (fun a b => `(Expr.app $a $b)) f |
| 39 | + |
| 40 | +def mkToExprBody (header : Header) (indVal : InductiveVal) (auxFunName : Name) : |
| 41 | + TermElabM Term := do |
| 42 | + let discrs ← mkDiscrs header indVal |
| 43 | + let alts ← mkAlts |
| 44 | + `(match $[$discrs],* with $alts:matchAlt*) |
| 45 | +where |
| 46 | + mkAlts : TermElabM (Array (TSyntax ``matchAlt)) := do |
| 47 | + let mut alts := #[] |
| 48 | + for ctorName in indVal.ctors do |
| 49 | + let ctorInfo ← getConstInfoCtor ctorName |
| 50 | + let alt ← forallTelescopeReducing ctorInfo.type fun xs _ => do |
| 51 | + let mut patterns := #[] |
| 52 | + -- add `_` pattern for indices |
| 53 | + for _ in [:indVal.numIndices] do |
| 54 | + patterns := patterns.push (← `(_)) |
| 55 | + let mut ctorArgs := #[] |
| 56 | + let mut rhsArgs : Array Term := #[] |
| 57 | + let mkArg (x : Expr) (a : Term) : TermElabM Term := do |
| 58 | + if (← inferType x).isAppOf indVal.name then |
| 59 | + `($(mkIdent auxFunName) $a) |
| 60 | + else if ← Meta.isType x then |
| 61 | + `(toTypeExpr $a) |
| 62 | + else |
| 63 | + `(toExpr $a) |
| 64 | + -- add `_` pattern for inductive parameters, which are inaccessible |
| 65 | + for i in [:ctorInfo.numParams] do |
| 66 | + let a := mkIdent header.argNames[i]! |
| 67 | + ctorArgs := ctorArgs.push (← `(_)) |
| 68 | + rhsArgs := rhsArgs.push <| ← mkArg xs[i]! a |
| 69 | + for i in [:ctorInfo.numFields] do |
| 70 | + let a := mkIdent (← mkFreshUserName `a) |
| 71 | + ctorArgs := ctorArgs.push a |
| 72 | + rhsArgs := rhsArgs.push <| ← mkArg xs[ctorInfo.numParams + i]! a |
| 73 | + patterns := patterns.push (← `(@$(mkIdent ctorName):ident $ctorArgs:term*)) |
| 74 | + let levels ← indVal.levelParams.toArray.mapM (fun u => `(toLevel.{$(mkIdent u)})) |
| 75 | + let rhs : Term ← |
| 76 | + mkAppNTerm (← `(Expr.const $(quote ctorInfo.name) [$levels,*])) rhsArgs |
| 77 | + `(matchAltExpr| | $[$patterns:term],* => $rhs) |
| 78 | + alts := alts.push alt |
| 79 | + return alts |
| 80 | + |
| 81 | +def mkToTypeExpr (argNames : Array Name) (indVal : InductiveVal) : TermElabM Term := do |
| 82 | + let levels ← indVal.levelParams.toArray.mapM (fun u => `(toLevel.{$(mkIdent u)})) |
| 83 | + forallTelescopeReducing indVal.type fun xs _ => do |
| 84 | + let mut args : Array Term := #[] |
| 85 | + for i in [:xs.size] do |
| 86 | + let x := xs[i]! |
| 87 | + let a := mkIdent argNames[i]! |
| 88 | + if ← Meta.isType x then |
| 89 | + args := args.push <| ← `(toTypeExpr $a) |
| 90 | + else |
| 91 | + args := args.push <| ← `(toExpr $a) |
| 92 | + mkAppNTerm (← `((Expr.const $(quote indVal.name) [$levels,*]))) args |
| 93 | + |
| 94 | +def mkLocalInstanceLetDecls (ctx : Deriving.Context) (argNames : Array Name) : |
| 95 | + TermElabM (Array (TSyntax ``Parser.Term.letDecl)) := do |
| 96 | + let mut letDecls := #[] |
| 97 | + for i in [:ctx.typeInfos.size] do |
| 98 | + let indVal := ctx.typeInfos[i]! |
| 99 | + let auxFunName := ctx.auxFunNames[i]! |
| 100 | + let currArgNames ← mkInductArgNames indVal |
| 101 | + let numParams := indVal.numParams |
| 102 | + let currIndices := currArgNames[numParams:] |
| 103 | + let binders ← mkImplicitBinders currIndices |
| 104 | + let argNamesNew := argNames[:numParams] ++ currIndices |
| 105 | + let indType ← mkInductiveApp indVal argNamesNew |
| 106 | + let instName ← mkFreshUserName `localinst |
| 107 | + let toTypeExpr ← mkToTypeExpr argNames indVal |
| 108 | + let letDecl ← `(Parser.Term.letDecl| $(mkIdent instName):ident $binders:implicitBinder* : |
| 109 | + ToExpr $indType := |
| 110 | + { toExpr := $(mkIdent auxFunName), toTypeExpr := $toTypeExpr }) |
| 111 | + letDecls := letDecls.push letDecl |
| 112 | + return letDecls |
| 113 | + |
| 114 | +/-- Fix the output of `mkInductiveApp` to explicitly reference universe levels. -/ |
| 115 | +def fixIndType (indVal : InductiveVal) (t : Term) : TermElabM Term := |
| 116 | + match t with |
| 117 | + | `(@$f $args*) => |
| 118 | + let levels := indVal.levelParams.toArray.map mkIdent |
| 119 | + `(@$f.{$levels,*} $args*) |
| 120 | + | _ => throwError "(internal error) expecting output of `mkInductiveApp`" |
| 121 | + |
| 122 | +/-- Make `ToLevel` instance binders for all the level variables. -/ |
| 123 | +def mkToLevelBinders (indVal : InductiveVal) : TermElabM (TSyntaxArray ``instBinderF) := do |
| 124 | + indVal.levelParams.toArray.mapM (fun u => `(instBinderF| [ToLevel.{$(mkIdent u)}])) |
| 125 | + |
| 126 | +open TSyntax.Compat in |
| 127 | +def mkAuxFunction (ctx : Deriving.Context) (i : Nat) : TermElabM Command := do |
| 128 | + let auxFunName := ctx.auxFunNames[i]! |
| 129 | + let indVal := ctx.typeInfos[i]! |
| 130 | + let header ← mkToExprHeader indVal |
| 131 | + let mut body ← mkToExprBody header indVal auxFunName |
| 132 | + if ctx.usePartial then |
| 133 | + let letDecls ← mkLocalInstanceLetDecls ctx header.argNames |
| 134 | + body ← mkLet letDecls body |
| 135 | + -- We need to alter the last binder (the one for the "target") to have explicit universe levels |
| 136 | + -- so that the `ToLevel` instance arguments can use them. |
| 137 | + let addLevels binder := |
| 138 | + match binder with |
| 139 | + | `(bracketedBinderF| ($a : $ty)) => do `(bracketedBinderF| ($a : $(← fixIndType indVal ty))) |
| 140 | + | _ => throwError "(internal error) expecting inst binder" |
| 141 | + let binders := header.binders.pop |
| 142 | + ++ (← mkToLevelBinders indVal) |
| 143 | + ++ #[← addLevels header.binders.back] |
| 144 | + let levels := indVal.levelParams.toArray.map mkIdent |
| 145 | + if ctx.usePartial then |
| 146 | + `(private partial def $(mkIdent auxFunName):ident.{$levels,*} $binders:bracketedBinder* : |
| 147 | + Expr := $body:term) |
| 148 | + else |
| 149 | + `(private def $(mkIdent auxFunName):ident.{$levels,*} $binders:bracketedBinder* : |
| 150 | + Expr := $body:term) |
| 151 | + |
| 152 | +def mkMutualBlock (ctx : Deriving.Context) : TermElabM Syntax := do |
| 153 | + let mut auxDefs := #[] |
| 154 | + for i in [:ctx.typeInfos.size] do |
| 155 | + auxDefs := auxDefs.push (← mkAuxFunction ctx i) |
| 156 | + `(mutual $auxDefs:command* end) |
| 157 | + |
| 158 | +open TSyntax.Compat in |
| 159 | +def mkInstanceCmds (ctx : Deriving.Context) (typeNames : Array Name) : |
| 160 | + TermElabM (Array Command) := do |
| 161 | + let mut instances := #[] |
| 162 | + for i in [:ctx.typeInfos.size] do |
| 163 | + let indVal := ctx.typeInfos[i]! |
| 164 | + if typeNames.contains indVal.name then |
| 165 | + let auxFunName := ctx.auxFunNames[i]! |
| 166 | + let argNames ← mkInductArgNames indVal |
| 167 | + let binders ← mkImplicitBinders argNames |
| 168 | + let binders := binders ++ (← mkInstImplicitBinders ``ToExpr indVal argNames) |
| 169 | + let binders := binders ++ (← mkToLevelBinders indVal) |
| 170 | + let indType ← fixIndType indVal (← mkInductiveApp indVal argNames) |
| 171 | + let toTypeExpr ← mkToTypeExpr argNames indVal |
| 172 | + let levels := indVal.levelParams.toArray.map mkIdent |
| 173 | + let instCmd ← `(instance $binders:implicitBinder* : ToExpr $indType where |
| 174 | + toExpr := $(mkIdent auxFunName).{$levels,*} |
| 175 | + toTypeExpr := $toTypeExpr) |
| 176 | + instances := instances.push instCmd |
| 177 | + return instances |
| 178 | + |
| 179 | +def mkToExprInstanceCmds (declNames : Array Name) : TermElabM (Array Syntax) := do |
| 180 | + let ctx ← mkContext "toExpr" declNames[0]! |
| 181 | + let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx declNames) |
| 182 | + trace[Elab.Deriving.toExpr] "\n{cmds}" |
| 183 | + return cmds |
| 184 | + |
| 185 | +def mkToExprInstanceHandler (declNames : Array Name) : CommandElabM Bool := do |
| 186 | + if (← declNames.allM isInductive) && declNames.size > 0 then |
| 187 | + let cmds ← liftTermElabM <| mkToExprInstanceCmds declNames |
| 188 | + cmds.forM elabCommand |
| 189 | + return true |
| 190 | + else |
| 191 | + return false |
| 192 | + |
| 193 | +initialize |
| 194 | + registerDerivingHandler `Lean.ToExpr mkToExprInstanceHandler |
| 195 | + registerTraceClass `Elab.Deriving.toExpr |
| 196 | + |
| 197 | +end Mathlib.Deriving.ToExpr |
0 commit comments