-
Notifications
You must be signed in to change notification settings - Fork 437
/
Repr.lean
133 lines (121 loc) · 5.38 KB
/
Repr.lean
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
/-
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Meta.Transform
import Lean.Meta.Inductive
import Lean.Elab.Deriving.Basic
import Lean.Elab.Deriving.Util
namespace Lean.Elab.Deriving.Repr
open Lean.Parser.Term
open Meta
open Std
def mkReprHeader (indVal : InductiveVal) : TermElabM Header := do
let header ← mkHeader `Repr 1 indVal
return { header with
binders := header.binders.push (← `(bracketedBinderF| (prec : Nat)))
}
def mkBodyForStruct (header : Header) (indVal : InductiveVal) : TermElabM Term := do
let ctorVal ← getConstInfoCtor indVal.ctors.head!
let fieldNames := getStructureFields (← getEnv) indVal.name
let numParams := indVal.numParams
let target := mkIdent header.targetNames[0]!
forallTelescopeReducing ctorVal.type fun xs _ => do
let mut fields ← `(Format.nil)
if xs.size != numParams + fieldNames.size then
throwError "'deriving Repr' failed, unexpected number of fields in structure"
for h : i in [:fieldNames.size] do
let fieldName := fieldNames[i]
let fieldNameLit := Syntax.mkStrLit (toString fieldName)
let x := xs[numParams + i]!
if i != 0 then
fields ← `($fields ++ "," ++ Format.line)
if (← isType x <||> isProof x) then
fields ← `($fields ++ $fieldNameLit ++ " := " ++ "_")
else
let indent := Syntax.mkNumLit <| toString ((toString fieldName |>.length) + " := ".length)
fields ← `($fields ++ $fieldNameLit ++ " := " ++ (Format.group (Format.nest $indent (repr ($target.$(mkIdent fieldName):ident)))))
`(Format.bracket "{ " $fields:term " }")
def mkBodyForInduct (header : Header) (indVal : InductiveVal) (auxFunName : Name) : TermElabM Term := do
let discrs ← mkDiscrs header indVal
let alts ← mkAlts
`(match $[$discrs],* with $alts:matchAlt*)
where
mkAlts : TermElabM (Array (TSyntax ``matchAlt)) := do
let mut alts := #[]
for ctorName in indVal.ctors do
let ctorInfo ← getConstInfoCtor ctorName
let alt ← forallTelescopeReducing ctorInfo.type fun xs _ => do
let mut patterns := #[]
-- add `_` pattern for indices
for _ in [:indVal.numIndices] do
patterns := patterns.push (← `(_))
let mut ctorArgs := #[]
let mut rhs : Term := Syntax.mkStrLit (toString ctorInfo.name)
rhs ← `(Format.text $rhs)
for h : i in [:xs.size] do
-- Note: some inductive parameters are explicit if they were promoted from indices,
-- so we process all constructor arguments in the same loop.
let x := xs[i]
let a ← mkIdent <$> if i < indVal.numParams then pure header.argNames[i]! else mkFreshUserName `a
if i < indVal.numParams then
-- add `_` for inductive parameters, they are inaccessible
ctorArgs := ctorArgs.push (← `(_))
else
ctorArgs := ctorArgs.push a
if (← x.fvarId!.getBinderInfo).isExplicit then
if (← inferType x).isAppOf indVal.name then
rhs ← `($rhs ++ Format.line ++ $(mkIdent auxFunName):ident $a:ident max_prec)
else if (← isType x <||> isProof x) then
rhs ← `($rhs ++ Format.line ++ "_")
else
rhs ← `($rhs ++ Format.line ++ reprArg $a)
patterns := patterns.push (← `(@$(mkIdent ctorName):ident $ctorArgs:term*))
`(matchAltExpr| | $[$patterns:term],* => Repr.addAppParen (Format.group (Format.nest (if prec >= max_prec then 1 else 2) ($rhs:term))) prec)
alts := alts.push alt
return alts
def mkBody (header : Header) (indVal : InductiveVal) (auxFunName : Name) : TermElabM Term := do
if isStructure (← getEnv) indVal.name then
mkBodyForStruct header indVal
else
mkBodyForInduct header indVal auxFunName
def mkAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do
let auxFunName := ctx.auxFunNames[i]!
let indVal := ctx.typeInfos[i]!
let header ← mkReprHeader indVal
let mut body ← mkBody header indVal auxFunName
if ctx.usePartial then
let letDecls ← mkLocalInstanceLetDecls ctx `Repr header.argNames
body ← mkLet letDecls body
let binders := header.binders
if ctx.usePartial then
`(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Format := $body:term)
else
`(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Format := $body:term)
def mkMutualBlock (ctx : Context) : TermElabM Syntax := do
let mut auxDefs := #[]
for i in [:ctx.typeInfos.size] do
auxDefs := auxDefs.push (← mkAuxFunction ctx i)
`(mutual
$auxDefs:command*
end)
private def mkReprInstanceCmd (declName : Name) : TermElabM (Array Syntax) := do
let ctx ← mkContext "repr" declName
let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `Repr #[declName])
trace[Elab.Deriving.repr] "\n{cmds}"
return cmds
open Command
def mkReprInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
if (← declNames.allM isInductive) then
for declName in declNames do
let cmds ← liftTermElabM <| mkReprInstanceCmd declName
cmds.forM elabCommand
return true
else
return false
builtin_initialize
registerDerivingHandler `Repr mkReprInstanceHandler
registerTraceClass `Elab.Deriving.repr
end Lean.Elab.Deriving.Repr