This repository has been archived by the owner on Feb 19, 2020. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 82
/
SegmentationEval.scala
54 lines (44 loc) · 2.02 KB
/
SegmentationEval.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
package epic.sequences
import epic.framework.EvaluationResult
import breeze.util.SerializableLogging
/**
* Object for evaluating [[epic.sequences.Segmentation]]s. Returned metrics
* are precision, recall, and f1
*
* @author dlwh
*/
object SegmentationEval extends SerializableLogging {
def eval[L ,W](crf: SemiCRF[L, W], examples: IndexedSeq[Segmentation[L, W]], logOnlyErrors: Boolean = true):Stats = {
examples.par.aggregate(new Stats(0,0,0)) ({ (stats, gold )=>
val guess = crf.bestSequence(gold.words, gold.id +"-guess")
try {
if (guess.label != gold.label)
logger.trace(s"gold = $gold guess = $guess " +
s"guess logPartition = ${crf.goldMarginal(guess.segments, guess.words).logPartition} " +
s"gold logPartition =${crf.goldMarginal(gold.segments, gold.words).logPartition}")
} catch {
case ex: Exception => logger.debug("Can't recover gold for " + gold)
}
val myStats = evaluateExample(Set(), guess, gold)
if (!logOnlyErrors || myStats.f1 < 1.0)
logger.info("Guess:\n" + guess.render + "\n Gold:\n" + gold.render+ "\n" + myStats)
stats + myStats
}, {_ + _})
}
def evaluateExample[W, L](outsideLabel: Set[L], guess: Segmentation[L, W], gold: Segmentation[L, W]): SegmentationEval.Stats = {
val guessSet = guess.segments.filter(a => !outsideLabel(a._1)).toSet
val goldSet = gold.segments.filter(a => !outsideLabel(a._1)).toSet
val nRight = (guessSet & goldSet).size
val myStats: Stats = new Stats(nRight, guessSet.size, goldSet.size)
myStats
}
class Stats(val nRight: Int = 0, val nGuess: Int = 0, val nGold: Int = 0) extends EvaluationResult[Stats] {
def precision = nRight * 1.0 / nGuess
def recall = nRight * 1.0 / nGold
def f1 = 2 * precision * recall / (precision + recall)
def +(stats: Stats) = {
new Stats(nRight + stats.nRight, nGuess + stats.nGuess, nGold + stats.nGold)
}
override def toString = "Evaluation Result: P=%.4f R=%.4f F=%.4f".format(precision,recall,f1)
}
}