This repository has been archived by the owner on Feb 19, 2020. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 82
/
TrainPosTagger.scala
58 lines (46 loc) · 2.11 KB
/
TrainPosTagger.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
package epic.sequences
import java.io._
import breeze.config.{CommandLineParser, Configuration}
import breeze.optimize.FirstOrderMinimizer.OptParams
import breeze.util.SerializableLogging
import epic.trees.{AnnotatedLabel, ProcessedTreebank}
/**
*
* @author dlwh
*/
object TrainPosTagger extends SerializableLogging {
case class Params(opt: OptParams, treebank: ProcessedTreebank, modelOut: File = new File("pos-model.ser.gz"))
def main(args: Array[String]) {
val params = CommandLineParser.readIn[Params](args)
logger.info("Command line arguments for recovery:\n" + Configuration.fromObject(params).toCommandLineString)
import params._
val train = treebank.trainTrees.map(_.mapLabels(_.baseAnnotatedLabel)).map(_.asTaggedSequence)
val test = treebank.devTrees.map(_.mapLabels(_.baseAnnotatedLabel)).map(_.asTaggedSequence)
val crf = CRF.buildSimple(train, AnnotatedLabel("TOP"), opt = opt)
breeze.util.writeObject(modelOut, crf)
val stats = TaggedSequenceEval.eval(crf, test)
println("Final Stats: " + stats)
println("Confusion Matrix:\n" + stats.confusion)
}
}
/**
* Mostly for debugging SemiCRFs. Just uses a SemiCRF as a CRF.
* @author dlwh
*/
object SemiPOSTagger extends SerializableLogging {
case class Params(opt: OptParams, treebank: ProcessedTreebank)
def main(args: Array[String]) {
val params = CommandLineParser.readIn[Params](args)
logger.info("Command line arguments for recovery:\n" + Configuration.fromObject(params).toCommandLineString)
import params._
val train = treebank.trainTrees.map(_.asTaggedSequence.asSegmentation)
val test = treebank.devTrees.map(_.asTaggedSequence.asSegmentation)
val crf = SemiCRF.buildSimple(train, opt = opt)
val inf = crf.asInstanceOf[SemiCRFInference[_, _]]
// val out = new PrintWriter(new BufferedOutputStream(new FileOutputStream("weights.txt")))
// Encoder.fromIndex(inf.featureIndex).decode(inf.weights).iterator foreach {case (x, v) if v.abs > 1E-6 => out.println(x -> v) case _ => }
// out.close()
val stats = SegmentationEval.eval(crf, test)
println("Final Stats: " + stats)
}
}