Skip to content

Commit

Permalink
Updates to support perceptron models.
Browse files Browse the repository at this point in the history
  • Loading branch information
tsmorton committed Nov 6, 2008
1 parent 9d3a343 commit d1123a6
Show file tree
Hide file tree
Showing 34 changed files with 1,279 additions and 214 deletions.
8 changes: 4 additions & 4 deletions src/java/opennlp/maxent/BasicEventStream.java
Expand Up @@ -17,8 +17,8 @@

package opennlp.maxent;

import opennlp.model.AbstractEventStream;
import opennlp.model.Event;
import opennlp.model.EventStream;

/**
* A object which can deliver a stream of training events assuming
Expand All @@ -30,9 +30,9 @@
* <p> cp_1 cp_2 ... cp_n outcome
*
* @author Jason Baldridge
* @version $Revision: 1.4 $, $Date: 2008/09/28 18:03:47 $
* @version $Revision: 1.5 $, $Date: 2008/11/06 19:59:44 $
*/
public class BasicEventStream implements EventStream {
public class BasicEventStream extends AbstractEventStream {
ContextGenerator cg = new BasicContextGenerator();
DataStream ds;
Event next;
Expand All @@ -48,7 +48,7 @@ public BasicEventStream (DataStream ds) {
*
* @return the Event object which is next in this EventStream
*/
public Event nextEvent () {
public Event next () {
while (next == null && this.ds.hasNext())
next = createEvent((String)this.ds.nextToken());

Expand Down
20 changes: 10 additions & 10 deletions src/java/opennlp/maxent/GISModel.java
Expand Up @@ -33,7 +33,7 @@
* Iterative Scaling procedure (implemented in GIS.java).
*
* @author Tom Morton and Jason Baldridge
* @version $Revision: 1.23 $, $Date: 2008/09/28 18:03:50 $
* @version $Revision: 1.24 $, $Date: 2008/11/06 19:59:44 $
*/
public final class GISModel extends AbstractModel {
/**
Expand Down Expand Up @@ -77,11 +77,11 @@ public GISModel (Context[] params, String[] predLabels, String[] outcomeNames, i
* getOutcome(int i).
*/
public final double[] eval(String[] context) {
return(eval(context,new double[evalParams.numOutcomes]));
return(eval(context,new double[evalParams.getNumOutcomes()]));
}

public final double[] eval(String[] context, float[] values) {
return(eval(context,values,new double[evalParams.numOutcomes]));
return(eval(context,values,new double[evalParams.getNumOutcomes()]));
}

public final double[] eval(String[] context, double[] outsums) {
Expand Down Expand Up @@ -144,8 +144,8 @@ public static double[] eval(int[] context, double[] prior, EvalParameters model)
* getOutcome(int i).
*/
public static double[] eval(int[] context, float[] values, double[] prior, EvalParameters model) {
Context[] params = model.params;
int numfeats[] = new int[model.numOutcomes];
Context[] params = model.getParams();
int numfeats[] = new int[model.getNumOutcomes()];
int[] activeOutcomes;
double[] activeParameters;
double value = 1;
Expand All @@ -166,17 +166,17 @@ public static double[] eval(int[] context, float[] values, double[] prior, EvalP
}

double normal = 0.0;
for (int oid = 0; oid < model.numOutcomes; oid++) {
if (model.correctionParam != 0) {
prior[oid] = Math.exp(prior[oid]*model.constantInverse+((1.0 - ((double) numfeats[oid] / model.correctionConstant)) * model.correctionParam));
for (int oid = 0; oid < model.getNumOutcomes(); oid++) {
if (model.getCorrectionParam() != 0) {
prior[oid] = Math.exp(prior[oid]*model.getConstantInverse()+((1.0 - ((double) numfeats[oid] / model.getCorrectionConstant())) * model.getCorrectionParam()));
}
else {
prior[oid] = Math.exp(prior[oid]*model.constantInverse);
prior[oid] = Math.exp(prior[oid]*model.getConstantInverse());
}
normal += prior[oid];
}

for (int oid = 0; oid < model.numOutcomes; oid++) {
for (int oid = 0; oid < model.getNumOutcomes(); oid++) {
prior[oid] /= normal;
}
return prior;
Expand Down
10 changes: 5 additions & 5 deletions src/java/opennlp/maxent/GISTrainer.java
Expand Up @@ -45,7 +45,7 @@
*
* @author Tom Morton
* @author Jason Baldridge
* @version $Revision: 1.30 $, $Date: 2008/09/28 18:03:38 $
* @version $Revision: 1.31 $, $Date: 2008/11/06 19:59:44 $
*/
class GISTrainer {

Expand Down Expand Up @@ -364,7 +364,7 @@ else if (useSimpleSmoothing) {
findParameters(iterations);

/*************** Create and return the model ******************/
return new GISModel(params, predLabels, outcomeLabels, correctionConstant, evalParams.correctionParam);
return new GISModel(params, predLabels, outcomeLabels, correctionConstant, evalParams.getCorrectionParam());

}

Expand Down Expand Up @@ -467,7 +467,7 @@ private double nextIteration() {
}
}
if (useSlackParameter)
CFMOD += (evalParams.correctionConstant - contexts[ei].length) * numTimesEventsSeen[ei];
CFMOD += (evalParams.getCorrectionConstant() - contexts[ei].length) * numTimesEventsSeen[ei];

loglikelihood += Math.log(modelDistribution[outcomeList[ei]]) * numTimesEventsSeen[ei];
numEvents += numTimesEventsSeen[ei];
Expand All @@ -493,7 +493,7 @@ private double nextIteration() {
int[] activeOutcomes = params[pi].getOutcomes();
for (int aoi=0;aoi<activeOutcomes.length;aoi++) {
if (useGaussianSmoothing) {
params[pi].updateParameter(aoi,gaussianUpdate(pi,aoi,numEvents,evalParams.correctionConstant));
params[pi].updateParameter(aoi,gaussianUpdate(pi,aoi,numEvents,evalParams.getCorrectionConstant()));
}
else {
if (model[aoi] == 0) {
Expand All @@ -505,7 +505,7 @@ private double nextIteration() {
}
}
if (CFMOD > 0.0 && useSlackParameter)
evalParams.correctionParam += (cfObservedExpect - Math.log(CFMOD));
evalParams.setCorrectionParam(evalParams.getCorrectionParam() + (cfObservedExpect - Math.log(CFMOD)));

display(". loglikelihood=" + loglikelihood + "\t" + ((double) numCorrect / numEvents) + "\n");
return (loglikelihood);
Expand Down
7 changes: 4 additions & 3 deletions src/java/opennlp/maxent/RealBasicEventStream.java
Expand Up @@ -17,11 +17,12 @@

package opennlp.maxent;

import opennlp.model.AbstractEventStream;
import opennlp.model.Event;
import opennlp.model.EventStream;
import opennlp.model.RealValueFileEventStream;

public class RealBasicEventStream implements EventStream {
public class RealBasicEventStream extends AbstractEventStream {
ContextGenerator cg = new BasicContextGenerator();
DataStream ds;
Event next;
Expand All @@ -33,7 +34,7 @@ public RealBasicEventStream(DataStream ds) {

}

public Event nextEvent() {
public Event next() {
while (next == null && this.ds.hasNext())
next = createEvent((String)this.ds.nextToken());

Expand Down Expand Up @@ -67,7 +68,7 @@ private Event createEvent(String obs) {
public static void main(String[] args) throws java.io.IOException {
EventStream es = new RealBasicEventStream(new PlainTextByLineDataStream(new java.io.FileReader(args[0])));
while (es.hasNext()) {
System.out.println(es.nextEvent());
System.out.println(es.next());
}
}
}
42 changes: 5 additions & 37 deletions src/java/opennlp/maxent/io/BinaryGISModelReader.java
Expand Up @@ -17,17 +17,17 @@

package opennlp.maxent.io;

import java.io.*;
import java.util.zip.*;
import java.io.DataInputStream;

import opennlp.model.BinaryFileDataReader;

/**
* A reader for GIS models stored in binary format.
*
* @author Jason Baldridge
* @version $Revision: 1.2 $, $Date: 2008/09/28 18:04:24 $
* @version $Revision: 1.3 $, $Date: 2008/11/06 19:59:44 $
*/
public class BinaryGISModelReader extends GISModelReader {
protected DataInputStream input;

/**
* Constructor which directly instantiates the DataInputStream containing
Expand All @@ -36,38 +36,6 @@ public class BinaryGISModelReader extends GISModelReader {
* @param dis The DataInputStream containing the model information.
*/
public BinaryGISModelReader (DataInputStream dis) {
input = dis;
}

/**
* Constructor which takes a File and creates a reader for it. Detects
* whether the file is gzipped or not based on whether the suffix contains
* ".gz"
*
* @param f The File in which the model is stored.
*/
public BinaryGISModelReader (File f) throws IOException {

if (f.getName().endsWith(".gz")) {
input = new DataInputStream(
new GZIPInputStream(new FileInputStream(f)));
}
else {
input = new DataInputStream(new FileInputStream(f));
}

super(new BinaryFileDataReader(dis));
}

public int readInt () throws java.io.IOException {
return input.readInt();
}

public double readDouble () throws java.io.IOException {
return input.readDouble();
}

public String readUTF () throws java.io.IOException {
return input.readUTF();
}

}
49 changes: 30 additions & 19 deletions src/java/opennlp/maxent/io/GISModelReader.java
Expand Up @@ -17,18 +17,31 @@

package opennlp.maxent.io;

import java.io.File;
import java.io.IOException;

import opennlp.maxent.GISModel;
import opennlp.model.AbstractModel;
import opennlp.model.AbstractModelReader;
import opennlp.model.Context;
import opennlp.model.DataReader;

/**
* Abstract parent class for readers of GISModels.
*
* @author Jason Baldridge
* @version $Revision: 1.8 $, $Date: 2008/09/28 18:04:22 $
* @version $Revision: 1.9 $, $Date: 2008/11/06 19:59:44 $
*/
public abstract class GISModelReader extends AbstractModelReader {
public class GISModelReader extends AbstractModelReader {

public GISModelReader(File file) throws IOException {
super(file);
}

public GISModelReader(DataReader dataReader) {
super(dataReader);
}

/**
* Retrieve a model from disk. It assumes that models are saved in the
* following sequence:
Expand All @@ -51,24 +64,22 @@ public abstract class GISModelReader extends AbstractModelReader {
*
* @return The GISModel stored in the format and location specified to
* this GISModelReader (usually via its the constructor).
*/
public AbstractModel getModel () throws java.io.IOException {
checkModelType();
int correctionConstant = getCorrectionConstant();
double correctionParam = getCorrectionParameter();
String[] outcomeLabels = getOutcomes();
int[][] outcomePatterns = getOutcomePatterns();
String[] predLabels = getPredicates();
Context[] params = getParameters(outcomePatterns);

return new GISModel(params,
predLabels,
outcomeLabels,
correctionConstant,
correctionParam);

}
*/
public AbstractModel constructModel() throws IOException {
int correctionConstant = getCorrectionConstant();
double correctionParam = getCorrectionParameter();
String[] outcomeLabels = getOutcomes();
int[][] outcomePatterns = getOutcomePatterns();
String[] predLabels = getPredicates();
Context[] params = getParameters(outcomePatterns);

return new GISModel(params,
predLabels,
outcomeLabels,
correctionConstant,
correctionParam);
}

public void checkModelType() throws java.io.IOException {
String modelType = readUTF();
if (!modelType.equals("GIS"))
Expand Down
21 changes: 4 additions & 17 deletions src/java/opennlp/maxent/io/ObjectGISModelReader.java
Expand Up @@ -17,9 +17,10 @@

package opennlp.maxent.io;

import java.io.IOException;
import java.io.ObjectInputStream;

import opennlp.model.ObjectDataReader;

public class ObjectGISModelReader extends GISModelReader {

protected ObjectInputStream input;
Expand All @@ -31,21 +32,7 @@ public class ObjectGISModelReader extends GISModelReader {
* @param dis The DataInputStream containing the model information.
*/

public ObjectGISModelReader(ObjectInputStream dis) {
super();
input = dis;
}

public int readInt() throws IOException {
return input.readInt();
}

public double readDouble() throws IOException {
return input.readDouble();
}

public String readUTF() throws IOException {
return input.readUTF();
public ObjectGISModelReader(ObjectInputStream ois) {
super(new ObjectDataReader(ois));
}

}
35 changes: 8 additions & 27 deletions src/java/opennlp/maxent/io/PlainTextGISModelReader.java
Expand Up @@ -17,17 +17,19 @@

package opennlp.maxent.io;

import java.io.*;
import java.util.zip.*;
import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;

import opennlp.model.PlainTextFileDataReader;

/**
* A reader for GIS models stored in plain text format.
*
* @author Jason Baldridge
* @version $Revision: 1.2 $, $Date: 2008/09/28 18:04:30 $
* @version $Revision: 1.3 $, $Date: 2008/11/06 19:59:44 $
*/
public class PlainTextGISModelReader extends GISModelReader {
private BufferedReader input;

/**
* Constructor which directly instantiates the BufferedReader containing
Expand All @@ -36,7 +38,7 @@ public class PlainTextGISModelReader extends GISModelReader {
* @param br The BufferedReader containing the model information.
*/
public PlainTextGISModelReader (BufferedReader br) {
input = br;
super(new PlainTextFileDataReader(br));
}

/**
Expand All @@ -47,27 +49,6 @@ public PlainTextGISModelReader (BufferedReader br) {
* @param f The File in which the model is stored.
*/
public PlainTextGISModelReader (File f) throws IOException {

if (f.getName().endsWith(".gz")) {
input = new BufferedReader(new InputStreamReader(
new GZIPInputStream(new FileInputStream(f))));
}
else {
input = new BufferedReader(new FileReader(f));
}

}

public int readInt () throws IOException {
return Integer.parseInt(input.readLine());
super(f);
}

public double readDouble () throws IOException {
return Double.parseDouble(input.readLine());
}

public String readUTF () throws IOException {
return input.readLine();
}

}

0 comments on commit d1123a6

Please sign in to comment.