Skip to content

Commit

Permalink
Make setters on ModelParameters protected, setN() and setD() methods …
Browse files Browse the repository at this point in the history
…are called on fit() and the parameters are now part of BaseModelParameters, Rename print_r() to var_export().
  • Loading branch information
datumbox committed Apr 28, 2015
1 parent ef61b10 commit b1ed4c4
Show file tree
Hide file tree
Showing 51 changed files with 211 additions and 489 deletions.
6 changes: 6 additions & 0 deletions pom.xml
Expand Up @@ -128,6 +128,12 @@
<target>${java-version}</target> <target>${java-version}</target>
<showDeprecation>true</showDeprecation> <showDeprecation>true</showDeprecation>
<encoding>UTF-8</encoding> <encoding>UTF-8</encoding>
<compilerArgs>
<!--
<arg>-verbose</arg>
<arg>-Xlint:all,-options,-path</arg>
-->
</compilerArgs>
</configuration> </configuration>
</plugin> </plugin>
<plugin> <plugin>
Expand Down
Expand Up @@ -37,7 +37,7 @@ public class Modeler extends BaseWrapper<Modeler.ModelParameters, Modeler.Traini
*/ */
public static class ModelParameters extends BaseWrapper.ModelParameters { public static class ModelParameters extends BaseWrapper.ModelParameters {


public ModelParameters(DatabaseConnector dbc) { protected ModelParameters(DatabaseConnector dbc) {
super(dbc); super(dbc);
} }


Expand Down
20 changes: 1 addition & 19 deletions src/main/java/com/datumbox/applications/nlp/TextClassifier.java
Expand Up @@ -47,7 +47,7 @@ public class TextClassifier extends BaseWrapper<TextClassifier.ModelParameters,
*/ */
public static class ModelParameters extends BaseWrapper.ModelParameters { public static class ModelParameters extends BaseWrapper.ModelParameters {


public ModelParameters(DatabaseConnector dbc) { protected ModelParameters(DatabaseConnector dbc) {
super(dbc); super(dbc);
} }


Expand Down Expand Up @@ -119,24 +119,6 @@ public TextClassifier(String dbName, DatabaseConfiguration dbConf) {
super(dbName, dbConf, TextClassifier.ModelParameters.class, TextClassifier.TrainingParameters.class); super(dbName, dbConf, TextClassifier.ModelParameters.class, TextClassifier.TrainingParameters.class);
} }


/**
* Trains a Machine Learning model using the provided training data.
*
* @param trainingData
* @param trainingParameters
*/
@Override
public void fit(Dataset trainingData, TrainingParameters trainingParameters) {
logger.info("fit()");

initializeTrainingConfiguration(trainingParameters);

_fit(trainingData);

//store database
knowledgeBase.save();
}

/** /**
* Trains a Machine Learning model using the provided dataset files. The data * Trains a Machine Learning model using the provided dataset files. The data
* map should have as keys the names of each class and as values the URIs * map should have as keys the names of each class and as values the URIs
Expand Down
Expand Up @@ -259,7 +259,7 @@ public static <T> void shuffle(T[] array) {
* @param object * @param object
* @return * @return
*/ */
public static <T> String print_r(T object) { public static <T> String var_export(T object) {
return ToStringBuilder.reflectionToString(object); return ToStringBuilder.reflectionToString(object);
} }


Expand Down
Expand Up @@ -51,11 +51,11 @@ public static class ModelParameters extends BaseNaiveBayes.ModelParameters {
private Map<Object, Double> sumOfLog1minusProb; //the Sum Of Log(1-prob) for each class. This is used to optimize the speed of validation. Instead of looping through all the keywords by having this Sum we are able to loop only through the features of the observation private Map<Object, Double> sumOfLog1minusProb; //the Sum Of Log(1-prob) for each class. This is used to optimize the speed of validation. Instead of looping through all the keywords by having this Sum we are able to loop only through the features of the observation


/** /**
* Public constructor which accepts as argument the DatabaseConnector. * Protected constructor which accepts as argument the DatabaseConnector.
* *
* @param dbc * @param dbc
*/ */
public ModelParameters(DatabaseConnector dbc) { protected ModelParameters(DatabaseConnector dbc) {
super(dbc); super(dbc);
} }


Expand All @@ -73,7 +73,7 @@ public Map<Object, Double> getSumOfLog1minusProb() {
* *
* @param sumOfLog1minusProb * @param sumOfLog1minusProb
*/ */
public void setSumOfLog1minusProb(Map<Object, Double> sumOfLog1minusProb) { protected void setSumOfLog1minusProb(Map<Object, Double> sumOfLog1minusProb) {
this.sumOfLog1minusProb = sumOfLog1minusProb; this.sumOfLog1minusProb = sumOfLog1minusProb;
} }
} }
Expand Down Expand Up @@ -186,22 +186,18 @@ protected void predictDataset(Dataset newData) {


@Override @Override
protected void _fit(Dataset trainingData) { protected void _fit(Dataset trainingData) {
ModelParameters modelParameters = knowledgeBase.getModelParameters();
int n = modelParameters.getN();
int d = modelParameters.getD();

knowledgeBase.getTrainingParameters().setMultiProbabilityWeighted(false); knowledgeBase.getTrainingParameters().setMultiProbabilityWeighted(false);


ModelParameters modelParameters = knowledgeBase.getModelParameters();


Map<List<Object>, Double> likelihoods = modelParameters.getLogLikelihoods(); Map<List<Object>, Double> likelihoods = modelParameters.getLogLikelihoods();
Map<Object, Double> logPriors = modelParameters.getLogPriors(); Map<Object, Double> logPriors = modelParameters.getLogPriors();
Set<Object> classesSet = modelParameters.getClasses(); Set<Object> classesSet = modelParameters.getClasses();
Map<Object, Double> sumOfLog1minusProb = modelParameters.getSumOfLog1minusProb(); Map<Object, Double> sumOfLog1minusProb = modelParameters.getSumOfLog1minusProb();


int n = trainingData.getRecordNumber();
int d = trainingData.getVariableNumber();

//initialization
modelParameters.setN(n);
modelParameters.setD(d);



//calculate first statistics about the classes //calculate first statistics about the classes
AssociativeArray totalFeatureOccurrencesForEachClass = new AssociativeArray(); AssociativeArray totalFeatureOccurrencesForEachClass = new AssociativeArray();
Expand Down Expand Up @@ -261,10 +257,6 @@ protected void _fit(Dataset trainingData) {


} }


int c = classesSet.size();
modelParameters.setC(c);


//calculate prior log probabilities //calculate prior log probabilities
for(Map.Entry<Object, Double> entry : logPriors.entrySet()) { for(Map.Entry<Object, Double> entry : logPriors.entrySet()) {
Object theClass = entry.getKey(); Object theClass = entry.getKey();
Expand Down
Expand Up @@ -40,11 +40,11 @@ public class BinarizedNaiveBayes extends BaseNaiveBayes<BinarizedNaiveBayes.Mode
public static class ModelParameters extends BaseNaiveBayes.ModelParameters { public static class ModelParameters extends BaseNaiveBayes.ModelParameters {


/** /**
* Public constructor which accepts as argument the DatabaseConnector. * Protected constructor which accepts as argument the DatabaseConnector.
* *
* @param dbc * @param dbc
*/ */
public ModelParameters(DatabaseConnector dbc) { protected ModelParameters(DatabaseConnector dbc) {
super(dbc); super(dbc);
} }


Expand Down
Expand Up @@ -54,11 +54,11 @@ public static class ModelParameters extends BaseMLclassifier.ModelParameters {
private Map<List<Object>, Double> lambdas; //the lambda parameters of the model private Map<List<Object>, Double> lambdas; //the lambda parameters of the model


/** /**
* Public constructor which accepts as argument the DatabaseConnector. * Protected constructor which accepts as argument the DatabaseConnector.
* *
* @param dbc * @param dbc
*/ */
public ModelParameters(DatabaseConnector dbc) { protected ModelParameters(DatabaseConnector dbc) {
super(dbc); super(dbc);
} }


Expand All @@ -76,7 +76,7 @@ public Map<List<Object>, Double> getLambdas() {
* *
* @param lambdas * @param lambdas
*/ */
public void setLambdas(Map<List<Object>, Double> lambdas) { protected void setLambdas(Map<List<Object>, Double> lambdas) {
this.lambdas = lambdas; this.lambdas = lambdas;
} }


Expand Down Expand Up @@ -151,14 +151,7 @@ protected void predictDataset(Dataset newData) {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
protected void _fit(Dataset trainingData) { protected void _fit(Dataset trainingData) {
ModelParameters modelParameters = knowledgeBase.getModelParameters(); ModelParameters modelParameters = knowledgeBase.getModelParameters();

int n = modelParameters.getN();
int n = trainingData.getRecordNumber();
int d = trainingData.getVariableNumber();


//initialization
modelParameters.setN(n);
modelParameters.setD(d);




Map<List<Object>, Double> lambdas = modelParameters.getLambdas(); Map<List<Object>, Double> lambdas = modelParameters.getLambdas();
Expand All @@ -171,10 +164,6 @@ protected void _fit(Dataset trainingData) {


classesSet.add(theClass); classesSet.add(theClass);
} }

int c = classesSet.size();
modelParameters.setC(c);



//create a temporary map for the observed probabilities in training set //create a temporary map for the observed probabilities in training set
DatabaseConnector dbc = knowledgeBase.getDbc(); DatabaseConnector dbc = knowledgeBase.getDbc();
Expand Down
Expand Up @@ -38,11 +38,11 @@ public class MultinomialNaiveBayes extends BaseNaiveBayes<MultinomialNaiveBayes.
public static class ModelParameters extends BaseNaiveBayes.ModelParameters { public static class ModelParameters extends BaseNaiveBayes.ModelParameters {


/** /**
* Public constructor which accepts as argument the DatabaseConnector. * Protected constructor which accepts as argument the DatabaseConnector.
* *
* @param dbc * @param dbc
*/ */
public ModelParameters(DatabaseConnector dbc) { protected ModelParameters(DatabaseConnector dbc) {
super(dbc); super(dbc);
} }


Expand Down
Expand Up @@ -67,11 +67,11 @@ public static class ModelParameters extends BaseMLclassifier.ModelParameters {
private Map<Object, Double> thitas; private Map<Object, Double> thitas;


/** /**
* Public constructor which accepts as argument the DatabaseConnector. * Protected constructor which accepts as argument the DatabaseConnector.
* *
* @param dbc * @param dbc
*/ */
public ModelParameters(DatabaseConnector dbc) { protected ModelParameters(DatabaseConnector dbc) {
super(dbc); super(dbc);
} }


Expand All @@ -89,7 +89,7 @@ public Map<Object, Double> getWeights() {
* *
* @param weights * @param weights
*/ */
public void setWeights(Map<Object, Double> weights) { protected void setWeights(Map<Object, Double> weights) {
this.weights = weights; this.weights = weights;
} }


Expand All @@ -107,7 +107,7 @@ public Map<Object, Double> getThitas() {
* *
* @param thitas * @param thitas
*/ */
public void setThitas(Map<Object, Double> thitas) { protected void setThitas(Map<Object, Double> thitas) {
this.thitas = thitas; this.thitas = thitas;
} }


Expand Down Expand Up @@ -240,17 +240,9 @@ protected void predictDataset(Dataset newData) {
@Override @Override
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
protected void _fit(Dataset trainingData) { protected void _fit(Dataset trainingData) {

int n = trainingData.getRecordNumber();
int d = trainingData.getVariableNumber();

ModelParameters modelParameters = knowledgeBase.getModelParameters(); ModelParameters modelParameters = knowledgeBase.getModelParameters();
TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters(); TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters();


//initialization
modelParameters.setN(n);
modelParameters.setD(d);

Map<Object, Double> weights = modelParameters.getWeights(); Map<Object, Double> weights = modelParameters.getWeights();
Map<Object, Double> thitas = modelParameters.getThitas(); Map<Object, Double> thitas = modelParameters.getThitas();


Expand All @@ -265,9 +257,6 @@ protected void _fit(Dataset trainingData) {
Set<Object> classesSet = modelParameters.getClasses(); Set<Object> classesSet = modelParameters.getClasses();
classesSet.addAll(sortedClasses); classesSet.addAll(sortedClasses);


int c = classesSet.size();
modelParameters.setC(c);

//we initialize the weights and thitas to zero //we initialize the weights and thitas to zero
for(Object feature: trainingData.getXDataTypes().keySet()) { for(Object feature: trainingData.getXDataTypes().keySet()) {
weights.put(feature, 0.0); weights.put(feature, 0.0);
Expand Down
Expand Up @@ -54,11 +54,11 @@ public static class ModelParameters extends BaseMLclassifier.ModelParameters {
private Map<List<Object>, Double> thitas; //the thita parameters of the model private Map<List<Object>, Double> thitas; //the thita parameters of the model


/** /**
* Public constructor which accepts as argument the DatabaseConnector. * Protected constructor which accepts as argument the DatabaseConnector.
* *
* @param dbc * @param dbc
*/ */
public ModelParameters(DatabaseConnector dbc) { protected ModelParameters(DatabaseConnector dbc) {
super(dbc); super(dbc);
} }


Expand All @@ -76,7 +76,7 @@ public Map<List<Object>, Double> getThitas() {
* *
* @param thitas * @param thitas
*/ */
public void setThitas(Map<List<Object>, Double> thitas) { protected void setThitas(Map<List<Object>, Double> thitas) {
this.thitas = thitas; this.thitas = thitas;
} }
} }
Expand Down Expand Up @@ -209,17 +209,9 @@ protected void predictDataset(Dataset newData) {
@Override @Override
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
protected void _fit(Dataset trainingData) { protected void _fit(Dataset trainingData) {

int n = trainingData.getRecordNumber();
int d = trainingData.getVariableNumber();

ModelParameters modelParameters = knowledgeBase.getModelParameters(); ModelParameters modelParameters = knowledgeBase.getModelParameters();
TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters(); TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters();


//initialization
modelParameters.setN(n);
modelParameters.setD(d);



Map<List<Object>, Double> thitas = modelParameters.getThitas(); Map<List<Object>, Double> thitas = modelParameters.getThitas();
Set<Object> classesSet = modelParameters.getClasses(); Set<Object> classesSet = modelParameters.getClasses();
Expand All @@ -231,9 +223,6 @@ protected void _fit(Dataset trainingData) {


classesSet.add(theClass); classesSet.add(theClass);
} }

int c = classesSet.size();
modelParameters.setC(c);


//we initialize the thitas to zero for all features and all classes compinations //we initialize the thitas to zero for all features and all classes compinations
for(Object theClass : classesSet) { for(Object theClass : classesSet) {
Expand Down

0 comments on commit b1ed4c4

Please sign in to comment.