-
Notifications
You must be signed in to change notification settings - Fork 345
/
AlphaEqv.lean
137 lines (117 loc) · 4.64 KB
/
AlphaEqv.lean
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
/-
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Compiler.LCNF.Basic
namespace Lean.Compiler.LCNF
/-!
Alpha equivalence for LCNF Code
-/
namespace AlphaEqv
abbrev EqvM := ReaderM (FVarIdMap FVarId)
def eqvFVar (fvarId₁ fvarId₂ : FVarId) : EqvM Bool := do
let fvarId₂ := (← read).find? fvarId₂ |>.getD fvarId₂
return fvarId₁ == fvarId₂
def eqvType (e₁ e₂ : Expr) : EqvM Bool := do
match e₁, e₂ with
| .app f₁ a₁, .app f₂ a₂ => eqvType a₁ a₂ <&&> eqvType f₁ f₂
| .fvar fvarId₁, .fvar fvarId₂ => eqvFVar fvarId₁ fvarId₂
| .forallE _ d₁ b₁ _, .forallE _ d₂ b₂ _ => eqvType d₁ d₂ <&&> eqvType b₁ b₂
| _, _ => return e₁ == e₂
def eqvTypes (es₁ es₂ : Array Expr) : EqvM Bool := do
if es₁.size = es₂.size then
for e₁ in es₁, e₂ in es₂ do
unless (← eqvType e₁ e₂) do
return false
return true
else
return false
def eqvArg (a₁ a₂ : Arg) : EqvM Bool := do
match a₁, a₂ with
| .type e₁, .type e₂ => eqvType e₁ e₂
| .fvar x₁, .fvar x₂ => eqvFVar x₁ x₂
| .erased, .erased => return true
| _, _ => return false
def eqvArgs (as₁ as₂ : Array Arg) : EqvM Bool := do
if as₁.size = as₂.size then
for a₁ in as₁, a₂ in as₂ do
unless (← eqvArg a₁ a₂) do
return false
return true
else
return false
def eqvLetValue (e₁ e₂ : LetValue) : EqvM Bool := do
match e₁, e₂ with
| .value v₁, .value v₂ => return v₁ == v₂
| .erased, .erased => return true
| .proj s₁ i₁ x₁, .proj s₂ i₂ x₂ => pure (s₁ == s₂ && i₁ == i₂) <&&> eqvFVar x₁ x₂
| .const n₁ us₁ as₁, .const n₂ us₂ as₂ => pure (n₁ == n₂ && us₁ == us₂) <&&> eqvArgs as₁ as₂
| .fvar f₁ as₁, .fvar f₂ as₂ => eqvFVar f₁ f₂ <&&> eqvArgs as₁ as₂
| _, _ => return false
@[inline] def withFVar (fvarId₁ fvarId₂ : FVarId) (x : EqvM α) : EqvM α :=
withReader (·.insert fvarId₂ fvarId₁) x
@[inline] def withParams (params₁ params₂ : Array Param) (x : EqvM Bool) : EqvM Bool := do
if h : params₂.size = params₁.size then
let rec @[specialize] go (i : Nat) : EqvM Bool := do
if h : i < params₁.size then
let p₁ := params₁[i]
have : i < params₂.size := by simp_all_arith
let p₂ := params₂[i]
unless (← eqvType p₁.type p₂.type) do return false
withFVar p₁.fvarId p₂.fvarId do
go (i+1)
else
x
termination_by params₁.size - i
go 0
else
return false
def sortAlts (alts : Array Alt) : Array Alt :=
alts.qsort fun
| .alt .., .default .. => true
| .alt ctorName₁ .., .alt ctorName₂ .. => Name.lt ctorName₁ ctorName₂
| _, _ => false
mutual
partial def eqvAlts (alts₁ alts₂ : Array Alt) : EqvM Bool := do
if alts₁.size = alts₂.size then
let alts₁ := sortAlts alts₁
let alts₂ := sortAlts alts₂
for alt₁ in alts₁, alt₂ in alts₂ do
match alt₁, alt₂ with
| .alt ctorName₁ ps₁ k₁, .alt ctorName₂ ps₂ k₂ =>
unless ctorName₁ == ctorName₂ do return false
unless (← withParams ps₁ ps₂ (eqv k₁ k₂)) do return false
| .default k₁, .default k₂ => unless (← eqv k₁ k₂) do return false
| _, _ => return false
return true
else
return false
partial def eqv (code₁ code₂ : Code) : EqvM Bool := do
match code₁, code₂ with
| .let decl₁ k₁, .let decl₂ k₂ =>
eqvType decl₁.type decl₂.type <&&>
eqvLetValue decl₁.value decl₂.value <&&>
withFVar decl₁.fvarId decl₂.fvarId (eqv k₁ k₂)
| .fun decl₁ k₁, .fun decl₂ k₂
| .jp decl₁ k₁, .jp decl₂ k₂ =>
eqvType decl₁.type decl₂.type <&&>
withParams decl₁.params decl₂.params (eqv decl₁.value decl₂.value) <&&>
withFVar decl₁.fvarId decl₂.fvarId (eqv k₁ k₂)
| .return fvarId₁, .return fvarId₂ => eqvFVar fvarId₁ fvarId₂
| .unreach type₁, .unreach type₂ => eqvType type₁ type₂
| .jmp fvarId₁ args₁, .jmp fvarId₂ args₂ => eqvFVar fvarId₁ fvarId₂ <&&> eqvArgs args₁ args₂
| .cases c₁, .cases c₂ =>
eqvFVar c₁.discr c₂.discr <&&>
eqvType c₁.resultType c₂.resultType <&&>
eqvAlts c₁.alts c₂.alts
| _, _ => return false
end
end AlphaEqv
/--
Return `true` if `c₁` and `c₂` are alpha equivalent.
-/
def Code.alphaEqv (c₁ c₂ : Code) : Bool :=
AlphaEqv.eqv c₁ c₂ |>.run {}
end Lean.Compiler.LCNF