This repository has been archived by the owner on Feb 19, 2020. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 82
/
OneBestInferenceAdaptor.scala
52 lines (40 loc) · 1.8 KB
/
OneBestInferenceAdaptor.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
package epic.framework
import breeze.util.Index
import breeze.linalg.DenseVector
/**
* TODO
*
* @author dlwh
**/
class OneBestInferenceAdaptor[Datum](val inference: AnnotatingInference[Datum]) extends Inference[Datum] {
type Marginal = inference.Marginal
type Scorer = inference.Scorer
def scorer(v: Datum): Scorer = inference.scorer(v)
def goldMarginal(scorer: Scorer, v: Datum): Marginal = inference.goldMarginal(scorer, v)
/**
* Produces the "guess marginal" which is the marginal conditioned on only the input data
* @param v the example
* @return gold marginal
*/
def marginal(scorer: Scorer, v: Datum): Marginal = {
val m = inference.marginal(scorer, v)
goldMarginal(scorer, inference.annotate(v, m))
}
}
class OneBestModelAdaptor[Datum](val model: Model[Datum] { type Inference <: AnnotatingInference[Datum]}) extends Model[Datum] {
type ExpectedCounts = model.ExpectedCounts
type Marginal = model.Marginal
type Scorer = model.Scorer
type Inference = OneBestInferenceAdaptor[Datum] { type Marginal = model.Marginal; type Scorer = model.Scorer}
def emptyCounts: ExpectedCounts = model.emptyCounts
def accumulateCounts(inf: Inference, s: Scorer, d: Datum, m: Marginal, accum: ExpectedCounts, scale: Double) {
model.accumulateCounts(inf.inference.asInstanceOf[model.Inference], s, d, m, accum, scale)
}
def featureIndex: Index[Feature] = model.featureIndex
def initialValueForFeature(f: Feature): Double = model.initialValueForFeature(f)
// hack cause i'm lazy.
def inferenceFromWeights(weights: DenseVector[Double]): Inference = new OneBestInferenceAdaptor[Datum](model.inferenceFromWeights(weights)).asInstanceOf[Inference]
def expectedCountsToObjective(ecounts: ExpectedCounts): (Double, DenseVector[Double]) = {
model.expectedCountsToObjective(ecounts)
}
}