Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Changed abstract model to better support creating full model for re-t…

…agging.
  • Loading branch information...
commit c402d297e8eeb1f867f44ab5df2833742eb9c773 1 parent d482514
tsmorton authored
View
4 src/main/java/opennlp/maxent/GISModel.java
@@ -33,7 +33,7 @@
* Iterative Scaling procedure (implemented in GIS.java).
*
* @author Tom Morton and Jason Baldridge
- * @version $Revision: 1.1 $, $Date: 2009/01/22 23:23:34 $
+ * @version $Revision: 1.2 $, $Date: 2009/03/19 13:04:31 $
*/
public final class GISModel extends AbstractModel {
/**
@@ -60,7 +60,7 @@ public GISModel (Context[] params, String[] predLabels, String[] outcomeNames, i
public GISModel (Context[] params, String[] predLabels, String[] outcomeNames, int correctionConstant,double correctionParam, Prior prior) {
super(params,predLabels,outcomeNames,correctionConstant,correctionParam);
this.prior = prior;
- prior.setLabels(ocNames, predLabels);
+ prior.setLabels(outcomeNames, predLabels);
modelType = ModelType.Maxent;
}
View
32 src/main/java/opennlp/model/AbstractModel.java
@@ -23,10 +23,10 @@
public abstract class AbstractModel implements MaxentModel {
- /** Maping between predicates/contexts and an integer representing them. */
+ /** Mapping between predicates/contexts and an integer representing them. */
protected Map<String,Integer> pmap;
/** The names of the outcomes. */
- protected String[] ocNames;
+ protected String[] outcomeNames;
/** Parameters for the model. */
protected EvalParameters evalParams;
/** Prior distribution for this model. */
@@ -36,15 +36,21 @@
/** The type of the model. */
protected ModelType modelType;
+
+ public AbstractModel(Context[] params, String[] predLabels, Map<String,Integer> pmap, String[] outcomeNames) {
+ this.pmap = pmap;
+ this.outcomeNames = outcomeNames;
+ this.evalParams = new EvalParameters(params,outcomeNames.length);
+ }
public AbstractModel(Context[] params, String[] predLabels, String[] outcomeNames) {
init(predLabels,outcomeNames);
- this.evalParams = new EvalParameters(params,ocNames.length);
+ this.evalParams = new EvalParameters(params,outcomeNames.length);
}
public AbstractModel(Context[] params, String[] predLabels, String[] outcomeNames, int correctionConstant,double correctionParam) {
init(predLabels,outcomeNames);
- this.evalParams = new EvalParameters(params,correctionParam,correctionConstant,ocNames.length);
+ this.evalParams = new EvalParameters(params,correctionParam,correctionConstant,outcomeNames.length);
}
private void init(String[] predLabels, String[] outcomeNames){
@@ -52,7 +58,7 @@ private void init(String[] predLabels, String[] outcomeNames){
for (int i=0; i<predLabels.length; i++) {
pmap.put(predLabels[i], i);
}
- this.ocNames = outcomeNames;
+ this.outcomeNames = outcomeNames;
}
@@ -68,7 +74,7 @@ public final String getBestOutcome(double[] ocs) {
int best = 0;
for (int i = 1; i<ocs.length; i++)
if (ocs[i] > ocs[best]) best = i;
- return ocNames[best];
+ return outcomeNames[best];
}
public ModelType getModelType(){
@@ -88,15 +94,15 @@ public ModelType getModelType(){
* for each one.
*/
public final String getAllOutcomes(double[] ocs) {
- if (ocs.length != ocNames.length) {
+ if (ocs.length != outcomeNames.length) {
return "The double array sent as a parameter to GISModel.getAllOutcomes() must not have been produced by this model.";
}
else {
DecimalFormat df = new DecimalFormat("0.0000");
StringBuffer sb = new StringBuffer(ocs.length*2);
- sb.append(ocNames[0]).append("[").append(df.format(ocs[0])).append("]");
+ sb.append(outcomeNames[0]).append("[").append(df.format(ocs[0])).append("]");
for (int i = 1; i<ocs.length; i++) {
- sb.append(" ").append(ocNames[i]).append("[").append(df.format(ocs[i])).append("]");
+ sb.append(" ").append(outcomeNames[i]).append("[").append(df.format(ocs[i])).append("]");
}
return sb.toString();
}
@@ -109,7 +115,7 @@ public final String getAllOutcomes(double[] ocs) {
* @return The name of the outcome associated with that id.
*/
public final String getOutcome(int i) {
- return ocNames[i];
+ return outcomeNames[i];
}
/**
@@ -121,8 +127,8 @@ public final String getOutcome(int i) {
* model, -1 if it does not.
**/
public int getIndex(String outcome) {
- for (int i=0; i<ocNames.length; i++) {
- if (ocNames[i].equals(outcome))
+ for (int i=0; i<outcomeNames.length; i++) {
+ if (outcomeNames[i].equals(outcome))
return i;
}
return -1;
@@ -156,7 +162,7 @@ public int getNumOutcomes() {
Object[] data = new Object[5];
data[0] = evalParams.getParams();
data[1] = pmap;
- data[2] = ocNames;
+ data[2] = outcomeNames;
data[3] = new Integer((int)evalParams.getCorrectionConstant());
data[4] = new Double(evalParams.getCorrectionParam());
return data;
View
5 src/main/java/opennlp/perceptron/PerceptronModel.java
@@ -21,6 +21,7 @@
import java.io.File;
import java.io.InputStreamReader;
import java.text.DecimalFormat;
+import java.util.Map;
import opennlp.model.AbstractModel;
import opennlp.model.Context;
@@ -28,6 +29,10 @@
public class PerceptronModel extends AbstractModel {
+ public PerceptronModel(Context[] params, String[] predLabels, Map<String,Integer> pmap, String[] outcomeNames) {
+ super(params,predLabels,outcomeNames);
+ modelType = ModelType.Perceptron;
+ }
public PerceptronModel(Context[] params, String[] predLabels, String[] outcomeNames) {
super(params,predLabels,outcomeNames);
View
2  src/main/java/opennlp/perceptron/SimplePerceptronSequenceTrainer.java
@@ -180,7 +180,7 @@ public void nextIteration(int iteration) {
int oei=0;
int si=0;
for (Sequence sequence : sequenceStream) {
- Event[] taggerEvents = sequenceStream.updateContext(sequence, new PerceptronModel(params,predLabels,outcomeLabels));
+ Event[] taggerEvents = sequenceStream.updateContext(sequence, new PerceptronModel(params,predLabels,pmap,outcomeLabels));
Event[] events = sequence.getEvents();
for (int ei=0;ei<events.length;ei++,oei++) {
String[] contextStrings = events[ei].getContext();
Please sign in to comment.
Something went wrong with that request. Please try again.