Skip to content

Commit

Permalink
Calling erase() in every dataset that we stop using.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Apr 16, 2015
1 parent 5ea6823 commit c21ac59
Show file tree
Hide file tree
Showing 36 changed files with 294 additions and 120 deletions.
2 changes: 1 addition & 1 deletion TODO.txt
Expand Up @@ -3,7 +3,7 @@ CODE IMPROVEMENTS


- CSV reader and converter to Dataset. - CSV reader and converter to Dataset.


- Serialization improvements: set serialVersionUID? - Improve Serialization by setting the serialVersionUID in every serializable class?
- Create better Exceptions and Exception messages. - Create better Exceptions and Exception messages.
- Add multithreading support. - Add multithreading support.
- Check out the Sparse Matrices/Vectors in Apache Math3 library. Use them in GaussianDPMM, MultinomialDPMM and MatrixLinearRegression. - Check out the Sparse Matrices/Vectors in Apache Math3 library. Use them in GaussianDPMM, MultinomialDPMM and MatrixLinearRegression.
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/com/datumbox/applications/nlp/CETR.java
Expand Up @@ -23,12 +23,10 @@
import com.datumbox.common.persistentstorage.interfaces.DatabaseConfiguration; import com.datumbox.common.persistentstorage.interfaces.DatabaseConfiguration;
import com.datumbox.common.utilities.MapFunctions; import com.datumbox.common.utilities.MapFunctions;
import com.datumbox.common.utilities.PHPfunctions; import com.datumbox.common.utilities.PHPfunctions;
import com.datumbox.common.utilities.RandomSingleton;
import com.datumbox.framework.machinelearning.clustering.Kmeans; import com.datumbox.framework.machinelearning.clustering.Kmeans;
import com.datumbox.framework.statistics.descriptivestatistics.Descriptives; import com.datumbox.framework.statistics.descriptivestatistics.Descriptives;
import com.datumbox.framework.utilities.text.cleaners.HTMLCleaner; import com.datumbox.framework.utilities.text.cleaners.HTMLCleaner;
import com.datumbox.framework.utilities.text.cleaners.StringCleaner; import com.datumbox.framework.utilities.text.cleaners.StringCleaner;
import java.math.BigInteger;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.HashMap; import java.util.HashMap;
Expand Down Expand Up @@ -187,6 +185,8 @@ private List<Integer> selectRows(List<String> rows, Parameters parameters) {
} }
} }


dataset.erase();

return selectedRows; return selectedRows;
} }


Expand Down
18 changes: 6 additions & 12 deletions src/main/java/com/datumbox/applications/nlp/TextClassifier.java
Expand Up @@ -15,9 +15,7 @@
*/ */
package com.datumbox.applications.nlp; package com.datumbox.applications.nlp;


import com.datumbox.common.dataobjects.AssociativeArray;
import com.datumbox.common.dataobjects.Dataset; import com.datumbox.common.dataobjects.Dataset;
import com.datumbox.common.dataobjects.Record;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConfiguration; import com.datumbox.common.persistentstorage.interfaces.DatabaseConfiguration;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConnector; import com.datumbox.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.framework.machinelearning.common.bases.featureselection.CategoricalFeatureSelection; import com.datumbox.framework.machinelearning.common.bases.featureselection.CategoricalFeatureSelection;
Expand All @@ -27,17 +25,8 @@
import com.datumbox.framework.machinelearning.common.bases.datatransformation.DataTransformer; import com.datumbox.framework.machinelearning.common.bases.datatransformation.DataTransformer;
import com.datumbox.framework.utilities.dataset.DatasetBuilder; import com.datumbox.framework.utilities.dataset.DatasetBuilder;
import com.datumbox.framework.utilities.text.extractors.TextExtractor; import com.datumbox.framework.utilities.text.extractors.TextExtractor;
import com.datumbox.framework.utilities.text.cleaners.StringCleaner;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.URI; import java.net.URI;
import java.util.HashMap; import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map; import java.util.Map;


/** /**
Expand Down Expand Up @@ -110,6 +99,8 @@ public void fit(Map<Object, URI> dataset, TrainingParameters trainingParameters)


_fit(trainingDataset); _fit(trainingDataset);


trainingDataset.erase();

//store database //store database
knowledgeBase.save(); knowledgeBase.save();
} }
Expand Down Expand Up @@ -191,8 +182,11 @@ public BaseMLmodel.ValidationMetrics validate(Map<Object, URI> dataset) {
//build the testDataset //build the testDataset
Dataset testDataset = DatasetBuilder.parseFromTextFiles(dataset, textExtractor, dbConf); Dataset testDataset = DatasetBuilder.parseFromTextFiles(dataset, textExtractor, dbConf);


BaseMLmodel.ValidationMetrics vm = getPredictions(testDataset);

testDataset.erase();


return getPredictions(testDataset); return vm;
} }


protected BaseMLmodel.ValidationMetrics getPredictions(Dataset testDataset) { protected BaseMLmodel.ValidationMetrics getPredictions(Dataset testDataset) {
Expand Down
18 changes: 13 additions & 5 deletions src/main/java/com/datumbox/common/dataobjects/Dataset.java
Expand Up @@ -33,11 +33,11 @@ public final class Dataset implements Serializable, Iterable<Integer> {
public static final String yColumnName = "~Y"; public static final String yColumnName = "~Y";
public static final String constantColumnName = "~CONSTANT"; public static final String constantColumnName = "~CONSTANT";


private final Map<Integer, Record> recordList; private Map<Integer, Record> recordList;


private TypeInference.DataType yDataType; private TypeInference.DataType yDataType;
/* Stores columnName=> DataType */ /* Stores columnName=> DataType */
private final Map<Object, TypeInference.DataType> xDataTypes; private Map<Object, TypeInference.DataType> xDataTypes;


private transient String dbName; private transient String dbName;
private transient DatabaseConnector dbc; private transient DatabaseConnector dbc;
Expand Down Expand Up @@ -299,13 +299,21 @@ public void _set(Integer rId, Record r) {
} }


/** /**
* Clears the Dataset and removes the internal variables. * Erases the Dataset and removes all internal variables.
*/ */
public void clear() { public void erase() {
yDataType = null;
dbc.dropBigMap("tmp_xColumnTypes", xDataTypes); dbc.dropBigMap("tmp_xColumnTypes", xDataTypes);
dbc.dropBigMap("tmp_recordList", recordList); dbc.dropBigMap("tmp_recordList", recordList);
dbc.dropDatabase(); dbc.dropDatabase();

dbName = null;
dbc = null;
dbConf = null;

//Ensures that the Dataset can't be used after erase() is called.
yDataType = null;
xDataTypes = null;
recordList = null;
} }


/** /**
Expand Down
Expand Up @@ -226,6 +226,7 @@ protected void _fit(Dataset trainingData) {
sampledTrainingDataset = sampledTrainingDataset.copy(); sampledTrainingDataset = sampledTrainingDataset.copy();
} }
mlclassifier.fit(sampledTrainingDataset, weakClassifierTrainingParameters); mlclassifier.fit(sampledTrainingDataset, weakClassifierTrainingParameters);
sampledTrainingDataset.erase();
sampledTrainingDataset = null; sampledTrainingDataset = null;




Expand All @@ -237,7 +238,9 @@ protected void _fit(Dataset trainingData) {
mlclassifier = null; mlclassifier = null;


Status status = updateObservationAndClassifierWeights(validationDataset, observationWeights, sampledIDs); Status status = updateObservationAndClassifierWeights(validationDataset, observationWeights, sampledIDs);

if(copyData) {
validationDataset.erase();
}
validationDataset = null; validationDataset = null;


if(status==Status.STOP) { if(status==Status.STOP) {
Expand Down
Expand Up @@ -230,7 +230,8 @@ protected static void denormalizeY(Dataset data, Map<Object, Double> minColumnVa
return; return;
} }


if(data.getYDataType()==TypeInference.DataType.NUMERICAL) { TypeInference.DataType dataType = data.getYDataType();
if(dataType==TypeInference.DataType.NUMERICAL || dataType==null) {


for(Integer rId : data) { for(Integer rId : data) {
Record r = data.get(rId); Record r = data.get(rId);
Expand All @@ -242,13 +243,17 @@ protected static void denormalizeY(Dataset data, Map<Object, Double> minColumnVa
Object denormalizedY = null; Object denormalizedY = null;
Object denormalizedYPredicted = null; Object denormalizedYPredicted = null;
if(min.equals(max)) { if(min.equals(max)) {
denormalizedY = min; if(r.getY()!=null) {
denormalizedY = min;
}
if(r.getYPredicted()!=null) { if(r.getYPredicted()!=null) {
denormalizedYPredicted = min; denormalizedYPredicted = min;
} }
} }
else { else {
denormalizedY = TypeInference.toDouble(r.getY())*(max-min) + min; if(r.getY()!=null) {
denormalizedY = TypeInference.toDouble(r.getY())*(max-min) + min;
}


Double YPredicted = TypeInference.toDouble(r.getYPredicted()); Double YPredicted = TypeInference.toDouble(r.getYPredicted());
if(YPredicted!=null) { if(YPredicted!=null) {
Expand Down
Expand Up @@ -109,6 +109,7 @@ public VM kFoldCrossValidation(Dataset dataset, int k, String dbName, DatabaseCo
trainingData = trainingData.copy(); trainingData = trainingData.copy();
} }
mlmodel.fit(trainingData, trainingParameters); mlmodel.fit(trainingData, trainingParameters);
trainingData.erase();
trainingData = null; trainingData = null;




Expand All @@ -118,6 +119,7 @@ public VM kFoldCrossValidation(Dataset dataset, int k, String dbName, DatabaseCo
} }
//fetch validation metrics //fetch validation metrics
VM entrySample = mlmodel.validate(validationData); VM entrySample = mlmodel.validate(validationData);
validationData.erase();
validationData = null; validationData = null;




Expand Down
Expand Up @@ -179,6 +179,7 @@ protected void _fit(Dataset trainingData) {
mlregressor = BaseMLmodel.newInstance(trainingParameters.getRegressionClass(), dbName, knowledgeBase.getDbConf()); mlregressor = BaseMLmodel.newInstance(trainingParameters.getRegressionClass(), dbName, knowledgeBase.getDbConf());


mlregressor.fit(copiedTrainingData, trainingParameters.getRegressionTrainingParameters()); mlregressor.fit(copiedTrainingData, trainingParameters.getRegressionTrainingParameters());
copiedTrainingData.erase();
copiedTrainingData = null; copiedTrainingData = null;
} }


Expand Down
Expand Up @@ -63,32 +63,12 @@ public static Map<Object, List<String>> stringListsFromTextFiles(Map<Object, URI
return listsMap; return listsMap;
} }


public static Dataset parseFromTextLists(Map<Object, List<String>> dataset, TextExtractor textExtractor, DatabaseConfiguration dbConf) { public static Dataset parseFromTextFiles(Map<Object, URI> dataFiles, TextExtractor textExtractor, DatabaseConfiguration dbConf) {
Dataset data = new Dataset(dbConf); Dataset dataset = new Dataset(dbConf);
Logger logger = LoggerFactory.getLogger(DatasetBuilder.class); Logger logger = LoggerFactory.getLogger(DatasetBuilder.class);


//loop throw the map and process each category file //loop throw the map and process each category file
for(Map.Entry<Object, List<String>> entry : dataset.entrySet()) { for(Map.Entry<Object, URI> entry : dataFiles.entrySet()) {
Object theClass = entry.getKey();
List<String> textList = entry.getValue();

logger.info("Dataset Parsing " + theClass + " class");

for(String text : textList) {
//extract features of the string and add every keyword combination in X map
data.add(new Record(new AssociativeArray(textExtractor.extract(StringCleaner.clear(text))), theClass));
}
}

return data;
}

public static Dataset parseFromTextFiles(Map<Object, URI> dataset, TextExtractor textExtractor, DatabaseConfiguration dbConf) {
Dataset data = new Dataset(dbConf);
Logger logger = LoggerFactory.getLogger(DatasetBuilder.class);

//loop throw the map and process each category file
for(Map.Entry<Object, URI> entry : dataset.entrySet()) {
Object theClass = entry.getKey(); Object theClass = entry.getKey();
URI datasetURI = entry.getValue(); URI datasetURI = entry.getValue();


Expand All @@ -99,15 +79,15 @@ public static Dataset parseFromTextFiles(Map<Object, URI> dataset, TextExtractor
//read strings one by one //read strings one by one
for(String line; (line = br.readLine()) != null; ) { for(String line; (line = br.readLine()) != null; ) {
//extract features of the string and add every keyword combination in X map //extract features of the string and add every keyword combination in X map
data.add(new Record(new AssociativeArray(textExtractor.extract(StringCleaner.clear(line))), theClass)); dataset.add(new Record(new AssociativeArray(textExtractor.extract(StringCleaner.clear(line))), theClass));
} }
} }
catch (IOException ex) { catch (IOException ex) {
throw new RuntimeException(ex); throw new RuntimeException(ex);
} }
} }


return data; return dataset;
} }


} }
Expand Up @@ -54,7 +54,7 @@ public void testTrainAndValidate() {
Dataset[] data = Datasets.carsNumeric(dbConf); Dataset[] data = Datasets.carsNumeric(dbConf);
Dataset trainingData = data[0]; Dataset trainingData = data[0];


Dataset newData = data[1]; Dataset validationData = data[1];




String dbName = "JUnit"; String dbName = "JUnit";
Expand Down Expand Up @@ -97,20 +97,23 @@ public void testTrainAndValidate() {


instance = new Modeler(dbName, dbConf); instance = new Modeler(dbName, dbConf);


instance.validate(newData); instance.validate(validationData);






Map<Integer, Object> expResult = new HashMap<>(); Map<Integer, Object> expResult = new HashMap<>();
Map<Integer, Object> result = new HashMap<>(); Map<Integer, Object> result = new HashMap<>();
for(Integer rId : newData) { for(Integer rId : validationData) {
Record r = newData.get(rId); Record r = validationData.get(rId);
expResult.put(rId, r.getY()); expResult.put(rId, r.getY());
result.put(rId, r.getYPredicted()); result.put(rId, r.getYPredicted());
} }
assertEquals(expResult, result); assertEquals(expResult, result);


instance.erase(); instance.erase();

trainingData.erase();
validationData.erase();
} }


} }
Expand Up @@ -111,22 +111,23 @@ public void testTrainAndPredict() throws URISyntaxException, MalformedURLExcepti




instance = new TextClassifier(dbName, dbConf); instance = new TextClassifier(dbName, dbConf);
Dataset testDataset = null; Dataset validationDataset = null;
try { try {
testDataset = instance.predict(TestUtils.getRemoteFile(new URL("http://www.datumbox.com/files/datasets/example.test"))); validationDataset = instance.predict(TestUtils.getRemoteFile(new URL("http://www.datumbox.com/files/datasets/example.test")));
} }
catch(Exception ex) { catch(Exception ex) {
TestUtils.log(this.getClass(), "Unable to download datasets, skipping test."); TestUtils.log(this.getClass(), "Unable to download datasets, skipping test.");
return; return;
} }


List<Object> expResult = Arrays.asList("negative","positive"); List<Object> expResult = Arrays.asList("negative","positive");
for(Integer rId : testDataset) { for(Integer rId : validationDataset) {
Record r = testDataset.get(rId); Record r = validationDataset.get(rId);
assertEquals(expResult.get(rId), r.getYPredicted()); assertEquals(expResult.get(rId), r.getYPredicted());
} }


instance.erase(); instance.erase();
validationDataset.erase();
} }


} }

0 comments on commit c21ac59

Please sign in to comment.