/
Repr.lean
129 lines (117 loc) · 5.09 KB
/
Repr.lean
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
/-
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Meta.Transform
import Lean.Meta.Inductive
import Lean.Elab.Deriving.Basic
import Lean.Elab.Deriving.Util
namespace Lean.Elab.Deriving.Repr
open Lean.Parser.Term
open Meta
open Std
def mkReprHeader (indVal : InductiveVal) : TermElabM Header := do
let header ← mkHeader `Repr 1 indVal
return { header with
binders := header.binders.push (← `(bracketedBinderF| (prec : Nat)))
}
def mkBodyForStruct (header : Header) (indVal : InductiveVal) : TermElabM Term := do
let ctorVal ← getConstInfoCtor indVal.ctors.head!
let fieldNames := getStructureFields (← getEnv) indVal.name
let numParams := indVal.numParams
let target := mkIdent header.targetNames[0]!
forallTelescopeReducing ctorVal.type fun xs _ => do
let mut fields ← `(Format.nil)
if xs.size != numParams + fieldNames.size then
throwError "'deriving Repr' failed, unexpected number of fields in structure"
for i in [:fieldNames.size] do
let fieldName := fieldNames[i]!
let fieldNameLit := Syntax.mkStrLit (toString fieldName)
let x := xs[numParams + i]!
if i != 0 then
fields ← `($fields ++ "," ++ Format.line)
if (← isType x <||> isProof x) then
fields ← `($fields ++ $fieldNameLit ++ " := " ++ "_")
else
let indent := Syntax.mkNumLit <| toString ((toString fieldName |>.length) + " := ".length)
fields ← `($fields ++ $fieldNameLit ++ " := " ++ (Format.group (Format.nest $indent (repr ($target.$(mkIdent fieldName):ident)))))
`(Format.bracket "{ " $fields:term " }")
def mkBodyForInduct (header : Header) (indVal : InductiveVal) (auxFunName : Name) : TermElabM Term := do
let discrs ← mkDiscrs header indVal
let alts ← mkAlts
`(match $[$discrs],* with $alts:matchAlt*)
where
mkAlts : TermElabM (Array (TSyntax ``matchAlt)) := do
let mut alts := #[]
for ctorName in indVal.ctors do
let ctorInfo ← getConstInfoCtor ctorName
let alt ← forallTelescopeReducing ctorInfo.type fun xs _ => do
let mut patterns := #[]
-- add `_` pattern for indices
for _ in [:indVal.numIndices] do
patterns := patterns.push (← `(_))
let mut ctorArgs := #[]
let mut rhs : Term := Syntax.mkStrLit (toString ctorInfo.name)
rhs ← `(Format.text $rhs)
-- add `_` for inductive parameters, they are inaccessible
for _ in [:indVal.numParams] do
ctorArgs := ctorArgs.push (← `(_))
for i in [:ctorInfo.numFields] do
let x := xs[indVal.numParams + i]!
let a := mkIdent (← mkFreshUserName `a)
ctorArgs := ctorArgs.push a
let localDecl ← x.fvarId!.getDecl
if localDecl.binderInfo.isExplicit then
if (← inferType x).isAppOf indVal.name then
rhs ← `($rhs ++ Format.line ++ $(mkIdent auxFunName):ident $a:ident max_prec)
else
rhs ← `($rhs ++ Format.line ++ reprArg $a)
patterns := patterns.push (← `(@$(mkIdent ctorName):ident $ctorArgs:term*))
`(matchAltExpr| | $[$patterns:term],* => Repr.addAppParen (Format.group (Format.nest (if prec >= max_prec then 1 else 2) ($rhs:term))) prec)
alts := alts.push alt
return alts
def mkBody (header : Header) (indVal : InductiveVal) (auxFunName : Name) : TermElabM Term := do
if isStructure (← getEnv) indVal.name then
mkBodyForStruct header indVal
else
mkBodyForInduct header indVal auxFunName
def mkAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do
let auxFunName := ctx.auxFunNames[i]!
let indVal := ctx.typeInfos[i]!
let header ← mkReprHeader indVal
let mut body ← mkBody header indVal auxFunName
if ctx.usePartial then
let letDecls ← mkLocalInstanceLetDecls ctx `Repr header.argNames
body ← mkLet letDecls body
let binders := header.binders
if ctx.usePartial then
`(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Format := $body:term)
else
`(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Format := $body:term)
def mkMutualBlock (ctx : Context) : TermElabM Syntax := do
let mut auxDefs := #[]
for i in [:ctx.typeInfos.size] do
auxDefs := auxDefs.push (← mkAuxFunction ctx i)
`(mutual
$auxDefs:command*
end)
private def mkReprInstanceCmd (declName : Name) : TermElabM (Array Syntax) := do
let ctx ← mkContext "repr" declName
let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `Repr #[declName])
trace[Elab.Deriving.repr] "\n{cmds}"
return cmds
open Command
def mkReprInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
if (← declNames.allM isInductive) then
for declName in declNames do
let cmds ← liftTermElabM <| mkReprInstanceCmd declName
cmds.forM elabCommand
return true
else
return false
builtin_initialize
registerDerivingHandler `Repr mkReprInstanceHandler
registerTraceClass `Elab.Deriving.repr
end Lean.Elab.Deriving.Repr