This repository has been archived by the owner on Feb 19, 2020. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 82
/
OracleParser.scala
241 lines (201 loc) · 9.97 KB
/
OracleParser.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
package epic.parser.projections
import epic.constraints.{CachedChartConstraintsFactory, ChartConstraints}
import epic.trees._
import epic.parser._
import breeze.numerics.I
import epic.util.{Optional, CacheBroker}
import epic.parser.GrammarAnchoring.StructureDelegatingAnchoring
import breeze.config.{Configuration, CommandLineParser, Help}
import java.io.File
import breeze.util._
import epic.parser.ParseExtractionException
import epic.trees.ProcessedTreebank
import epic.trees.TreeInstance
import epic.parser.ViterbiDecoder
import epic.parser.ParserParams.XbarGrammar
import breeze.collection.mutable.TriangularArray
import epic.constraints.ChartConstraints.UnifiedFactory
/**
* Finds the best tree (relative to the gold tree) s.t. it's reacheable given the current anchoring.
* Best is measured as number of correct labeled spans, as usual. If the given treebank symbol is correct,
* bonus points can be awarded for getting the right refinement.
*
* On the training set, the "best" reachable tree will not always (~5% of the time) be the correct tree,
* because pruning will remove the right answer. We don't want to try to train towards an unreachable tree,
* because the training algorithm will do bad things. Instead, we want the best possible tree that
* our parser could conceivably produce. That is why this class exists.
*
* If backupGrammar is provided, it will be used to find such a tree in the case that no tree can be
* found with grammar (given the current constraints).
*
* Typically, the first grammar will be a treebank grammar that has no horizontal markovization (i.e.
* it is not forgetfully binarized) and it also remembers the functional tags like -TMP.
* The backup grammar is usually the grammar
* with which the pruning masks were produced; because of the way we prune, that parser will always
* be able to find a tree (assuming that it was able to find a tree without pruning.)
*
* TODO: should be a cascade of grammars
*
* @author dlwh
*/
class OracleParser[L, L2, W](val grammar: SimpleGrammar[L, L2, W], backupGrammar: Optional[SimpleGrammar[L, L2, W]] = None) extends SerializableLogging {
private val cache = CacheBroker().make[IndexedSeq[W], BinarizedTree[L2]]("OracleParser")
private var problems = 0
private var total = 0
def forTree(tree: BinarizedTree[L2],
words: IndexedSeq[W],
constraints: ChartConstraints[L]) = try {
val projectedTree: BinarizedTree[L] = tree.map(grammar.refinements.labels.project)
cache.getOrElseUpdate(words, {
val treeconstraints = ChartConstraints.fromTree(grammar.topology.labelIndex, projectedTree)
if (constraints.top.containsAll(treeconstraints.top) && constraints.bot.containsAll(treeconstraints.bot)) {
synchronized(total += 1)
tree
} else try {
logger.warn {
val ratio = synchronized {
problems += 1
total += 1
problems * 1.0 / total
}
f"Gold tree for $words not reachable. $ratio%.2f are bad so far. "
}
(IndexedSeq(grammar) ++ backupGrammar.iterator).iterator.flatMap { grammar =>
try {
val w = words
val marg = makeGoldPromotingAnchoring(grammar, w, tree, treeconstraints, constraints).maxMarginal
val closest: BinarizedTree[(L, Int)] = new ViterbiDecoder[L, W]().extractMaxDerivationParse(marg)
val globalizedClosest: BinarizedTree[L2] = closest.map({
grammar.refinements.labels.globalize(_: L, _: Int)
}.tupled)
logger.trace {
s"Reachable: ${globalizedClosest.render(words, newline = true)}\nGold:${tree.render(words, newline = true)}"
}
Some(globalizedClosest)
} catch {
case ex: ParseExtractionException => None
}
}.next()
} catch {
case ex:NoSuchElementException =>
logger.error("Couldn't find a tree for $words")
tree
}
})
} catch {
case ex: Exception =>
logger.error(s"while handling projectability for $tree $words: " + ex.getMessage, ex)
throw ex
}
def makeGoldPromotingAnchoring(grammar: SimpleGrammar[L, L2, W],
w: IndexedSeq[W],
tree: BinarizedTree[L2],
treeconstraints: ChartConstraints[L],
constraints: ChartConstraints[L]): GrammarAnchoring[L, W] = {
import grammar.{refinements => _, _}
val correctRefinedSpans = GoldTagPolicy.goldTreeForcing(tree.map(grammar.refinements.labels.fineIndex))
val correctUnaryChains = {
val arr = TriangularArray.fill(w.length + 1)(IndexedSeq.empty[String])
for(UnaryTree(a, btree, chain, span) <- tree.allChildren) {
arr(span.begin, span.end) = chain
}
arr
}
new StructureDelegatingAnchoring[L, W] {
protected val baseAnchoring: GrammarAnchoring[L, W] = grammar.anchor(w)
override def addConstraints(cs: ChartConstraints[L]): GrammarAnchoring[L, W] = {
makeGoldPromotingAnchoring(grammar, w, tree, treeconstraints, constraints & cs)
}
override def sparsityPattern: ChartConstraints[L] = constraints
def scoreBinaryRule(begin: Int, split: Int, end: Int, rule: Int, ref: Int): Double = {
0.1 * baseAnchoring.scoreBinaryRule(begin, split, end, rule, ref)
}
def scoreUnaryRule(begin: Int, end: Int, rule: Int, ref: Int): Double = {
val top = topology.parent(rule)
val isAllowed = treeconstraints.top.isAllowedLabeledSpan(begin, end, top)
val isRightRefined = isAllowed && {
val parentRef = baseAnchoring.parentRefinement(rule, ref)
val refTop = grammar.refinements.labels.globalize(top, parentRef)
correctRefinedSpans.isGoldTopTag(begin, end, refTop)
}
val isCorrectChain = {
isAllowed && topology.chain(rule) == correctUnaryChains(begin, end)
}
200 * I(isAllowed) + 10 * I(isRightRefined) + 5 * I(isCorrectChain) + 0.1 * baseAnchoring.scoreUnaryRule(begin, end, rule, ref)
}
def scoreSpan(begin: Int, end: Int, tag: Int, ref: Int): Double = {
val globalized = grammar.refinements.labels.globalize(tag, ref)
200 * I(treeconstraints.bot.isAllowedLabeledSpan(begin, end, tag)) +
10 * I(correctRefinedSpans.isGoldBotTag(begin, end, globalized)) + 0.1 * baseAnchoring.scoreSpan(begin, end, tag, ref)
}
}
}
def oracleMarginalFactory(trees: IndexedSeq[TreeInstance[L2, W]]):ParseMarginal.Factory[L, W] = new ParseMarginal.Factory[L, W] {
val knownTrees = trees.iterator.map(ti => ti.words -> ti.tree).toMap
def apply(w: IndexedSeq[W], constraints: ChartConstraints[L]): RefinedChartMarginal[L, W] = {
val refAnchoring = knownTrees.get(w).map { t =>
val projectedTree: BinarizedTree[L] = t.map(grammar.refinements.labels.project)
val treeconstraints = ChartConstraints.fromTree(grammar.topology.labelIndex, projectedTree)
makeGoldPromotingAnchoring(grammar, w, t, treeconstraints, constraints)
}.getOrElse(grammar.anchor(w))
RefinedChartMarginal(refAnchoring, true)
}
}
def oracleParser(constraintGrammar: ChartConstraints.Factory[L, W], trees: IndexedSeq[TreeInstance[L2, W]])(implicit deb: Debinarizer[L]): Parser[L, W] = {
new Parser(grammar.topology, grammar.lexicon, constraintGrammar, oracleMarginalFactory(trees), ViterbiDecoder())
}
}
object OracleParser {
case class Params(treebank: ProcessedTreebank,
@Help(text="Prefix for the name of the eval directory. Sentences are dumped for EVALB to read.")
name: String,
@Help(text="Path to the parser file. Look in parsers/")
parser: File = null,
implicit val cache: CacheBroker,
@Help(text="Print this and exit.")
help: Boolean = false,
@Help(text="How many threads to parse with. Default is whatever Scala wants")
threads: Int = -1)
def main(args: Array[String]) {
val params = CommandLineParser.readIn[Params](args)
println("Command line arguments for recovery:\n" + Configuration.fromObject(params).toCommandLineString)
println("Evaluating Parser...")
import params.treebank._
val ann = GenerativeParser.defaultAnnotator()
val initialParser = params.parser match {
case null =>
val (grammar, lexicon) = XbarGrammar().xbarGrammar(trainTrees)
GenerativeParser.annotatedParser(grammar, lexicon, ann, trainTrees)
// GenerativeParser.annotatedParser(grammar, lexicon, Xbarize(), trainTrees)
case f =>
readObject[Parser[AnnotatedLabel, String]](f)
}
val constraints: ChartConstraints.Factory[AnnotatedLabel, String] = new UnifiedFactory({
val maxMarginalized = initialParser.copy(marginalFactory=initialParser.marginalFactory match {
case StandardChartFactory(ref, mm) => StandardChartFactory(ref, maxMarginal = true)
case x => x
})
val uncached = new ParserChartConstraintsFactory[AnnotatedLabel, String](maxMarginalized, {(_:AnnotatedLabel).isIntermediate})
new CachedChartConstraintsFactory[AnnotatedLabel, String](uncached)(params.cache)
})
val refGrammar = GenerativeParser.annotated(initialParser.topology, initialParser.lexicon, ann, trainTrees)
val annDevTrees = devTrees.map(ann)
val parser = {
new OracleParser(refGrammar).oracleParser(constraints, annDevTrees)
}
for(ti <- devTrees.par) try {
val m = TreeMarginal(refGrammar, ti.words, ann.localized(refGrammar.refinements.labels)(ti.tree, ti.words))
println(m.logPartition)
} catch {
case e: Exception => e.printStackTrace()
}
val name = params.name
println("Parser " + name)
{
println("Evaluating Parser on dev...")
val stats = ParserTester.evalParser(devTrees, parser, name+ "-dev", params.threads)
println("Eval finished. Results:")
println(stats)
}
}
}