Permalink
Browse files

Faster CKY parsing

  • Loading branch information...
1 parent 43682c1 commit 0d326ecc2366e61a5fa9e6f5a0458ad3631ad351 @gangeli committed Apr 6, 2012
View
12 src/org/goobs/exec/Log.java
@@ -3,14 +3,14 @@
*/
package org.goobs.exec;
+import org.goobs.util.Stopwatch;
+
+import java.io.File;
+import java.io.FileWriter;
+import java.io.IOException;
import java.util.Stack;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
-import java.io.IOException;
-import java.io.FileWriter;
-import java.io.File;
-
-import org.goobs.util.Stopwatch;
public final class Log {
@@ -130,7 +130,7 @@ private boolean shouldPrint(String toPrint, boolean force){
@Option(gloss="Print debugging log entries")
- private static boolean logDebug = false;
+ public static boolean logDebug = false;
private static final Stack <LogInfo> levels = new Stack<LogInfo>();
private static LogInfo currentInfo = null;
View
125 src/org/goobs/nlp/CKYParser.scala
@@ -1332,13 +1332,21 @@ class CKYParser (
//-----
// Values
//-----
- private val binaryRules:Array[CKYRule]
- = ruleProbIndex.keys.filter{ !_.isUnary }.toArray
- private val unaryRules:Array[CKYUnary]
+ private val binaryRules:Array[(CKYBinary,(Int,Int),Int,Int)]
+ = ruleProbIndex.keys
+ .filter{ !_.isUnary }
+ .map{ case (r:CKYBinary) =>
+ (r,ruleProbIndex(r), nodeTypeIndex(r.leftChild), nodeTypeIndex(r.rightChild)) }
+ .toArray
+ private val unaryRules:Array[(CKYUnary,(Int,Int),Int)]
= ruleProbIndex.keys
.filter{ _.isUnary }
- .map{ _.asInstanceOf[CKYUnary] }
+ .map{ case (r:CKYUnary) =>
+ (r,ruleProbIndex(r), nodeTypeIndex(r.child)) }
.toArray
+ private val closureRules:Array[(CKYUnary,Array[(Int,Int)],Int)]
+ = closures.map{ case (r:CKYClosure) =>
+ (r, r.chain.map{ u => ruleProbIndex(u) }.toArray, nodeTypeIndex(r.child)) }
//-----
@@ -2094,10 +2102,13 @@ class CKYParser (
}
def ruleProb(rule:CKYRule):Double = {
val (head,index) = ruleProbIndex(rule)
+ ruleProb(head,index)
+ }
+ def ruleProb(head:Int,index:Int):Double = {
assert(!ruleProb(head).prob(index).isNaN,"NaN rule probability")
ruleProb(head).prob(index)
}
- def ruleLogProb(rule:CKYRule):Double = safeLn(ruleProb(rule))
+ def ruleLogProb(head:Int,index:Int):Double = safeLn(ruleProb(head,index))
def lexProb(rule:CKYUnary,w:Int,normalizeTo:Double=1.0):Double = {
if(w >= 0 && w < numWords) {
assert(!lexProb(lexProbIndex(rule)).prob(w).isNaN, "NaN lex probability")
@@ -2110,11 +2121,9 @@ class CKYParser (
def lexLogProb(rule:CKYUnary,w:Int,normalizeTo:Double=1.0):Double = safeLn(lexProb(rule,w,normalizeTo))
def lexProbs(rule:CKYUnary):Array[Double] = {
- lexProbDomain
- .zipWithIndex
- .map{ case (rule:CKYUnary,w:Int) =>
+ (0 until numWords).map{ (w:Int) =>
lexProb(rule,w)
- }
+ }.toArray
}
def sortedLexProbs(rule:CKYUnary):Array[(Double,Int)] = {
@@ -2243,10 +2252,11 @@ class CKYParser (
/**
Access a chart element
*/
- private def gram(chart:Chart,begin:Int,end:Int,parent:NodeType,t:Int
+ private def gram(chart:Chart,begin:Int,end:Int,parent:NodeType,t:Int):Beam
+ = gram(chart,begin,end, nodeTypeIndex(parent), t)
+ private def gram(chart:Chart,begin:Int,end:Int,head:Int,t:Int
):Beam = {
- val head:Int = nodeTypeIndex(parent)
- if(end == begin+1){ return lex(chart,begin,parent,t) }
+ if(end == begin+1){ return lex(chart,begin,head,t) }
//(asserts)
assert(end > begin+1, "Chart access error: bad end: " + begin + ", " + end)
assert(begin >= 0, "Chart access error: negative values: " + begin)
@@ -2256,13 +2266,20 @@ class CKYParser (
//(access)
chart(begin)(end-begin-1)(t)(head)
}
+ private def gram(chart:Chart,begin:Int,end:Int,parent:NodeType):Array[ChartElem] = {
+ Array.concat(
+ gram(chart,begin,end,parent,CKYParser.UNARY).toArray,
+ gram(chart,begin,end,parent,CKYParser.BINARY).toArray
+ ).sortWith{ (a:ChartElem,b:ChartElem) => a.logProb > b.logProb }
+ }
+
/**
Access a lexical element
*/
- private def lex(chart:Chart,elem:Int,parent:NodeType,t:Int=CKYParser.BINARY
- ):Beam = {
- val head:Int = nodeTypeIndex(parent)
- //(asserts)
+ private def lex(chart:Chart,elem:Int,parent:NodeType,t:Int=CKYParser.BINARY):Beam = {
+ lex(chart,elem,nodeTypeIndex(parent),t);
+ }
+ private def lex(chart:Chart,elem:Int,head:Int,t:Int):Beam = {
assert(elem >= 0, "Chart access error: negative value: " + elem)
assert(head >= 0, "Chart access error: bad head: " + head)
assert(head < index2NodeType.length, "Chart access error: bad head: "+head)
@@ -2273,7 +2290,7 @@ class CKYParser (
/**
k-best CKY Algorithm implementation
*/
- def parse(sent:Sentence, beam:Int):Array[EvalTree[Any]] = {
+ private def cky(sent:Sentence, beam:Int):Chart = {
//--Asserts
assert(sent.length > 0, "Sentence of length 0 cannot be parsed")
//--Get Lexical Entries
@@ -2308,41 +2325,44 @@ class CKYParser (
lastLogProb = logProb
}
}
+ //--Indexing
//--Grammar
for(length <- 1 to sent.length) { // length
for(begin <- 0 to sent.length-length) { // begin
//(overhead)
val end:Int = begin+length
assert(end <= sent.length, "end is out of bounds")
- def addUnary(term:CKYUnary, ruleLProb:Double){
+ def addUnary(term:CKYUnary, child:Int, ruleLProb:Double){
assert(ruleLProb <= 0.0, "Log probability of >0: " + ruleLProb)
assert(term.isUnary, "Unary rule should be unary")
- val child:Beam = gram(chart,begin,end,term.child,CKYParser.BINARY)
- gram(chart,begin,end,term.parent,CKYParser.UNARY).combine(term,child,
+ val childBeam:Beam = gram(chart,begin,end,child,CKYParser.BINARY)
+ gram(chart,begin,end,term.parent,CKYParser.UNARY).combine(term,childBeam,
(left:ChartElem,right:ChartElem) => { ruleLProb })
}
//(unaries)
if(length == 1){
- unaryRules.foreach{ u => addUnary(u,ruleLogProb(u)) }
- closures.foreach{ c => addUnary(c,c.logProb(ruleLogProb(_))) }
+ unaryRules.foreach{ case (term,(head,index),child) =>
+ addUnary(term, child,ruleLogProb(head,index))
+ }
+ closureRules.foreach{ case (term,chain,child) =>
+ addUnary( term, child, chain.map{ case (h,i) => ruleLogProb(h,i) }.sum )
+ }
}
//(binaries)
- binaryRules.foreach{ (term:CKYRule) => // rules [binary]
- val ruleLProb:Double = ruleLogProb(term)
+ binaryRules.foreach{ case (term,(parent,index),childLeft,childRight) => // rules [binary]
+ val ruleLProb:Double = ruleLogProb(parent,index)
assert(ruleLProb <= 0.0, "Log probability of >0: " + ruleLProb)
assert(!term.isUnary, "Binary rule should be binary")
- val ruleLeft = term.leftChild
- val ruleRight = term.rightChild
for(split <- (begin+1) to (end-1)){ // splits
//((get variables))
val leftU:Beam
- = gram(chart, begin,split,ruleLeft, CKYParser.UNARY)
+ = gram(chart, begin,split,childLeft, CKYParser.UNARY)
val rightU:Beam
- = gram(chart,split,end, ruleRight,CKYParser.UNARY)
+ = gram(chart,split,end, childRight,CKYParser.UNARY)
val leftB:Beam
- = gram(chart, begin,split,ruleLeft, CKYParser.BINARY)
+ = gram(chart, begin,split,childLeft, CKYParser.BINARY)
val rightB:Beam
- = gram(chart,split,end, ruleRight,CKYParser.BINARY)
+ = gram(chart,split,end, childRight,CKYParser.BINARY)
assert(leftU != leftB && rightU != rightB, ""+begin+" to "+end)
val output = gram(chart,begin,end,term.parent,CKYParser.BINARY)
val score = (left:ChartElem,right:ChartElem) => { ruleLProb }
@@ -2361,8 +2381,12 @@ class CKYParser (
}
//(unaries)
if(length > 1){
- unaryRules.foreach{ u => addUnary(u,ruleLogProb(u)) }
- closures.foreach{ c => addUnary(c,c.logProb(ruleLogProb(_))) }
+ unaryRules.foreach{ case (term,(head,index),child) =>
+ addUnary(term, child,ruleLogProb(head,index))
+ }
+ closureRules.foreach{ case (term,chain,child) =>
+ addUnary( term, child, chain.map{ case (h,i) => ruleLogProb(h,i) }.sum )
+ }
}
//(post-update tasks)
if(kbestCKYAlgorithm < 3) {
@@ -2374,10 +2398,14 @@ class CKYParser (
}
}
//--Return
- Array.concat(
- gram(chart,0,sent.length,factory.ROOT,CKYParser.UNARY).toArray,
- gram(chart,0,sent.length,factory.ROOT,CKYParser.BINARY).toArray
- ).map{ x => x.deepclone(sent) }
+ chart
+ }
+
+ def parse(sent:Sentence,beam:Int):Array[EvalTree[Any]] = {
+ //--Fill Chart
+ val chart = cky(sent,beam);
+ //--Return
+ gram(chart,0,sent.length,factory.ROOT).map{ x => x.deepclone(sent) }
}
def parse(sent:Sentence):EvalTree[Any] = {
@@ -2394,6 +2422,31 @@ class CKYParser (
def apply(sent:Sentence):EvalTree[Any] = parse(sent)
+ def chart(sent:Sentence, b:Int=0, e:Int=Int.MaxValue, nodes:Iterable[NodeType]=Nil
+ ):scala.collection.immutable.Map[(NodeType,Int,Int),Double] = {
+ //--Fill Chart
+ val crt = cky(sent,1);
+ //--Create Return
+ val nodesToReturn:Iterable[NodeType] = if(nodes.isEmpty){ factory.all } else { nodes }
+ //(for each begin)
+ (b until math.min(sent.length,e)).flatMap{ (begin:Int) =>
+ //(for each end)
+ (begin+1 to math.min(sent.length,e)).flatMap{ (end:Int) =>
+ //(for each node type)
+ nodesToReturn.map{ (node:NodeType) =>
+ val cands:Array[ChartElem] = gram(crt,begin,end,node);
+ if(cands.length > 0){
+ //(case: have entry)
+ Some(((node,begin,end),cands(0).logProb))
+ } else {
+ //(case: no entry)
+ None
+ }
+ }.filter{ _.isDefined }.map{ _.get }
+ }
+ }.toMap
+ }
+
//-----
// EM
//-----
View
8 src/org/goobs/stanford/StanfordExecutionLogInterface.java
@@ -1,11 +1,11 @@
package org.goobs.stanford;
import edu.stanford.nlp.util.logging.Redwood;
+import edu.stanford.nlp.util.logging.Redwood.Util;
import edu.stanford.nlp.util.logging.StanfordRedwoodConfiguration;
import org.goobs.exec.Execution;
import org.goobs.exec.ExitCode;
-
-import edu.stanford.nlp.util.logging.Redwood.Util;
+import org.goobs.exec.Log;
import java.io.IOException;
import java.util.Properties;
@@ -31,7 +31,9 @@ public void setup(){
//(init stanford)
StanfordRedwoodConfiguration.apply(props);
//(tweaks)
- Redwood.hideOnlyChannels(Redwood.DBG);
+ if(!Log.logDebug){
+ Redwood.hideOnlyChannels(Redwood.DBG);
+ }
}
@Override
View
63 test/src/org/goobs/tests/CKYTest.scala
@@ -8,7 +8,7 @@ import org.scalatest.Spec
import org.scalatest.matchers.ShouldMatchers
import org.goobs.nlp._
import java.io._
-import org.goobs.util.{TrackedObjectOutputStream, SingletonIterator}
+import org.goobs.util.{TrackedObjectOutputStream, SingletonIterator, Stopwatch}
import org.goobs.stats._
object Grammars {
@@ -581,7 +581,6 @@ class CKYParserSpec extends Spec with ShouldMatchers {
val dir = Dirichlet.fromMap(lexPrior)
val mult = dir.posterior( new Multinomial[java.lang.Integer](CountStores.ARRAY(mathtyp2str.length)).initUniform().asInstanceOf[Multinomial[Int]] )
mult.prob(0) should be (1.0 / 13.0)
- println(mult)
val fixedParser = CKYParser(mathtyp2str.length, MATH_TYPE.map{ (_, 0.0) }, NodeType.defaultFactory,
(n:NodeType) => Dirichlet.fromMap(lexPrior))
//(test lexProb)
@@ -827,4 +826,64 @@ class CKYParserSpec extends Spec with ShouldMatchers {
}
}
}
+
+ describe("Long Sentences"){
+ it("[9] should parse < 0.005s"){
+ //(variables)
+ val s9 = MSent("1 + 2 + 3 + 4 + 5")
+ var parser = CKYParser.apply(math2str.length, MATH.map{ (_,0.0) });
+ //(parse)
+ val watch = Stopwatch.time()
+ for( i <- 0 until 10){
+ val parse = parser.parse(s9);
+ watch.lap.toInt should be < 5
+ }
+ }
+ it("[19] should parse < 0.02s"){
+ //(variables)
+ val s19 = MSent("1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10")
+ var parser = CKYParser.apply(math2str.length, MATH.map{ (_,0.0) });
+ //(parse)
+ val watch = Stopwatch.time()
+ for( i <- 0 until 10){
+ val parse = parser.parse(s19);
+ watch.lap.toInt should be < 20
+ }
+ }
+ it("[29] should parse < 0.1s"){
+ //(variables)
+ val s29 = MSent("1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11 + 12 + 13 + 14 + 15")
+ var parser = CKYParser.apply(math2str.length, MATH.map{ (_,0.0) });
+ //(parse)
+ val watch = Stopwatch.time()
+ for( i <- 0 until 10){
+ val parse = parser.parse(s29);
+ watch.lap.toInt should be < 100
+ }
+ }
+ it("[39] should parse < 0.2s"){
+ //(variables)
+ val s39 = MSent("1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11 + 12 + 13 + 14 + 15 + 16 + 17 + 18 + 19 + 20")
+ var parser = CKYParser.apply(math2str.length, MATH.map{ (_,0.0) });
+ //(parse)
+ val watch = Stopwatch.time()
+ for( i <- 0 until 10){
+ val parse = parser.parse(s39);
+ watch.lap.toInt should be < 200
+ }
+ }
+ it("[49] should parse < 0.5s"){
+ //(variables)
+ val s49 = MSent("1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11 + 12 + 13 + 14 + 15 + 16 + 17 + 18 + 19 + 20 + 21 + 22 + 23 + 24 + 25")
+ var parser = CKYParser.apply(math2str.length, MATH.map{ (_,0.0) });
+ //(parse)
+ val watch = Stopwatch.time()
+ for( i <- 0 until 10){
+ val parse = parser.parse(s49);
+ watch.lap.toInt should be < 500
+ }
+ }
+
+
+ }
}

0 comments on commit 0d326ec

Please sign in to comment.