-
Notifications
You must be signed in to change notification settings - Fork 18
/
BetaReduction.scala
140 lines (118 loc) · 4.3 KB
/
BetaReduction.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
package gapt.expr
import gapt.expr.formula._
import gapt.expr.formula.hol.universalClosure
import gapt.expr.subst.Substitution
import gapt.expr.util.freeVariables
import gapt.expr.util.syntacticMatching
import gapt.proofs.context.Context
import scala.annotation.tailrec
import scala.collection.mutable
case class ReductionRule( lhs: Expr, rhs: Expr ) {
require( lhs.ty == rhs.ty )
require(
freeVariables( rhs ).subsetOf( freeVariables( lhs ) ),
s"Right-hand side of rule contains variables ${
freeVariables( rhs ) -- freeVariables( lhs ) mkString ", "
} which are not in the left hand side:\n"
+ ( lhs === rhs ) )
val Apps( lhsHead @ Const( lhsHeadName, _, _ ), lhsArgs ) = lhs
val lhsArgsSize: Int = lhsArgs.size
val isNonLinear: Boolean = {
val seen = mutable.Set[Var]()
def go( e: Expr ): Boolean =
e match {
case App( a, b ) => go( a ) || go( b )
case Abs( _, _ ) => true
case Const( _, _, _ ) => false
case v: Var =>
seen( v ) || { seen += v; false }
}
go( lhs )
}
val nonVarArgs: Set[Int] =
lhsArgs.zipWithIndex.filterNot( _._1.isInstanceOf[Var] ).map( _._2 ).toSet
val structuralRecArgs: Set[Int] =
lhsArgs.zipWithIndex.collect {
case ( Apps( _: Const, vs ), i ) if vs.forall( _.isInstanceOf[Var] ) =>
i
}.toSet
val normalizeArgs: Set[Int] =
if ( isNonLinear ) lhsArgs.indices.toSet else nonVarArgs -- structuralRecArgs
val whnfArgs: Set[Int] =
structuralRecArgs -- normalizeArgs
}
object ReductionRule {
implicit def apply( rule: ( Expr, Expr ) ): ReductionRule =
ReductionRule( rule._1, rule._2 )
implicit def apply( atom: Formula ): ReductionRule = {
val Eq( lhs, rhs ) = atom
ReductionRule( lhs, rhs )
}
}
case class Normalizer( rules: Set[ReductionRule] ) {
val headMap: Map[String, ( Set[ReductionRule], Set[Int], Set[Int] )] = Map() ++ rules.groupBy( _.lhsHeadName ).view.mapValues { rs =>
val normalizeArgs = rs.flatMap( _.normalizeArgs )
val whnfArgs = rs.flatMap( _.whnfArgs ) -- normalizeArgs
( rs, whnfArgs, normalizeArgs )
}.toMap
def +( rule: ReductionRule ): Normalizer =
Normalizer( rules + rule )
def toFormula: Formula = And( rules.map { case ReductionRule( lhs, rhs ) => universalClosure( lhs === rhs ) } )
def normalize( expr: Expr ): Expr = {
val Apps( hd_, as_ ) = whnf( expr )
Apps( hd_ match {
case Abs.Block( xs, e ) if xs.nonEmpty =>
Abs.Block( xs, normalize( e ) )
case _ => hd_
}, as_.map( normalize ) )
}
@tailrec
final def whnf( expr: Expr ): Expr =
reduce1( expr ) match {
case Some( expr_ ) => whnf( expr_ )
case None => expr
}
def reduce1( expr: Expr ): Option[Expr] = {
val Apps( hd, as ) = expr
hd match {
case Abs.Block( vs, hd_ ) if vs.nonEmpty && as.nonEmpty =>
val n = math.min( as.size, vs.size )
Some( Apps( Substitution( vs.take( n ) zip as.take( n ) )( Abs.Block( vs.drop( n ), hd_ ) ), as.drop( n ) ) )
case hd @ Const( c, _, _ ) =>
headMap.get( c ).flatMap {
case ( rs, whnfArgs, normalizeArgs ) =>
val as_ = as.zipWithIndex.map {
case ( a, i ) if whnfArgs( i ) => whnf( a )
case ( a, i ) if normalizeArgs( i ) => normalize( a )
case ( a, _ ) => a
}
rs.view.flatMap { r =>
syntacticMatching( r.lhs, Apps( hd, as_.take( r.lhsArgsSize ) ) ).map { subst =>
Apps( subst( r.rhs ), as_.drop( r.lhsArgsSize ) )
}
}.headOption
}
case _ =>
None
}
}
def isDefEq( a: Expr, b: Expr ): Boolean =
normalize( a ) == normalize( b )
}
object Normalizer {
def apply( rules: Iterable[ReductionRule] ): Normalizer =
Normalizer( rules.toSet )
def apply( rules: ReductionRule* ): Normalizer =
Normalizer( rules )
}
object normalize {
def apply( expr: Expr )( implicit ctx: Context = null ): Expr =
if ( ctx == null ) BetaReduction.normalize( expr )
else ctx.normalizer.normalize( expr )
}
object BetaReduction extends Normalizer( Set() ) {
def betaNormalize( expression: Expr ): Expr =
normalize( expression )
def betaNormalize( f: Formula ): Formula =
betaNormalize( f: Expr ).asInstanceOf[Formula]
}