Skip to content

Commit

Permalink
Calling erase() in every dataset that we stop using.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Apr 16, 2015
1 parent 5ea6823 commit c21ac59
Show file tree
Hide file tree
Showing 36 changed files with 294 additions and 120 deletions.
2 changes: 1 addition & 1 deletion TODO.txt
Expand Up @@ -3,7 +3,7 @@ CODE IMPROVEMENTS

- CSV reader and converter to Dataset.

- Serialization improvements: set serialVersionUID?
- Improve Serialization by setting the serialVersionUID in every serializable class?
- Create better Exceptions and Exception messages.
- Add multithreading support.
- Check out the Sparse Matrices/Vectors in Apache Math3 library. Use them in GaussianDPMM, MultinomialDPMM and MatrixLinearRegression.
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/com/datumbox/applications/nlp/CETR.java
Expand Up @@ -23,12 +23,10 @@
import com.datumbox.common.persistentstorage.interfaces.DatabaseConfiguration;
import com.datumbox.common.utilities.MapFunctions;
import com.datumbox.common.utilities.PHPfunctions;
import com.datumbox.common.utilities.RandomSingleton;
import com.datumbox.framework.machinelearning.clustering.Kmeans;
import com.datumbox.framework.statistics.descriptivestatistics.Descriptives;
import com.datumbox.framework.utilities.text.cleaners.HTMLCleaner;
import com.datumbox.framework.utilities.text.cleaners.StringCleaner;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
Expand Down Expand Up @@ -187,6 +185,8 @@ private List<Integer> selectRows(List<String> rows, Parameters parameters) {
}
}

dataset.erase();

return selectedRows;
}

Expand Down
18 changes: 6 additions & 12 deletions src/main/java/com/datumbox/applications/nlp/TextClassifier.java
Expand Up @@ -15,9 +15,7 @@
*/
package com.datumbox.applications.nlp;

import com.datumbox.common.dataobjects.AssociativeArray;
import com.datumbox.common.dataobjects.Dataset;
import com.datumbox.common.dataobjects.Record;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConfiguration;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.framework.machinelearning.common.bases.featureselection.CategoricalFeatureSelection;
Expand All @@ -27,17 +25,8 @@
import com.datumbox.framework.machinelearning.common.bases.datatransformation.DataTransformer;
import com.datumbox.framework.utilities.dataset.DatasetBuilder;
import com.datumbox.framework.utilities.text.extractors.TextExtractor;
import com.datumbox.framework.utilities.text.cleaners.StringCleaner;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.URI;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

/**
Expand Down Expand Up @@ -110,6 +99,8 @@ public void fit(Map<Object, URI> dataset, TrainingParameters trainingParameters)

_fit(trainingDataset);

trainingDataset.erase();

//store database
knowledgeBase.save();
}
Expand Down Expand Up @@ -191,8 +182,11 @@ public BaseMLmodel.ValidationMetrics validate(Map<Object, URI> dataset) {
//build the testDataset
Dataset testDataset = DatasetBuilder.parseFromTextFiles(dataset, textExtractor, dbConf);

BaseMLmodel.ValidationMetrics vm = getPredictions(testDataset);

testDataset.erase();

return getPredictions(testDataset);
return vm;
}

protected BaseMLmodel.ValidationMetrics getPredictions(Dataset testDataset) {
Expand Down
18 changes: 13 additions & 5 deletions src/main/java/com/datumbox/common/dataobjects/Dataset.java
Expand Up @@ -33,11 +33,11 @@ public final class Dataset implements Serializable, Iterable<Integer> {
public static final String yColumnName = "~Y";
public static final String constantColumnName = "~CONSTANT";

private final Map<Integer, Record> recordList;
private Map<Integer, Record> recordList;

private TypeInference.DataType yDataType;
/* Stores columnName=> DataType */
private final Map<Object, TypeInference.DataType> xDataTypes;
private Map<Object, TypeInference.DataType> xDataTypes;

private transient String dbName;
private transient DatabaseConnector dbc;
Expand Down Expand Up @@ -299,13 +299,21 @@ public void _set(Integer rId, Record r) {
}

/**
* Clears the Dataset and removes the internal variables.
* Erases the Dataset and removes all internal variables.
*/
public void clear() {
yDataType = null;
public void erase() {
dbc.dropBigMap("tmp_xColumnTypes", xDataTypes);
dbc.dropBigMap("tmp_recordList", recordList);
dbc.dropDatabase();

dbName = null;
dbc = null;
dbConf = null;

//Ensures that the Dataset can't be used after erase() is called.
yDataType = null;
xDataTypes = null;
recordList = null;
}

/**
Expand Down
Expand Up @@ -226,6 +226,7 @@ protected void _fit(Dataset trainingData) {
sampledTrainingDataset = sampledTrainingDataset.copy();
}
mlclassifier.fit(sampledTrainingDataset, weakClassifierTrainingParameters);
sampledTrainingDataset.erase();
sampledTrainingDataset = null;


Expand All @@ -237,7 +238,9 @@ protected void _fit(Dataset trainingData) {
mlclassifier = null;

Status status = updateObservationAndClassifierWeights(validationDataset, observationWeights, sampledIDs);

if(copyData) {
validationDataset.erase();
}
validationDataset = null;

if(status==Status.STOP) {
Expand Down
Expand Up @@ -230,7 +230,8 @@ protected static void denormalizeY(Dataset data, Map<Object, Double> minColumnVa
return;
}

if(data.getYDataType()==TypeInference.DataType.NUMERICAL) {
TypeInference.DataType dataType = data.getYDataType();
if(dataType==TypeInference.DataType.NUMERICAL || dataType==null) {

for(Integer rId : data) {
Record r = data.get(rId);
Expand All @@ -242,13 +243,17 @@ protected static void denormalizeY(Dataset data, Map<Object, Double> minColumnVa
Object denormalizedY = null;
Object denormalizedYPredicted = null;
if(min.equals(max)) {
denormalizedY = min;
if(r.getY()!=null) {
denormalizedY = min;
}
if(r.getYPredicted()!=null) {
denormalizedYPredicted = min;
}
}
else {
denormalizedY = TypeInference.toDouble(r.getY())*(max-min) + min;
if(r.getY()!=null) {
denormalizedY = TypeInference.toDouble(r.getY())*(max-min) + min;
}

Double YPredicted = TypeInference.toDouble(r.getYPredicted());
if(YPredicted!=null) {
Expand Down
Expand Up @@ -109,6 +109,7 @@ public VM kFoldCrossValidation(Dataset dataset, int k, String dbName, DatabaseCo
trainingData = trainingData.copy();
}
mlmodel.fit(trainingData, trainingParameters);
trainingData.erase();
trainingData = null;


Expand All @@ -118,6 +119,7 @@ public VM kFoldCrossValidation(Dataset dataset, int k, String dbName, DatabaseCo
}
//fetch validation metrics
VM entrySample = mlmodel.validate(validationData);
validationData.erase();
validationData = null;


Expand Down
Expand Up @@ -179,6 +179,7 @@ protected void _fit(Dataset trainingData) {
mlregressor = BaseMLmodel.newInstance(trainingParameters.getRegressionClass(), dbName, knowledgeBase.getDbConf());

mlregressor.fit(copiedTrainingData, trainingParameters.getRegressionTrainingParameters());
copiedTrainingData.erase();
copiedTrainingData = null;
}

Expand Down
Expand Up @@ -63,32 +63,12 @@ public static Map<Object, List<String>> stringListsFromTextFiles(Map<Object, URI
return listsMap;
}

public static Dataset parseFromTextLists(Map<Object, List<String>> dataset, TextExtractor textExtractor, DatabaseConfiguration dbConf) {
Dataset data = new Dataset(dbConf);
public static Dataset parseFromTextFiles(Map<Object, URI> dataFiles, TextExtractor textExtractor, DatabaseConfiguration dbConf) {
Dataset dataset = new Dataset(dbConf);
Logger logger = LoggerFactory.getLogger(DatasetBuilder.class);

//loop throw the map and process each category file
for(Map.Entry<Object, List<String>> entry : dataset.entrySet()) {
Object theClass = entry.getKey();
List<String> textList = entry.getValue();

logger.info("Dataset Parsing " + theClass + " class");

for(String text : textList) {
//extract features of the string and add every keyword combination in X map
data.add(new Record(new AssociativeArray(textExtractor.extract(StringCleaner.clear(text))), theClass));
}
}

return data;
}

public static Dataset parseFromTextFiles(Map<Object, URI> dataset, TextExtractor textExtractor, DatabaseConfiguration dbConf) {
Dataset data = new Dataset(dbConf);
Logger logger = LoggerFactory.getLogger(DatasetBuilder.class);

//loop throw the map and process each category file
for(Map.Entry<Object, URI> entry : dataset.entrySet()) {
for(Map.Entry<Object, URI> entry : dataFiles.entrySet()) {
Object theClass = entry.getKey();
URI datasetURI = entry.getValue();

Expand All @@ -99,15 +79,15 @@ public static Dataset parseFromTextFiles(Map<Object, URI> dataset, TextExtractor
//read strings one by one
for(String line; (line = br.readLine()) != null; ) {
//extract features of the string and add every keyword combination in X map
data.add(new Record(new AssociativeArray(textExtractor.extract(StringCleaner.clear(line))), theClass));
dataset.add(new Record(new AssociativeArray(textExtractor.extract(StringCleaner.clear(line))), theClass));
}
}
catch (IOException ex) {
throw new RuntimeException(ex);
}
}

return data;
return dataset;
}

}
Expand Up @@ -54,7 +54,7 @@ public void testTrainAndValidate() {
Dataset[] data = Datasets.carsNumeric(dbConf);
Dataset trainingData = data[0];

Dataset newData = data[1];
Dataset validationData = data[1];


String dbName = "JUnit";
Expand Down Expand Up @@ -97,20 +97,23 @@ public void testTrainAndValidate() {

instance = new Modeler(dbName, dbConf);

instance.validate(newData);
instance.validate(validationData);



Map<Integer, Object> expResult = new HashMap<>();
Map<Integer, Object> result = new HashMap<>();
for(Integer rId : newData) {
Record r = newData.get(rId);
for(Integer rId : validationData) {
Record r = validationData.get(rId);
expResult.put(rId, r.getY());
result.put(rId, r.getYPredicted());
}
assertEquals(expResult, result);

instance.erase();

trainingData.erase();
validationData.erase();
}

}
Expand Up @@ -111,22 +111,23 @@ public void testTrainAndPredict() throws URISyntaxException, MalformedURLExcepti


instance = new TextClassifier(dbName, dbConf);
Dataset testDataset = null;
Dataset validationDataset = null;
try {
testDataset = instance.predict(TestUtils.getRemoteFile(new URL("http://www.datumbox.com/files/datasets/example.test")));
validationDataset = instance.predict(TestUtils.getRemoteFile(new URL("http://www.datumbox.com/files/datasets/example.test")));
}
catch(Exception ex) {
TestUtils.log(this.getClass(), "Unable to download datasets, skipping test.");
return;
}

List<Object> expResult = Arrays.asList("negative","positive");
for(Integer rId : testDataset) {
Record r = testDataset.get(rId);
for(Integer rId : validationDataset) {
Record r = validationDataset.get(rId);
assertEquals(expResult.get(rId), r.getYPredicted());
}

instance.erase();
validationDataset.erase();
}

}

0 comments on commit c21ac59

Please sign in to comment.