/
ChainNERDemo.scala
214 lines (190 loc) · 9.53 KB
/
ChainNERDemo.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
/* Copyright (C) 2008-2016 University of Massachusetts Amherst.
This file is part of "FACTORIE" (Factor graphs, Imperative, Extensible)
http://factorie.cs.umass.edu, http://github.com/factorie
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
package cc.factorie.tutorial
import java.io.File
import cc.factorie._
import cc.factorie.infer.{BP, BPSummary, GibbsSampler, IteratedConditionalModes}
import cc.factorie.model.{DotTemplateWithStatistics1, DotTemplateWithStatistics2, Parameters, TemplateModel}
import cc.factorie.variable._
/** A demonstration of training a linear-chain CRF for named entity recognition.
Prints various diagnostics suitable to a demo.
@author Andrew McCallum */
object ChainNERDemo {
// The variable classes
object TokenDomain extends CategoricalVectorDomain[String]
class Token(val word:String, features:Seq[String], labelString:String) extends BinaryFeatureVectorVariable[String] with ChainLink[Token,Sentence] {
def domain = TokenDomain
val label: Label = new Label(labelString, this)
this ++= features
}
object LabelDomain extends CategoricalDomain[String]
class Label(labelname: String, val token: Token) extends LabeledCategoricalVariable(labelname) {
def domain = LabelDomain
def hasNext = token.hasNext && token.next.label != null
def hasPrev = token.hasPrev && token.prev.label != null
def next = token.next.label
def prev = token.prev.label
}
class Sentence extends Chain[Sentence,Token]
// The model
val model = new TemplateModel with Parameters {
// Bias term on each individual label
object bias extends DotTemplateWithStatistics1[Label] {
val weights = Weights(new la.DenseTensor1(LabelDomain.size))
}
// Transition factors between two successive labels
object transtion extends DotTemplateWithStatistics2[Label, Label] {
val weights = Weights(new la.DenseTensor2(LabelDomain.size, LabelDomain.size))
def unroll1(label: Label) = if (label.hasPrev) Factor(label.prev, label) else Nil
def unroll2(label: Label) = if (label.hasNext) Factor(label, label.next) else Nil
}
// Factor between label and observed token
object evidence extends DotTemplateWithStatistics2[Label, Token] {
val weights = Weights(new la.DenseTensor2(LabelDomain.size, TokenDomain.dimensionSize))
def unroll1(label: Label) = Factor(label, label.token)
def unroll2(token: Token) = throw new Error("Token values shouldn't change")
}
this += evidence
this += bias
this += transtion
}
// The training objective
val objective = new HammingTemplate[Label, Label#TargetType]
def main(args: Array[String]): Unit = {
implicit val random = new scala.util.Random(0)
if (args.length != 2) throw new Error("Usage: ChainNERDemo trainfile testfile")
// Read in the data
val trainSentences = load(args(0))
val testSentences = load(args(1))
// Get the variables to be inferred
val trainLabels = trainSentences.flatMap(_.links.map(_.label)).take(50000) //.take(30000)
// Add features from next and previous tokens
// println("Adding offset features...")
trainLabels.map(_.token).foreach(t => {
if (t.hasPrev) t ++= t.prev.activeCategories.filter(!_.contains('@')).map(_+"@-1")
if (t.hasNext) t ++= t.next.activeCategories.filter(!_.contains('@')).map(_+"@+1")
})
// Freeze domain now that we've added all feature values observed in training data
TokenDomain.freeze()
println("Using "+TokenDomain.dimensionSize+" observable features.")
// Compute features on test data
val testLabels = testSentences.flatMap(_.links.map(_.label))//.take(2000)
testLabels.map(_.token).foreach(t => {
if (t.hasPrev) t ++= t.prev.activeCategories.filter(!_.contains('@')).map(_+"@-1")
if (t.hasNext) t ++= t.next.activeCategories.filter(!_.contains('@')).map(_+"@+1")
})
// Print some significant features
//println("Most predictive features:")
//val pllo = new cc.factorie.app.classify.PerLabelLogOdds(trainSentences.flatMap(_.map(_.label)), (label:Label) => label.token)
//for (label <- LabelDomain.values) println(label.category+": "+pllo.top(label, 20))
// Sample and Learn!
val startTime = System.currentTimeMillis
(trainLabels ++ testLabels).foreach(_.setRandomly)
val learner = new optimize.SampleRankTrainer(new GibbsSampler(model, objective) {temperature=0.1}, new cc.factorie.optimize.AdaGrad)
val predictor = new IteratedConditionalModes(model, null)
for (i <- 1 to 3) {
// println("Iteration "+i)
learner.processContexts(trainLabels)
predictor.processAll(testLabels); predictor.processAll(trainLabels)
trainLabels.take(20).foreach(printLabel _); println(); println()
printDiagnostic(trainLabels.take(400))
//trainLabels.take(20).foreach(label => println("%30s %s %s %f".format(label.token.word, label.targetCategory, label.categoryValue, objective.currentScore(label))))
//println ("Tr50 accuracy = "+ objective.accuracy(trainLabels.take(20)))
println ("Train accuracy = "+ objective.accuracy(trainLabels))
println ("Test accuracy = "+ objective.accuracy(testLabels))
}
if (false) {
// Use BP Viterbi for prediction
for (sentence <- testSentences)
BP.inferChainMax(sentence.asSeq.map(_.label), model).setToMaximize(null)
//BP.inferChainSum(sentence.asSeq.map(_.label), model).setToMaximize(null) // max-marginal inference
for (sentence <- trainSentences.take(10)) {
println("---SumProduct---")
printTokenMarginals(sentence.asSeq, BP.inferChainSum(sentence.asSeq.map(_.label), model))
println("---MaxProduct---")
// printTokenMarginals(sentence.asSeq, BP.inferChainMax(sentence.asSeq.map(_.label), model))
println("---Gibbs Sampling---")
predictor.processAll(testLabels, 2)
sentence.asSeq.foreach(token => printLabel(token.label))
}
} else {
// Use VariableSettingsSampler for prediction
//predictor.temperature *= 0.1
predictor.processAll(testLabels, 2)
}
println ("Final Test accuracy = "+ objective.accuracy(testLabels))
//println("norm " + model.weights.twoNorm)
println("Finished in " + ((System.currentTimeMillis - startTime) / 1000.0) + " seconds")
//for (sentence <- testSentences) BP.inferChainMax(sentence.asSeq.map(_.label), model); println ("MaxBP Test accuracy = "+ objective.accuracy(testLabels))
//for (sentence <- testSentences) BP.inferChainSum(sentence.asSeq.map(_.label), model).setToMaximize(null); println ("SumBP Test accuracy = "+ objective.accuracy(testLabels))
//predictor.processAll(testLabels, 2); println ("Gibbs Test accuracy = "+ objective.accuracy(testLabels))
}
def printTokenMarginals(tokens:Seq[Token], summary:BPSummary): Unit = {
for (token <- tokens)
println(token.word + " " + LabelDomain.categories.zip(summary.marginal(token.label).proportions.asSeq).sortBy(_._2).reverse.mkString(" "))
println()
}
// Feature extraction
def wordToFeatures(word:String, initialFeatures:String*) : Seq[String] = {
import scala.collection.mutable.ArrayBuffer
val f = new ArrayBuffer[String]
f += "W="+word
f ++= initialFeatures
if (word.length > 3) f += "PRE="+word.substring(0,3)
if (Capitalized.findFirstMatchIn(word) != None) f += "CAPITALIZED"
if (Numeric.findFirstMatchIn(word) != None) f += "NUMERIC"
if (Punctuation.findFirstMatchIn(word) != None) f += "PUNCTUATION"
f
}
val Capitalized = "^[A-Z].*".r
val Numeric = "^[0-9]+$".r
val Punctuation = "[-,\\.;:?!()]+".r
def printLabel(label:Label) : Unit = {
println("%-16s TRUE=%-8s PRED=%-8s %s".format(label.token.word, label.target.categoryValue, label.value.category, label.token.toString))
}
def printDiagnostic(labels:Seq[Label]) : Unit = {
for (label <- labels; if label.intValue != label.domain.index("O")) {
if (!label.hasPrev || label.value != label.prev.value)
print("%-7s %-7s ".format(if (label.value != label.target.value) label.target.value.category.drop(2) else " ", label.value.category.drop(2)))
print(label.token.word+" ")
if (!label.hasNext || label.value != label.next.value) println()
}
println()
}
def load(filename:String) : Seq[Sentence] = {
import scala.collection.mutable.ArrayBuffer
import scala.io.Source
var wordCount = 0
var sentences = new ArrayBuffer[Sentence]
val source = Source.fromFile(new File(filename))
var sentence = new Sentence
for (line <- source.getLines()) {
if (line.length < 2) { // Sentence boundary
sentences += sentence
sentence = new Sentence
} else if (line.startsWith("-DOCSTART-")) {
// Skip document boundaries
} else {
val fields = line.split(' ')
assert(fields.length == 4)
val word = fields(0)
val pos = fields(1)
val label = fields(3).stripLineEnd
sentence += new Token(word, wordToFeatures(word,"POS="+pos), label)
wordCount += 1
}
}
println("Loaded "+sentences.length+" sentences with "+wordCount+" words total from file "+filename)
sentences
}
}