Skip to content

Commit

Permalink
Refactored GaussianDPMM to use Vectors and Matrices instead of double…
Browse files Browse the repository at this point in the history
… arrays.
  • Loading branch information
datumbox committed Dec 29, 2016
1 parent 1a825d8 commit b76dc5e
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 65 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -64,6 +64,7 @@ Version 0.8.0-SNAPSHOT - Build 20161229
- Added minAbsolute and maxAbsolute methods in Descriptives. - Added minAbsolute and maxAbsolute methods in Descriptives.
- All the preprocessors have the ability to limit the transformed columns. - All the preprocessors have the ability to limit the transformed columns.
- Refactored PCA to use Vectors and Matrices instead of double arrays. - Refactored PCA to use Vectors and Matrices instead of double arrays.
- Refactored GaussianDPMM to use Vectors and Matrices instead of double arrays.
- Flattening the featureselection package. - Flattening the featureselection package.


Version 0.7.0 - Build 20160319 Version 0.7.0 - Build 20160319
Expand Down
1 change: 0 additions & 1 deletion TODO.txt
@@ -1,7 +1,6 @@
CODE IMPROVEMENTS CODE IMPROVEMENTS
================= =================


- Remove double arrays from GaussianDPMM with Vectors and Matrices.
- Rewrite the FeatureSelection package: - Rewrite the FeatureSelection package:
- Improve the API of Feature Selection and how we handle different data types. - Improve the API of Feature Selection and how we handle different data types.
- Refactor AbstractCategoricalFeatureSelector and simplify the method calls. - Refactor AbstractCategoricalFeatureSelector and simplify the method calls.
Expand Down
Expand Up @@ -19,7 +19,6 @@
import com.datumbox.framework.common.dataobjects.MatrixDataframe; import com.datumbox.framework.common.dataobjects.MatrixDataframe;
import com.datumbox.framework.common.dataobjects.Record; import com.datumbox.framework.common.dataobjects.Record;
import com.datumbox.framework.common.storageengines.interfaces.StorageEngine; import com.datumbox.framework.common.storageengines.interfaces.StorageEngine;
import com.datumbox.framework.common.utilities.PHPMethods;
import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer; import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer;
import com.datumbox.framework.core.machinelearning.common.abstracts.algorithms.AbstractDPMM; import com.datumbox.framework.core.machinelearning.common.abstracts.algorithms.AbstractDPMM;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClusterer; import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClusterer;
Expand Down Expand Up @@ -57,19 +56,19 @@ public static class Cluster extends AbstractDPMM.AbstractCluster {
private final int kappa0; private final int kappa0;
private final int nu0; private final int nu0;
private final RealVector mu0; private final RealVector mu0;
private final OpenMapRealMatrix psi0; private final RealMatrix psi0;


//cluster parameters //cluster parameters
private RealVector mean; private RealVector mean;
private OpenMapRealMatrix covariance; private RealMatrix covariance;


//validation - confidence interval vars //validation - confidence interval vars
private OpenMapRealMatrix meanError; private RealMatrix meanError;
private int meanDf; private int meanDf;


//internal vars for calculation //internal vars for calculation
private RealVector xi_sum; private RealVector xi_sum;
private OpenMapRealMatrix xi_square_sum; private RealMatrix xi_square_sum;


//Cache //Cache
private volatile Double cache_covariance_determinant; private volatile Double cache_covariance_determinant;
Expand All @@ -84,32 +83,27 @@ public static class Cluster extends AbstractDPMM.AbstractCluster {
* @param psi0 * @param psi0
* @see AbstractClusterer.AbstractCluster * @see AbstractClusterer.AbstractCluster
*/ */
protected Cluster(Integer clusterId, int dimensions, int kappa0, int nu0, RealVector mu0, OpenMapRealMatrix psi0) { protected Cluster(Integer clusterId, int dimensions, int kappa0, int nu0, RealVector mu0, RealMatrix psi0) {
super(clusterId); super(clusterId);


if(nu0<dimensions) { if(nu0<dimensions) {
nu0 = dimensions; nu0 = dimensions;
} }


if(mu0==null) {
mu0 = new OpenMapRealVector(dimensions); //0 vector
}

if(psi0==null) {
psi0 = createIdentityOpenMapRealMatrix(dimensions); //identity matrix
}

mean = new OpenMapRealVector(dimensions); mean = new OpenMapRealVector(dimensions);
covariance = createIdentityOpenMapRealMatrix(dimensions); covariance = new OpenMapRealMatrix(dimensions, dimensions);
for(int i=0;i<dimensions;i++) {
covariance.setEntry(i, i, 1.0);
}


meanError = calculateMeanError(psi0, kappa0, nu0); meanError = calculateMeanError(psi0, kappa0, nu0);
meanDf = nu0-dimensions+1; meanDf = nu0-dimensions+1;




this.kappa0 = kappa0; this.kappa0 = kappa0;
this.nu0 = nu0; this.nu0 = nu0;
this.mu0 = mu0; this.mu0 = new OpenMapRealVector(mu0);
this.psi0 = psi0; this.psi0 = new OpenMapRealMatrix(dimensions, dimensions).add(psi0);
this.dimensions = dimensions; this.dimensions = dimensions;


xi_sum = new OpenMapRealVector(dimensions); xi_sum = new OpenMapRealVector(dimensions);
Expand All @@ -118,20 +112,6 @@ protected Cluster(Integer clusterId, int dimensions, int kappa0, int nu0, RealVe
cache_covariance_inverse = null; cache_covariance_inverse = null;
} }


/**
* Creates an Identity matrix of OpenMapRealMatrix type.
*
* @param n
* @return
*/
private OpenMapRealMatrix createIdentityOpenMapRealMatrix(int n) {
OpenMapRealMatrix id = new OpenMapRealMatrix(n,n);
for(int i=0;i<n;i++) {
id.setEntry(i,i,1.0);
}
return id;
}

/** /**
* Ensure that the cluster parameters can be modified. * Ensure that the cluster parameters can be modified.
*/ */
Expand Down Expand Up @@ -203,8 +183,8 @@ protected void add(Record r) {
RealVector rv = MatrixDataframe.parseRecord(r, featureIds); RealVector rv = MatrixDataframe.parseRecord(r, featureIds);


//update cluster clusterParameters //update cluster clusterParameters
xi_sum=xi_sum.add(rv); xi_sum = xi_sum.add(rv);
xi_square_sum= (OpenMapRealMatrix) xi_square_sum.add(rv.outerProduct(rv)); xi_square_sum = xi_square_sum.add(rv.outerProduct(rv));


size++; size++;


Expand All @@ -230,9 +210,9 @@ protected void remove(Record r) {
updateClusterParameters(); updateClusterParameters();
} }


private OpenMapRealMatrix calculateMeanError(OpenMapRealMatrix Psi, int kappa, int nu) { private RealMatrix calculateMeanError(RealMatrix Psi, int kappa, int nu) {
//Reference: page 18, equation 228 at http://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf //Reference: page 18, equation 228 at http://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf
return (OpenMapRealMatrix)Psi.scalarMultiply(1.0/(kappa*(nu-dimensions+1.0))); return Psi.scalarMultiply(1.0/(kappa*(nu-dimensions+1.0)));
} }


/** {@inheritDoc} */ /** {@inheritDoc} */
Expand All @@ -258,14 +238,12 @@ protected void updateClusterParameters() {


RealMatrix C = xi_square_sum.subtract( ( mu.outerProduct(mu) ).scalarMultiply(size) ); RealMatrix C = xi_square_sum.subtract( ( mu.outerProduct(mu) ).scalarMultiply(size) );


OpenMapRealMatrix psi = (OpenMapRealMatrix)psi0.add( C.add( ( mu_mu0.outerProduct(mu_mu0) ).scalarMultiply(kappa0*size/(double)kappa_n) )); RealMatrix psi = psi0.add( C.add( ( mu_mu0.outerProduct(mu_mu0) ).scalarMultiply(kappa0*size/(double)kappa_n) ));//
//C = null;
//mu_mu0 = null;


mean = ( mu0.mapMultiply(kappa0) ).add( mu.mapMultiply(size) ).mapDivide(kappa_n); mean = ( mu0.mapMultiply(kappa0) ).add( mu.mapMultiply(size) ).mapDivide(kappa_n);


synchronized(this) { synchronized(this) {
covariance = (OpenMapRealMatrix)psi.scalarMultiply( (kappa_n+1.0)/(kappa_n*(nu - dimensions + 1.0)) ); covariance = psi.scalarMultiply( (kappa_n+1.0)/(kappa_n*(nu - dimensions + 1.0)) );
cache_covariance_determinant = null; cache_covariance_determinant = null;
cache_covariance_inverse = null; cache_covariance_inverse = null;
} }
Expand All @@ -291,13 +269,13 @@ protected ModelParameters(StorageEngine storageEngine) {


/** {@inheritDoc} */ /** {@inheritDoc} */
public static class TrainingParameters extends AbstractDPMM.AbstractTrainingParameters { public static class TrainingParameters extends AbstractDPMM.AbstractTrainingParameters {
private static final long serialVersionUID = 1L; private static final long serialVersionUID = 2L;


private int kappa0 = 0; private int kappa0 = 0;
private int nu0 = 1; private int nu0 = 1;
private double[] mu0; private RealVector mu0;


private double[][] psi0; private RealMatrix psi0;


/** /**
* Getter for Kappa0 hyperparameter. * Getter for Kappa0 hyperparameter.
Expand Down Expand Up @@ -340,35 +318,35 @@ public void setNu0(int nu0) {
* *
* @return * @return
*/ */
public double[] getMu0() { public RealVector getMu0() {
return PHPMethods.array_clone(mu0); return mu0;
} }


/** /**
* Getter for Mu0 hyperparameter. * Getter for Mu0 hyperparameter.
* *
* @param mu0 * @param mu0
*/ */
public void setMu0(double[] mu0) { public void setMu0(RealVector mu0) {
this.mu0 = PHPMethods.array_clone(mu0); this.mu0 = mu0;
} }


/** /**
* Getter for Psi0 hyperparameter. * Getter for Psi0 hyperparameter.
* *
* @return * @return
*/ */
public double[][] getPsi0() { public RealMatrix getPsi0() {
return PHPMethods.array_clone(psi0); return psi0;
} }


/** /**
* Setter for Psi0 hyperparameter. * Setter for Psi0 hyperparameter.
* *
* @param psi0 * @param psi0
*/ */
public void setPsi0(double[][] psi0) { public void setPsi0(RealMatrix psi0) {
this.psi0 = PHPMethods.array_clone(psi0); this.psi0 = psi0;
} }


} }
Expand Down Expand Up @@ -397,22 +375,14 @@ protected GaussianDPMM(String storageName, Configuration configuration) {
protected Cluster createNewCluster(Integer clusterId) { protected Cluster createNewCluster(Integer clusterId) {
ModelParameters modelParameters = knowledgeBase.getModelParameters(); ModelParameters modelParameters = knowledgeBase.getModelParameters();
TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters(); TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters();
double[] mu0 = trainingParameters.getMu0();
double[][] psi0 = trainingParameters.getPsi0();

OpenMapRealMatrix mPsi0 = null;
if(psi0 != null) {
mPsi0 = new OpenMapRealMatrix(psi0.length, psi0.length);
mPsi0.setSubMatrix(psi0, 0, 0);
}


Cluster c = new Cluster( Cluster c = new Cluster(
clusterId, clusterId,
modelParameters.getD(), modelParameters.getD(),
trainingParameters.getKappa0(), trainingParameters.getKappa0(),
trainingParameters.getNu0(), trainingParameters.getNu0(),
mu0!=null?new OpenMapRealVector(mu0):null, trainingParameters.getMu0(),
mPsi0 trainingParameters.getPsi0()


); );
c.setFeatureIds(modelParameters.getFeatureIds()); c.setFeatureIds(modelParameters.getFeatureIds());
Expand Down
Expand Up @@ -24,6 +24,8 @@
import com.datumbox.framework.tests.Constants; import com.datumbox.framework.tests.Constants;
import com.datumbox.framework.tests.Datasets; import com.datumbox.framework.tests.Datasets;
import com.datumbox.framework.tests.abstracts.AbstractTest; import com.datumbox.framework.tests.abstracts.AbstractTest;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.OpenMapRealVector;
import org.junit.Test; import org.junit.Test;


import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
Expand Down Expand Up @@ -58,8 +60,8 @@ public void testPredict() {
param.setInitializationMethod(GaussianDPMM.TrainingParameters.Initialization.ONE_CLUSTER_PER_RECORD); param.setInitializationMethod(GaussianDPMM.TrainingParameters.Initialization.ONE_CLUSTER_PER_RECORD);
param.setKappa0(0); param.setKappa0(0);
param.setNu0(1); param.setNu0(1);
param.setMu0(new double[]{0.0, 0.0}); param.setMu0(new OpenMapRealVector(2));
param.setPsi0(new double[][]{{1.0,0.0},{0.0,1.0}}); param.setPsi0(MatrixUtils.createRealIdentityMatrix(2));


GaussianDPMM instance = MLBuilder.create(param, configuration); GaussianDPMM instance = MLBuilder.create(param, configuration);
instance.fit(trainingData); instance.fit(trainingData);
Expand Down Expand Up @@ -106,8 +108,8 @@ public void testKFoldCrossValidation() {
param.setInitializationMethod(GaussianDPMM.TrainingParameters.Initialization.ONE_CLUSTER_PER_RECORD); param.setInitializationMethod(GaussianDPMM.TrainingParameters.Initialization.ONE_CLUSTER_PER_RECORD);
param.setKappa0(0); param.setKappa0(0);
param.setNu0(1); param.setNu0(1);
param.setMu0(new double[]{0.0, 0.0}); param.setMu0(new OpenMapRealVector(2));
param.setPsi0(new double[][]{{1.0,0.0},{0.0,1.0}}); param.setPsi0(MatrixUtils.createRealIdentityMatrix(2));


ClusteringMetrics vm = new Validator<>(ClusteringMetrics.class, configuration) ClusteringMetrics vm = new Validator<>(ClusteringMetrics.class, configuration)
.validate(new KFoldSplitter(k).split(trainingData), param); .validate(new KFoldSplitter(k).split(trainingData), param);
Expand Down

0 comments on commit b76dc5e

Please sign in to comment.