Skip to content

Commit

Permalink
All methods that have copy() now implement the Copyable interface. Ad…
Browse files Browse the repository at this point in the history
…ded the multi-thread support on the predict method of most ML models. Updated the MapDBConnector to call compact on primary db before close. Updated the DatabaseConnector.getBigMap() support thread-safe maps. Massive restructuring of packages and class names.
  • Loading branch information
datumbox committed Jan 8, 2016
1 parent b1e2bea commit d3aa1ba
Show file tree
Hide file tree
Showing 205 changed files with 2,062 additions and 1,554 deletions.
6 changes: 5 additions & 1 deletion CHANGELOG.md
Expand Up @@ -51,7 +51,11 @@ Version 0.7.0-SNAPSHOT - Build 20160108
- KnowledgeBase is no longer serializable. Its serializable fields are stored individually into the Database.
- Restructuring the framework to remove all FindBug warnings.
- KnowledgeBase is now an interface, while the implementation moved to StandardKnowledgeBase. The interface contains a static factory method to produce any KB. This enable us to define the knowledgeBase field of BaseTrainable private and final.
- Moved SensitivityRates and StepwiseCompatible to the machinelearning/common/ package.
- All methods that have copy() now implement the Copyable interface.
- Added the multi-thread support on the predict method of most ML models.
- Updated the MapDBConnector to call compact on primary db before close.
- Updated the DatabaseConnector.getBigMap() support thread-safe maps.
- Massive restructuring of packages and class names.

Version 0.6.1 - Build 20160102
------------------------------
Expand Down
22 changes: 17 additions & 5 deletions TODO.txt
@@ -1,8 +1,10 @@
CODE IMPROVEMENTS
=================

- Add a predictRecord() method BaseMLmodel and refactor the code to be implemented by every algorithm.
- Add multithreading support.
- Validate that calling compact on MapDB is not reducing the speed too much.
- See of we can make LatentDirichletAllocation use the parallel prediction.
- Add multithreading support on the training part and on some parts of DataTransformation, FeatureSelection.

- Update all maven plugins and dependencies to their latest versions.
- Add support for MapDB 2.0 once a stable version is released.
- Upgrade maven-surefire to 2.19.1.
Expand All @@ -28,8 +30,18 @@ NEW ALGORITHMS
- Add the ability to search through the configuration space and find the best performing algorithmic configuration.


CHECK OUT HUGE COLLECTION LIBS, DBS AND STORAGE
===============================================
TO CHECK OUT
============

Linear Algebra
--------------

- JBLAS - Linear Algebra for Java:
https://github.com/mikiobraun/jblas
http://jblas.org/

Huge Collection libs, DBs and Storage
-------------------------------------

- Java StoredMap + BerkeleyDB:
http://docs.oracle.com/cd/E17277_02/html/java/com/sleepycat/collections/StoredMap.html
Expand All @@ -48,4 +60,4 @@ CHECK OUT HUGE COLLECTION LIBS, DBS AND STORAGE
https://github.com/OpenHFT/Chronicle-Map/

- H2 Database:
http://www.h2database.com/html/main.html
http://www.h2database.com/html/main.html
2 changes: 1 addition & 1 deletion pom.xml
Expand Up @@ -172,7 +172,7 @@
<configuration>
<threadCount>3</threadCount>
<parallel>classes</parallel>
<argLine>-Xmx512M</argLine>
<argLine>-Xmx512M -Dfile.encoding=${project.build.sourceEncoding}</argLine>
</configuration>
</plugin>
<plugin>
Expand Down
35 changes: 18 additions & 17 deletions src/main/java/com/datumbox/applications/datamodeling/Modeler.java
Expand Up @@ -18,10 +18,10 @@
import com.datumbox.common.dataobjects.Dataframe;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConfiguration;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.framework.machinelearning.common.bases.featureselection.FeatureSelection;
import com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLmodel;
import com.datumbox.framework.machinelearning.common.bases.wrappers.BaseWrapper;
import com.datumbox.framework.machinelearning.common.bases.datatransformation.DataTransformer;
import com.datumbox.framework.machinelearning.common.abstracts.featureselectors.AbstractFeatureSelector;
import com.datumbox.framework.machinelearning.common.abstracts.modelers.AbstractAlgorithm;
import com.datumbox.framework.machinelearning.common.abstracts.wrappers.AbstractWrapper;
import com.datumbox.framework.machinelearning.common.abstracts.datatransformers.AbstractTransformer;

/**
* Modeler is a convenience class which can be used to train Machine Learning
Expand All @@ -30,12 +30,12 @@
*
* @author Vasilis Vryniotis <bbriniotis@datumbox.com>
*/
public class Modeler extends BaseWrapper<Modeler.ModelParameters, Modeler.TrainingParameters> {
public class Modeler extends AbstractWrapper<Modeler.ModelParameters, Modeler.TrainingParameters> {

/**
* It contains all the Model Parameters which are learned during the training.
*/
public static class ModelParameters extends BaseWrapper.ModelParameters {
public static class ModelParameters extends AbstractWrapper.ModelParameters {
private static final long serialVersionUID = 1L;

/**
Expand All @@ -51,7 +51,7 @@ protected ModelParameters(DatabaseConnector dbc) {
/**
* It contains the Training Parameters of the Modeler.
*/
public static class TrainingParameters extends BaseWrapper.TrainingParameters<DataTransformer, FeatureSelection, BaseMLmodel> {
public static class TrainingParameters extends AbstractWrapper.TrainingParameters<AbstractTransformer, AbstractFeatureSelector, AbstractAlgorithm> {
private static final long serialVersionUID = 1L;

}
Expand Down Expand Up @@ -85,12 +85,13 @@ public void predict(Dataframe newData) {
* @param testData
* @return
*/
public BaseMLmodel.ValidationMetrics validate(Dataframe testData) {
public AbstractAlgorithm.ValidationMetrics validate(Dataframe testData) {
logger.info("validate()");

return evaluateData(testData, true);
}


/** {@inheritDoc} */
@Override
protected void _fit(Dataframe trainingData) {

Expand All @@ -104,7 +105,7 @@ protected void _fit(Dataframe trainingData) {

boolean transformData = (dtClass!=null);
if(transformData) {
dataTransformer = DataTransformer.<DataTransformer>newInstance(dtClass, dbName, dbConf);
dataTransformer = AbstractTransformer.<AbstractTransformer>newInstance(dtClass, dbName, dbConf);
dataTransformer.fit_transform(trainingData, trainingParameters.getDataTransformerTrainingParameters());
}

Expand All @@ -114,7 +115,7 @@ protected void _fit(Dataframe trainingData) {

boolean selectFeatures = (fsClass!=null);
if(selectFeatures) {
featureSelection = FeatureSelection.<FeatureSelection>newInstance(fsClass, dbName, dbConf);
featureSelection = AbstractFeatureSelector.<AbstractFeatureSelector>newInstance(fsClass, dbName, dbConf);
featureSelection.fit_transform(trainingData, trainingParameters.getFeatureSelectionTrainingParameters());
}

Expand All @@ -123,7 +124,7 @@ protected void _fit(Dataframe trainingData) {

//initialize mlmodel
Class mlClass = trainingParameters.getMLmodelClass();
mlmodel = BaseMLmodel.<BaseMLmodel>newInstance(mlClass, dbName, dbConf);
mlmodel = AbstractAlgorithm.<AbstractAlgorithm>newInstance(mlClass, dbName, dbConf);

//train the mlmodel on the whole dataset
mlmodel.fit(trainingData, trainingParameters.getMLmodelTrainingParameters());
Expand All @@ -133,7 +134,7 @@ protected void _fit(Dataframe trainingData) {
}
}

private BaseMLmodel.ValidationMetrics evaluateData(Dataframe data, boolean estimateValidationMetrics) {
private AbstractAlgorithm.ValidationMetrics evaluateData(Dataframe data, boolean estimateValidationMetrics) {
//ensure db loaded
kb().load();
Modeler.TrainingParameters trainingParameters = kb().getTrainingParameters();
Expand All @@ -145,7 +146,7 @@ private BaseMLmodel.ValidationMetrics evaluateData(Dataframe data, boolean estim
boolean transformData = (dtClass!=null);
if(transformData) {
if(dataTransformer==null) {
dataTransformer = DataTransformer.<DataTransformer>newInstance(dtClass, dbName, dbConf);
dataTransformer = AbstractTransformer.<AbstractTransformer>newInstance(dtClass, dbName, dbConf);
}
dataTransformer.transform(data);
}
Expand All @@ -155,7 +156,7 @@ private BaseMLmodel.ValidationMetrics evaluateData(Dataframe data, boolean estim
boolean selectFeatures = (fsClass!=null);
if(selectFeatures) {
if(featureSelection==null) {
featureSelection = FeatureSelection.<FeatureSelection>newInstance(fsClass, dbName, dbConf);
featureSelection = AbstractFeatureSelector.<AbstractFeatureSelector>newInstance(fsClass, dbName, dbConf);
}

//remove unnecessary features
Expand All @@ -166,12 +167,12 @@ private BaseMLmodel.ValidationMetrics evaluateData(Dataframe data, boolean estim
//initialize mlmodel
if(mlmodel==null) {
Class mlClass = trainingParameters.getMLmodelClass();
mlmodel = BaseMLmodel.<BaseMLmodel>newInstance(mlClass, dbName, dbConf);
mlmodel = AbstractAlgorithm.<AbstractAlgorithm>newInstance(mlClass, dbName, dbConf);
}

//call predict of the mlmodel for the new dataset

BaseMLmodel.ValidationMetrics vm = null;
AbstractAlgorithm.ValidationMetrics vm = null;
if(estimateValidationMetrics) {
//run validate which calculates validation metrics. It is used by validate() method
vm = mlmodel.validate(data);
Expand Down
10 changes: 5 additions & 5 deletions src/main/java/com/datumbox/applications/nlp/CETR.java
Expand Up @@ -19,10 +19,10 @@
import com.datumbox.common.dataobjects.Dataframe;
import com.datumbox.common.dataobjects.FlatDataCollection;
import com.datumbox.common.dataobjects.Record;
import com.datumbox.common.objecttypes.Parameterizable;
import com.datumbox.common.interfaces.Parameterizable;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConfiguration;
import com.datumbox.common.utilities.MapFunctions;
import com.datumbox.common.utilities.PHPfunctions;
import com.datumbox.common.utilities.MapMethods;
import com.datumbox.common.utilities.PHPMethods;
import com.datumbox.framework.machinelearning.clustering.Kmeans;
import com.datumbox.framework.statistics.descriptivestatistics.Descriptives;
import com.datumbox.framework.utilities.text.cleaners.HTMLCleaner;
Expand Down Expand Up @@ -224,7 +224,7 @@ private List<Integer> selectRows(List<String> rows, Parameters parameters) {
}

//fetch the cluster with the smallest average.
Map.Entry<Object, Double> entry = MapFunctions.selectMinKeyValue(avgTTRscorePerCluster);
Map.Entry<Object, Double> entry = MapMethods.selectMinKeyValue(avgTTRscorePerCluster);

//this cluster is considered the non-content cluster
Integer nonContentClusterId = (Integer)entry.getKey();
Expand Down Expand Up @@ -401,7 +401,7 @@ private List<String> extractRows(String text) {

private String clearText(String text) {
text = HTMLCleaner.removeNonTextTagsAndAttributes(text); //remove all the irrelevant HTML Tags that are not related to the text (such as forms, scripts etc)
if(PHPfunctions.substr_count(text, '\n')<=1) { //if the document is in a single line (no spaces), then break it in order for this algorithm to work
if(PHPMethods.substr_count(text, '\n')<=1) { //if the document is in a single line (no spaces), then break it in order for this algorithm to work
text = text.replace(">", ">\n");
}

Expand Down

0 comments on commit d3aa1ba

Please sign in to comment.