From 21abb88be334542aea7d072a28186cf5cd06e39d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Veronika=20Maurerov=C3=A1?= Date: Thu, 12 Dec 2019 20:46:35 +0100 Subject: [PATCH] PUBDEV-6447 constrained kmeans POC (#4067) * constrained kmeans implementation * add new parameters to V3 schema, fix bug, add doc * enhance junit test * implement cross validation * improve find minimal reduced weight loop, remove demo * remove cluster_size_constraints parameter from python/R api * remove constrained kmeans from doc, remove python test * remove constrained kmeans from doc * set constrained kmeans test aside * ignore chicago dataset test --- .../src/main/java/hex/kmeans/KMeans.java | 339 ++++++-- .../src/main/java/hex/kmeans/KMeansModel.java | 1 + .../java/hex/kmeans/KMeansSimplexSolver.java | 755 ++++++++++++++++++ .../src/test/java/hex/kmeans/KMeansTest.java | 15 +- .../hex/kmeans/KmeansConstrainedTest.java | 262 ++++++ h2o-docs/src/product/data-science/k-means.rst | 2 - 6 files changed, 1306 insertions(+), 68 deletions(-) create mode 100644 h2o-algos/src/main/java/hex/kmeans/KMeansSimplexSolver.java create mode 100644 h2o-algos/src/test/java/hex/kmeans/KmeansConstrainedTest.java diff --git a/h2o-algos/src/main/java/hex/kmeans/KMeans.java b/h2o-algos/src/main/java/hex/kmeans/KMeans.java index a63f13140cdd..d10c9cd0a3b9 100755 --- a/h2o-algos/src/main/java/hex/kmeans/KMeans.java +++ b/h2o-algos/src/main/java/hex/kmeans/KMeans.java @@ -73,13 +73,16 @@ public enum Initialization { Random, PlusPlus, Furthest, User } error("_user_y", "User-specified points do not refer to a valid frame"); else if (user_points.numCols() != _train.numCols() - numSpecialCols()) error("_user_y", "The user-specified points must have the same number of columns (" + (_train.numCols() - - numSpecialCols()) + ") as the training observations"); + numSpecialCols()) + ") as the training observations"); else if( user_points.numRows() != _parms._k) error("_user_y", "The number of rows in the user-specified points is not equal to k = " + _parms._k); } if (_parms._estimate_k) { if (_parms._user_points!=null) error("_estimate_k", "Cannot estimate k if user_points are provided."); + if(_parms._cluster_size_constraints != null){ + error("_estimate_k", "Cannot estimate k if cluster_size_constraints are provided."); + } info("_seed", "seed is ignored when estimate_k is enabled."); info("_init", "Initialization scheme is ignored when estimate_k is enabled - algorithm is deterministic."); if (expensive) { @@ -95,6 +98,11 @@ else if( user_points.numRows() != _parms._k) } } } + if(_parms._cluster_size_constraints != null){ + if(_parms._cluster_size_constraints.length != _parms._k){ + error("_cluster_size_constraints", "\"The number of cluster size constraints is not equal to k = \" + _parms._k"); + } + } if (expensive && error_count() == 0) checkMemoryFootPrint(); } @@ -174,7 +182,7 @@ private final class KMeansDriver extends Driver { transient private int _reinit_attempts; // Handle the case where some centers go dry. Rescue only 1 cluster // per iteration ('cause we only tracked the 1 worst row) - boolean cleanupBadClusters( LloydsIterationTask task, final Vec[] vecs, final double[][] centers, final double[] means, final double[] mults, final int[] modes ) { + boolean cleanupBadClusters( IterationTask task, final Vec[] vecs, final double[][] centers, final double[] means, final double[] mults, final int[] modes ) { // Find any bad clusters int clu; for( clu=0; clu 0: "The value of constraint should be higher then zero."; + csum += _parms._cluster_size_constraints[i]; + assert csum <= vecs[0].length(): "The sum of constraints is higher than the number of data rows."; + } + } + for (int k = startK; k <= _parms._k; ++k) { - Log.info("Running Lloyds iteration for " + k + " centroids."); + if(!constrained){ + Log.info("Running Lloyds iteration for " + k + " centroids."); + } else { + Log.info("Running Constrained K-means iteration for " + k + " centroids."); + } model._output._iterations = 0; // Loop ends only when iterations > max_iterations with strict inequality double[][] lo=null, hi=null; boolean stop = false; - do { //Lloyds algorithm + do { assert(centers.length == k); - LloydsIterationTask task = new LloydsIterationTask(centers, means, mults, impute_cat, _isCats, k, hasWeightCol()).doAll(vecs2); //1 PASS OVER THE DATA + IterationTask task; + if(!constrained) { + //Lloyds algorithm + task = new LloydsIterationTask(centers, means, mults, impute_cat, _isCats, k, hasWeightCol()).doAll(vecs2); //1 PASS OVER THE DATA + } else { + // Constrained K-means + + // Get distances and aggregated values + CalculateDistancesTask countDistancesTask = new CalculateDistancesTask(centers, means, mults, impute_cat, _isCats, k, hasWeightCol()).doAll(vecs2); + + // Check if the constraint setting does not break cross validation setting + assert !hasWeightCol() || csum <= countDistancesTask._non_zero_weights : "The sum of constraints ("+csum+") is higher than the number of data rows with non zero weights ("+countDistancesTask._non_zero_weights+") because cross validation is set."; + + // Calculate center assignments + // Experimental code. Polynomial implementation - slow performance. Need to be parallelize! + KMeansSimplexSolver solver = new KMeansSimplexSolver(_parms._cluster_size_constraints, new Frame(vecs2), countDistancesTask._sum, hasWeightCol(), countDistancesTask._non_zero_weights); + + // Get cluster assignments + Frame result = solver.assignClusters(); + + // Count statistics and result task + task = new CalculateMetricTask(centers, means, mults, impute_cat, _isCats, k, hasWeightCol()).doAll(result); + } + // Pick the max categorical level for cluster center max_cats(task._cMeans, task._cats, _isCats); // Handle the case where some centers go dry. Rescue only 1 cluster // per iteration ('cause we only tracked the 1 worst row) - if( !_parms._estimate_k && cleanupBadClusters(task,vecs,centers,means,mults,impute_cat) ) continue; + // If constrained K-meas is set, clusters with zero points are allowed + if(!_parms._estimate_k && _parms._cluster_size_constraints == null && cleanupBadClusters(task,vecs,centers,means,mults,impute_cat) ) continue; // Compute model stats; update standardized cluster centers centers = computeStatsFillModel(task, model, vecs, means, mults, impute_cat, k); - if (model._parms._score_each_iteration) + if (model._parms._score_each_iteration) Log.info(model._output._model_summary); lo = task._lo; hi = task._hi; if (work_unit_iter) { model.update(_job); // Update model in K/V store - _job.update(1); //1 more Lloyds iteration + _job.update(1); //1 more iteration } stop = (task._reassigned_count < Math.max(1,train().numRows()*TOLERANCE) || model._output._iterations >= _parms._max_iterations || stop_requested()); if (stop) { if (model._output._iterations < _parms._max_iterations) - Log.info("Lloyds converged after " + model._output._iterations + " iterations."); + Log.info("K-means converged after " + model._output._iterations + " iterations."); else - Log.info("Lloyds stopped after " + model._output._iterations + " iterations."); + Log.info("K-means stopped after " + model._output._iterations + " iterations."); } } while (!stop); @@ -355,11 +409,18 @@ public void computeImpl() { centers = splitLargestCluster(centers, lo, hi, means, mults, impute_cat, vecs2, k); } //k-finder vecs2[vecs2.length-1].remove(); - + // Create metrics by scoring on training set otherwise scores are based on last Lloyd iteration - model.score(_train).delete(); - model._output._training_metrics = ModelMetrics.getFromDKV(model,_train); + // These lines cause the training metrics are recalculated on strange model values. + // Especially for Constrained Kmeans, it returns a result that does not meet the constraints set + // because scoring is based on calculated centroids and does not preserve the constraints + // There is a JIRA to explore this part of code: https://0xdata.atlassian.net/browse/PUBDEV-7097 + if(!constrained) { + model.score(_train).delete(); + model._output._training_metrics = ModelMetrics.getFromDKV(model, _train); + } + model.update(_job); // Update model in K/V store Log.info(model._output._model_summary); Log.info(model._output._scoring_history); Log.info(((ModelMetricsClustering)model._output._training_metrics).createCentroidStatsTable().toString()); @@ -619,15 +680,8 @@ private static class Sampler extends MRTask { _sampled = ArrayUtils.append(_sampled, other._sampled); } } - - // --------------------------------------- - // A Lloyd's pass: - // Find nearest cluster center for every point - // Compute new mean/center & variance & rows for each cluster - // Compute distance between clusters - // Compute total sqr distance - - private static class LloydsIterationTask extends MRTask { + + public static class IterationTask extends MRTask { // IN double[][] _centers; double[] _means, _mults; // Standardization @@ -646,7 +700,7 @@ private static class LloydsIterationTask extends MRTask { long _worst_row; // Row with max err double _worst_err; // Max-err-row's max-err - LloydsIterationTask(double[][] centers, double[] means, double[] mults, int[] modes, String[][] isCats, int k, boolean hasWeight ) { + IterationTask(double[][] centers, double[] means, double[] mults, int[] modes, String[][] isCats, int k, boolean hasWeight ) { _centers = centers; _means = means; _mults = mults; @@ -655,6 +709,20 @@ private static class LloydsIterationTask extends MRTask { _k = k; _hasWeight = hasWeight; } + } + + // --------------------------------------- + // A Lloyd's pass: + // Find nearest cluster center for every point + // Compute new mean/center & variance & rows for each cluster + // Compute distance between clusters + // Compute total sqr distance + + private static class LloydsIterationTask extends IterationTask { + + LloydsIterationTask(double[][] centers, double[] means, double[] mults, int[] modes, String[][] isCats, int k, boolean hasWeight ) { + super(centers, means, mults, modes, isCats, k, hasWeight); + } @Override public void map(Chunk[] cs) { int N = cs.length - (_hasWeight ? 1:0) - 1 /*clusterassignment*/; @@ -705,7 +773,7 @@ private static class LloydsIterationTask extends MRTask { for( int col = 0; col < N; col++ ) if( _isCats[col] != null ) _cats[clu][col][(int)values[col]]++; // Histogram the cats - else + else _cMeans[clu][col] += values[col]; // Sum the column centers _size[clu]++; // Track worst row @@ -719,7 +787,153 @@ private static class LloydsIterationTask extends MRTask { _modes = null; } - @Override public void reduce(LloydsIterationTask mr) { + @Override public void reduce(IterationTask mr) { + _reassigned_count += mr._reassigned_count; + for( int clu = 0; clu < _k; clu++ ) { + long ra = _size[clu]; + long rb = mr._size[clu]; + double[] ma = _cMeans[clu]; + double[] mb = mr._cMeans[clu]; + for( int c = 0; c < ma.length; c++ ) // Recursive mean + if( ra+rb > 0 ) ma[c] = (ma[c] * ra + mb[c] * rb) / (ra + rb); + } + ArrayUtils.add(_cats, mr._cats); + ArrayUtils.add(_cSqr, mr._cSqr); + ArrayUtils.add(_size, mr._size); + for( int clu=0; clu< _k; clu++ ) { + for( int col=0; col<_lo[clu].length; col++ ) { + _lo[clu][col] = Math.min(mr._lo[clu][col], _lo[clu][col]); + _hi[clu][col] = Math.max(mr._hi[clu][col], _hi[clu][col]); + } + } + // track global worst-row + if( _worst_err < mr._worst_err) { _worst_err = mr._worst_err; _worst_row = mr._worst_row; } + } + } + + private static class CalculateDistancesTask extends MRTask { + // IN + double[][] _centers; + double[] _means, _mults; // Standardization + int[] _modes; // Imputation of missing categoricals + final int _k; + boolean _hasWeight; + final String[][] _isCats; + double _sum; + long _non_zero_weights; + + CalculateDistancesTask(double[][] centers, double[] means, double[] mults, int[] modes, String[][] isCats, int k, boolean hasWeight) { + _centers = centers; + _means = means; + _mults = mults; + _modes = modes; + _k = k; + _hasWeight = hasWeight; + _isCats = isCats; + _sum = 0; + _non_zero_weights = 0; + } + + @Override + public void map(Chunk[] cs) { + int N = cs.length - (_hasWeight ? 1 : 0) - 3 - 2*_centers.length /*data + weight column + distances + edge indices + old assignment + new assignment */; + assert _centers[0].length == N; + int vecsStart = _hasWeight ? N+1 : N; + + double[] values = new double[N]; // Temp data to hold row as doubles + for (int row = 0; row < cs[0]._len; row++) { + double weight = _hasWeight ? cs[N].atd(row) : 1; + if (weight == 0) continue; //skip holdout rows + _non_zero_weights++; + assert (weight == 1); //K-Means only works for weight 1 (or weight 0 for holdout) + data(values, cs, row, _means, _mults, _modes); // Load row as doubles + double[] distances = getDistances(_centers, values, _isCats); + for(int cluster=0; cluster _worst_err) { _worst_err = distance; _worst_row = cs[0].start()+row; } + } + // Scale back down to local mean + for( int clu = 0; clu < _k; clu++ ) + if( _size[clu] != 0 ) ArrayUtils.div(_cMeans[clu], _size[clu]); + _centers = null; + _means = _mults = null; + _modes = null; + } + + @Override public void reduce(IterationTask mr) { _reassigned_count += mr._reassigned_count; for( int clu = 0; clu < _k; clu++ ) { long ra = _size[clu]; @@ -774,6 +988,15 @@ private static ClusterDist closest(double[][] centers, double[] point, String[][ return cd; // Return for flow-coding } + /** Return square-distance of point to all clusters. */ + private static double[] getDistances(double[][] centers, double[] point, String[][] isCats) { + double[] distances = new double[centers.length]; + for( int cluster = 0; cluster < centers.length; cluster++ ) { + distances[cluster] = hex.genmodel.GenModel.KMeans_distance(centers[cluster],point,isCats); + } + return distances; + } + // KMeans++ re-clustering private static double[][] recluster(double[][] points, Random rand, int N, Initialization init, String[][] isCats) { double[][] res = new double[N][]; @@ -781,38 +1004,38 @@ private static double[][] recluster(double[][] points, Random rand, int N, Initi int count = 1; ClusterDist cd = new ClusterDist(); switch( init ) { - case Random: - break; - case PlusPlus: { // k-means++ - while( count < res.length ) { - double sum = 0; - for (double[] point1 : points) sum += minSqr(res, point1, isCats, cd, count); - - for (double[] point : points) { - if (minSqr(res, point, isCats, cd, count) >= rand.nextDouble() * sum) { - res[count++] = point; - break; + case Random: + break; + case PlusPlus: { // k-means++ + while( count < res.length ) { + double sum = 0; + for (double[] point1 : points) sum += minSqr(res, point1, isCats, cd, count); + + for (double[] point : points) { + if (minSqr(res, point, isCats, cd, count) >= rand.nextDouble() * sum) { + res[count++] = point; + break; + } } } + break; } - break; - } - case Furthest: { // Takes cluster center further from any already chosen ones - while( count < res.length ) { - double max = 0; - int index = 0; - for( int i = 0; i < points.length; i++ ) { - double sqr = minSqr(res, points[i], isCats, cd, count); - if( sqr > max ) { - max = sqr; - index = i; + case Furthest: { // Takes cluster center further from any already chosen ones + while( count < res.length ) { + double max = 0; + int index = 0; + for( int i = 0; i < points.length; i++ ) { + double sqr = minSqr(res, points[i], isCats, cd, count); + if( sqr > max ) { + max = sqr; + index = i; + } } + res[count++] = points[index]; } - res[count++] = points[index]; + break; } - break; - } - default: throw H2O.fail(); + default: throw H2O.fail(); } return res; } @@ -959,3 +1182,5 @@ private static class SplitTask extends MRTask { } } } + + diff --git a/h2o-algos/src/main/java/hex/kmeans/KMeansModel.java b/h2o-algos/src/main/java/hex/kmeans/KMeansModel.java index fc3d6a1881ef..227769ad2694 100755 --- a/h2o-algos/src/main/java/hex/kmeans/KMeansModel.java +++ b/h2o-algos/src/main/java/hex/kmeans/KMeansModel.java @@ -39,6 +39,7 @@ public static class KMeansParameters extends ClusteringModel.ClusteringParameter public boolean _pred_indicator = false; // For internal use only: generate indicator cols during prediction // Ex: k = 4, cluster = 3 -> [0, 0, 1, 0] public boolean _estimate_k = false; // If enabled, iteratively find up to _k clusters + public int[] _cluster_size_constraints = null; } public static class KMeansOutput extends ClusteringModel.ClusteringOutput { diff --git a/h2o-algos/src/main/java/hex/kmeans/KMeansSimplexSolver.java b/h2o-algos/src/main/java/hex/kmeans/KMeansSimplexSolver.java new file mode 100644 index 000000000000..5f911087e507 --- /dev/null +++ b/h2o-algos/src/main/java/hex/kmeans/KMeansSimplexSolver.java @@ -0,0 +1,755 @@ +package hex.kmeans; + +import water.Iced; +import water.MRTask; +import water.fvec.Chunk; +import water.fvec.Frame; +import water.fvec.Vec; + +import java.util.ArrayList; +import java.util.Collections; + +/** + * Experimental code. Polynomial implementation - slow performance. Need to be parallelize! + * Calculate Minimal Cost Flow problem using simplex method with go through spanning tree. + * Used to solve minimal cluster size constraints in K-means. + */ +class KMeansSimplexSolver { + public int _constraintsLength; + public long _numberOfPoints; + public long _edgeSize; + public long _nodeSize; + public long _resultSize; + + // Input graph to store K-means configuration + public Vec _demands; // store demand of all nodes (-1 for data points, constraints values for constraints nodes, ) + public Vec _capacities; // store capacities of all edges + edges from all node to leader node + public Frame _weights; // input data + weight column, calculated distances from all points to all centres + columns to store result of cluster assignments + public Vec _additiveWeights; // additive weight vector to store edge weight from constraints nodes to additive and leader nodes + public double _sumWeights; // calculated sum of all weights to calculate maximal capacity value + public double _maxAbsDemand; // maximal absolute demand to calculate maximal capacity value + public boolean _hasWeightsColumn; // weight column existence flag + public long _numberOfNonZeroWeightPoints; //if weights columns is set, how many rows has non zero weight + + // Spanning tree to calculate min cost flow + public SpanningTree tree; + + /** + * Construct K-means simplex solver. + * @param constrains + * @param weights + * @param sumDistances + * @param hasWeights + */ + public KMeansSimplexSolver(int[] constrains, Frame weights, double sumDistances, boolean hasWeights, long numberOfNonZeroWeightPoints) { + this._numberOfPoints = weights.numRows(); + this._nodeSize = this._numberOfPoints + constrains.length + 1; + this._edgeSize = _numberOfPoints * constrains.length + constrains.length; + this._constraintsLength = constrains.length; + + this._demands = Vec.makeCon(0, _nodeSize, Vec.T_NUM); + this._capacities = Vec.makeCon(0, _edgeSize + _nodeSize, Vec.T_NUM); + this._resultSize = this._numberOfPoints * _constraintsLength; + this._hasWeightsColumn = hasWeights; + this._numberOfNonZeroWeightPoints = numberOfNonZeroWeightPoints; + + this._weights = weights; + this._additiveWeights = Vec.makeCon(0, _nodeSize + _constraintsLength, Vec.T_NUM); + this._sumWeights = sumDistances; + + long constraintsSum = 0; + _maxAbsDemand = Double.MIN_VALUE; + for (long i = 0; i < _nodeSize; i++) { + if (i < _numberOfPoints) { + _demands.set(i, -1); + } else { + long tmpDemand; + if (i < _nodeSize - 1) { + tmpDemand = constrains[(int)(i - _numberOfPoints)]; + constraintsSum += constrains[(int)(i - _numberOfPoints)]; + } else { + tmpDemand = _numberOfNonZeroWeightPoints - constraintsSum; + } + _demands.set(i, tmpDemand); + if (tmpDemand > Math.abs(_maxAbsDemand)) { + _maxAbsDemand = Math.abs(tmpDemand); + } + } + } + + int edgeIndexStart = _weights.numCols() - 3 - _constraintsLength; + long edgeIndex = 0; + for (long i = 0; i < _weights.numRows(); i++) { + for(int j=0; j < _constraintsLength; j++){ + _weights.vec(edgeIndexStart+j).set(i, edgeIndex++); + } + } + + this.tree = new SpanningTree(_nodeSize, _edgeSize, _constraintsLength); + } + + /** + * Initialize graph and spanning tree. + */ + public void init() { + // always start with infinity _capacities + for (long i = 0; i < _edgeSize; i++) { + _capacities.set(i, Long.MAX_VALUE); + } + + // find maximum value for capacity + double maxCapacity = 3 * (_sumWeights > _maxAbsDemand ? _sumWeights : _maxAbsDemand); + + // fill max capacity from the leader node to all others _nodes + for (long i = 0; i < _nodeSize; i++) { + _additiveWeights.set(i + _constraintsLength, maxCapacity); + _capacities.set(i + _edgeSize, maxCapacity); + } + + tree.init(_numberOfPoints, maxCapacity, _demands); + } + + /** + * Get weight base on edge index from weights data or from additive weights. + * @param edgeIndex + * @return + */ + public double getWeight(long edgeIndex) { + long numberOfFrameWeights = this._numberOfPoints * this._constraintsLength; + if (edgeIndex < numberOfFrameWeights) { + long i = Math.round(edgeIndex / _constraintsLength); + int j = _weights.numCols() - 2 * _constraintsLength - 3 + (int)(edgeIndex % _constraintsLength); + return _weights.vec(j).at(i); + } + return _additiveWeights.at(edgeIndex - numberOfFrameWeights); + } + + /** + * Get weight base on edge index from weights data or from additive weights. + * @param edgeIndex + * @return + */ + public boolean isNonZeroWeight(long edgeIndex) { + if(_hasWeightsColumn) { + long numberOfFrameWeights = this._numberOfPoints * this._constraintsLength; + if (edgeIndex < numberOfFrameWeights) { + long i = Math.round(edgeIndex / _constraintsLength); + int j = _weights.numCols() - 1 - 2 * _constraintsLength - 3; + return _weights.vec(j).at8(i) == 1; + } + } + return true; + } + + /** + * Check edges flow where constraints flows and additive node flow should be zero at the end of calculation. + * @return true if the flows are not zero yet + */ + private boolean checkIfContinue() { + for (long i = tree._edgeFlow.length() - 1; i > tree._edgeFlow.length() - _constraintsLength - 2; i--) { + if (tree._edgeFlow.at8(i) != 0) { + return true; + } + } + return false; + } + + public long findMinimalReducedWeight() { + FindMinimalWeightTask t = new FindMinimalWeightTask(tree, _hasWeightsColumn, _constraintsLength).doAll(_weights); + double minimalWeight = t.minimalWeight; + long minimalIndex = t.minimalIndex; + long additiveEdgesIndexStart = _weights.vec(0).length() * _constraintsLength; + for(long i = additiveEdgesIndexStart; i < _edgeSize; i++){ + double tmpWeight = tree.reduceWeight(i, getWeight(i)); + boolean countValue = !_hasWeightsColumn || isNonZeroWeight(i); + if (countValue && tmpWeight < minimalWeight) { + minimalWeight = tmpWeight; + minimalIndex = i; + } + } + return minimalIndex; + } + + /** + * Find next entering edge to find cycle. + * @return index of the edge + */ + public Edge findNextEnteringEdge() { + if(checkIfContinue()) { + long minimalIndex = findMinimalReducedWeight(); + if (tree._edgeFlow.at8(minimalIndex) == 0) { + return new Edge(minimalIndex, tree._sources.at8(minimalIndex), tree._targets.at8(minimalIndex)); + } else { + return new Edge(minimalIndex, tree._targets.at8(minimalIndex), tree._sources.at8(minimalIndex)); + } + } + return null; + } + + /** + * Find cycle from the edge defined by source and target nodes to leader node and back. + * @param edgeIndex + * @param sourceIndex + * @param targetIndex + * @return + */ + public NodesEdgesObject getCycle(long edgeIndex, long sourceIndex, long targetIndex) { + long ancestor = tree.findAncestor(sourceIndex, targetIndex); + NodesEdgesObject resultPath = tree.getPath(sourceIndex, ancestor); + resultPath.reverseNodes(); + resultPath.reverseEdges(); + if (resultPath.edgeSize() != 1 || resultPath.getEdge(0) != edgeIndex) { + resultPath.addEdge(edgeIndex); + } + NodesEdgesObject resultPathBack = tree.getPath(targetIndex, ancestor); + resultPathBack.removeLastNode(); + resultPath.addAllNodes(resultPathBack.getNodes()); + resultPath.addAllEdges(resultPathBack.getEdges()); + return resultPath; + } + + /** + * Find the leaving edge with minimal residual capacity. + * @param cycle + * @return the edge with minimal residual capacity + */ + public Edge getLeavingEdge(NodesEdgesObject cycle) { + cycle.reverseNodes(); + cycle.reverseEdges(); + double minResidualCapacity = Double.MAX_VALUE; + int minIndex = -1; + for (int i = 0; i < cycle.edgeSize(); i++) { + double tmpResidualCapacity = tree.getResidualCapacity(cycle.getEdge(i), cycle.getNode(i), _capacities.at(cycle.getEdge(i))); + boolean countValue = !_hasWeightsColumn || isNonZeroWeight(cycle.getEdge(i)); + if (countValue && tmpResidualCapacity < minResidualCapacity) { + minResidualCapacity = tmpResidualCapacity; + minIndex = i; + } + } + assert minIndex != -1; + long nodeIndex = cycle.getNode(minIndex); + long edgeIndex = cycle.getEdge(minIndex); + return new Edge(edgeIndex, nodeIndex, nodeIndex == tree._sources.at8(edgeIndex) ? tree._targets.at8(edgeIndex) : tree._sources.at8(edgeIndex)); + } + + /** + * Loop over all entering edges to find minimal cost flow in spanning tree. + */ + public void pivotLoop() { + Edge edge = findNextEnteringEdge(); + while (edge != null) { + long enteringEdgeIndex = edge.getEdgeIndex(); + long enteringEdgeSourceIndex = edge.getSourceIndex(); + long enteringEdgeTargetIndex = edge.getTargetIndex(); + NodesEdgesObject cycle = getCycle(enteringEdgeIndex, enteringEdgeSourceIndex, enteringEdgeTargetIndex); + Edge leavingEdge = getLeavingEdge(cycle); + long leavingEdgeIndex = leavingEdge.getEdgeIndex(); + long leavingEdgeSourceIndex = leavingEdge.getSourceIndex(); + long leavingEdgeTargetIndex = leavingEdge.getTargetIndex(); + double residualCap = tree.getResidualCapacity(leavingEdgeIndex, leavingEdgeSourceIndex, _capacities.at(leavingEdgeIndex)); + if(residualCap != 0) { + tree.augmentFlow(cycle, residualCap); + } + if (enteringEdgeIndex != leavingEdgeIndex) { + if (leavingEdgeSourceIndex != tree._parents.at8(leavingEdgeTargetIndex)) { + long tmpS = leavingEdgeSourceIndex; + leavingEdgeSourceIndex = leavingEdgeTargetIndex; + leavingEdgeTargetIndex = tmpS; + } + if (cycle.indexOfEdge(enteringEdgeIndex) < cycle.indexOfEdge(leavingEdgeIndex)) { + long tmpP = enteringEdgeSourceIndex; + enteringEdgeSourceIndex = enteringEdgeTargetIndex; + enteringEdgeTargetIndex = tmpP; + } + tree.removeParentEdge(leavingEdgeSourceIndex, leavingEdgeTargetIndex); + tree.makeRoot(enteringEdgeTargetIndex); + tree.addEdge(enteringEdgeIndex, enteringEdgeSourceIndex, enteringEdgeTargetIndex); + tree.updatePotentials(enteringEdgeIndex, enteringEdgeSourceIndex, enteringEdgeTargetIndex, getWeight(enteringEdgeIndex)); + } + edge = findNextEnteringEdge(); + } + } + + public void checkInfeasibility() { + assert !tree.isInfeasible(): "Spanning tree to calculate K-means cluster assignments is not infeasible."; + } + + public void checkConstraintsCondition(int[] numberOfPointsInCluster){ + for(int i = 0; i<_constraintsLength; i++){ + assert numberOfPointsInCluster[i] >= _demands.at8(_numberOfPoints+i) : String.format("Cluster %d has %d assigned points however should has assigned at least %d points.", i+1, numberOfPointsInCluster[i], _demands.at8(_numberOfPoints+i)); + } + } + + /** + * Initialize graph and working spanning tree, calculate minimal cost flow and check if result flow is correct. + */ + private void calculateMinimalCostFlow() { + init(); + pivotLoop(); + checkInfeasibility(); + } + + /** + * Calculate minimal cost flow and based on flow assign cluster to all data points. + * @return input data with new cluster assignments + */ + public Frame assignClusters() { + calculateMinimalCostFlow(); + int distanceAssignmentIndex = _weights.numCols() - 3; + int oldAssignmentIndex = _weights.numCols() - 2; + int newAssignmentIndex = _weights.numCols() - 1; + int dataStopLength = _weights.numCols() - (_hasWeightsColumn ? 1 : 0) - 2 * _constraintsLength - 3; + + int[] numberOfPointsInCluster = new int[_constraintsLength]; + for(int i = 0; i<_constraintsLength; i++){ + numberOfPointsInCluster[i] = 0; + } + + for (long i = 0; i < _weights.numRows(); i++) { + if(!_hasWeightsColumn || _weights.vec(dataStopLength).at8(i) == 1) { + for (int j = 0; j < _constraintsLength; j++) { + //long edgeIndex = i + j * _weights.numRows(); + long edgeIndex = i * _constraintsLength + j; + if (tree._edgeFlow.at8(edgeIndex) == 1) { + // old assignment + _weights.vec(oldAssignmentIndex).set(i, _weights.vec(newAssignmentIndex).at(i)); + // new assignment + _weights.vec(newAssignmentIndex).set(i, j); + // distances + _weights.vec(distanceAssignmentIndex).set(i, _weights.vec(dataStopLength + j + (_hasWeightsColumn ? 1 : 0)).at(i)); + numberOfPointsInCluster[j]++; + break; + } + } + } + } + checkConstraintsCondition(numberOfPointsInCluster); + //remove distances columns + for(int i = 0; i < 2 * _constraintsLength; i++) { + _weights.remove(dataStopLength+(_hasWeightsColumn ? 1 : 0)); + } + return _weights; + } +} + +/** + * Experimental + * Class to store calculation of flow in minimal cost flow problem. + */ +class SpanningTree extends Iced { + + public long _nodeSize; + public long _edgeSize; + public long _secondLayerSize; + + + public Vec _sources; // edge size + node size + public Vec _targets; // edge size + node size + + public Vec _edgeFlow; // edge size + node size, integer + public Vec _nodePotentials; // node size, long + public Vec _parents; // node size + 1, integer + public Vec _parentEdges; // node size + 1, integer + public Vec _subtreeSize; // node size + 1, integer + public Vec _nextDepthFirst; // node size + 1, integer + public Vec _previousNodes; // node size + 1, integer + public Vec _lastDescendants; // node size + 1, integer + + SpanningTree(long nodeSize, long edgeSize, long secondLayerSize){ + this._nodeSize = nodeSize; + this._edgeSize = edgeSize; + this._secondLayerSize = secondLayerSize; + + this._sources = Vec.makeCon(0, edgeSize+nodeSize, Vec.T_NUM); + this._targets = Vec.makeCon(0, edgeSize+nodeSize, Vec.T_NUM); + + this._edgeFlow = Vec.makeCon(0, edgeSize+nodeSize, Vec.T_NUM); + this._nodePotentials = Vec.makeCon(0, nodeSize, Vec.T_NUM); + this._parents = Vec.makeCon(0, nodeSize+1, Vec.T_NUM); + this._parentEdges = Vec.makeCon(0, nodeSize+1, Vec.T_NUM); + this._subtreeSize = Vec.makeCon(1, nodeSize+1, Vec.T_NUM); + this._nextDepthFirst = Vec.makeCon(0, nodeSize+1, Vec.T_NUM); + this._previousNodes = Vec.makeCon(0, nodeSize+1, Vec.T_NUM); + this._lastDescendants = Vec.makeCon(0, nodeSize+1, Vec.T_NUM); + } + + public void init(long numberOfPoints, double maxCapacity, Vec demands){ + for (long i = 0; i < _nodeSize; i++) { + if (i < numberOfPoints) { + for (int j = 0; j < _secondLayerSize; j++) { + _sources.set(i * _secondLayerSize + j, i); + _targets.set(i * _secondLayerSize + j, numberOfPoints + j); + } + } else { + if (i < _nodeSize - 1) { + _sources.set(numberOfPoints* _secondLayerSize +i-numberOfPoints, i); + _targets.set(numberOfPoints* _secondLayerSize +i-numberOfPoints, _nodeSize - 1); + } + } + } + + for (long i = 0; i < _nodeSize; i++) { + long demand = demands.at8(i); + if (demand >= 0) { + _sources.set(_edgeSize + i, _nodeSize); + _targets.set(_edgeSize + i, i); + } else { + _sources.set(_edgeSize + i, i); + _targets.set(_edgeSize + i, _nodeSize); + } + if (i < _nodeSize - 1) { + _nextDepthFirst.set(i, i + 1); + } + _edgeFlow.set(_edgeSize +i, Math.abs(demand)); + _nodePotentials.set(i, demand < 0 ? maxCapacity : -maxCapacity); + _parents.set(i, _nodeSize); + _parentEdges.set(i, i + _edgeSize); + _previousNodes.set(i, i - 1); + _lastDescendants.set(i, i); + } + _parents.set(_nodeSize, -1); + _subtreeSize.set(_nodeSize, _nodeSize + 1); + _nextDepthFirst.set(_nodeSize - 1, _nodeSize); + _previousNodes.set(0, _nodeSize); + _previousNodes.set(_nodeSize, _nodeSize - 1); + _lastDescendants.set(_nodeSize, _nodeSize - 1); + } + + public boolean isInfeasible() { + for(long i = _edgeFlow.length() - _secondLayerSize + 1; i < _edgeFlow.length(); i++) { + if(_edgeFlow.at8(i) > 0){ + return true; + } + } + return false; + } + + public long findAncestor(long sourceIndex, long targetIndex) { + long subtreeSizeSource = _subtreeSize.at8(sourceIndex); + long subtreeSizeTarget = _subtreeSize.at8(targetIndex); + while (true) { + while (subtreeSizeSource < subtreeSizeTarget) { + sourceIndex = _parents.at8(sourceIndex); + subtreeSizeSource = _subtreeSize.at8(sourceIndex); + } + while (subtreeSizeSource > subtreeSizeTarget) { + targetIndex = _parents.at8(targetIndex); + subtreeSizeTarget = _subtreeSize.at8(targetIndex); + } + if (subtreeSizeSource == subtreeSizeTarget) { + if (sourceIndex !=targetIndex) { + sourceIndex = _parents.at8(sourceIndex); + subtreeSizeSource = _subtreeSize.at8(sourceIndex); + targetIndex = _parents.at8(targetIndex); + subtreeSizeTarget = _subtreeSize.at8(targetIndex); + } else { + return sourceIndex; + } + } + } + } + + public double reduceWeight(long edgeIndex, double weight) { + double newWeight = weight - _nodePotentials.at(_sources.at8(edgeIndex)) + _nodePotentials.at(_targets.at8(edgeIndex)); + return _edgeFlow.at8(edgeIndex) == 0 ? newWeight : -newWeight; + } + + public NodesEdgesObject getPath(long node, long ancestor) { + NodesEdgesObject result = new NodesEdgesObject(); + result.addNode(node); + while (node != ancestor) { + result.addEdge(_parentEdges.at8(node)); + node = _parents.at8(node); + result.addNode(node); + } + return result; + } + + public double getResidualCapacity(long edgeIndex, long nodeIndex, double capacity) { + return nodeIndex == _sources.at8(edgeIndex) ? capacity - _edgeFlow.at(edgeIndex) : _edgeFlow.at(edgeIndex); + } + + public void augmentFlow(NodesEdgesObject nodesEdges, double flow) { + for (int i = 0; i < nodesEdges.edgeSize(); i++) { + long edge = nodesEdges.getEdge(i); + long node = nodesEdges.getNode(i); + if (node == _sources.at8(edge)) { + _edgeFlow.set(edge, _edgeFlow.at(edge) + flow); + } else { + _edgeFlow.set(edge, _edgeFlow.at(edge) - flow); + } + } + } + + public void removeParentEdge(long sourceIndex, long targetIndex) { + long subtreeSizeTarget = _subtreeSize.at8(targetIndex); + long previousTargetIndex = _previousNodes.at8(targetIndex); + long lastTargetIndex = _lastDescendants.at8(targetIndex); + long nextTargetIndex = _nextDepthFirst.at8(lastTargetIndex); + + _parents.set(targetIndex, -1); + _parentEdges.set(targetIndex, -1); + _nextDepthFirst.set(previousTargetIndex, nextTargetIndex); + _previousNodes.set(nextTargetIndex, previousTargetIndex); + _nextDepthFirst.set(lastTargetIndex, targetIndex); + _previousNodes.set(targetIndex, lastTargetIndex); + while (sourceIndex != -1) { + _subtreeSize.set(sourceIndex, _subtreeSize.at8(sourceIndex) - subtreeSizeTarget); + if (lastTargetIndex == _lastDescendants.at8(sourceIndex)) { + _lastDescendants.set(sourceIndex, previousTargetIndex); + } + sourceIndex = _parents.at8(sourceIndex); + } + } + + public void makeRoot(long nodeIndex) { + ArrayList ancestors = new ArrayList<>(); + while (nodeIndex != -1) { + ancestors.add(nodeIndex); + nodeIndex = _parents.at8(nodeIndex); + } + Collections.reverse(ancestors); + for (int i = 0; i < ancestors.size() - 1; i++) { + long sourceIndex = ancestors.get(i); + long targetIndex = ancestors.get(i + 1); + long subtreeSizeSource = _subtreeSize.at8(sourceIndex); + long lastSourceIndex = _lastDescendants.at8(sourceIndex); + long prevTargetIndex = _previousNodes.at8(targetIndex); + long lastTargetIndex = _lastDescendants.at8(targetIndex); + long nextTargetIndex = _nextDepthFirst.at8(lastTargetIndex); + + _parents.set(sourceIndex, targetIndex); + _parents.set(targetIndex, -1); + _parentEdges.set(sourceIndex, _parentEdges.at8(targetIndex)); + _parentEdges.set(targetIndex, -1); + _subtreeSize.set(sourceIndex, subtreeSizeSource - _subtreeSize.at8(targetIndex)); + _subtreeSize.set(targetIndex, subtreeSizeSource); + + _nextDepthFirst.set(prevTargetIndex, nextTargetIndex); + _previousNodes.set(nextTargetIndex, prevTargetIndex); + _nextDepthFirst.set(lastTargetIndex, targetIndex); + _previousNodes.set(targetIndex, lastTargetIndex); + + if (lastSourceIndex == lastTargetIndex) { + _lastDescendants.set(sourceIndex, prevTargetIndex); + lastSourceIndex = prevTargetIndex; + } + _previousNodes.set(sourceIndex, lastTargetIndex); + _nextDepthFirst.set(lastTargetIndex, sourceIndex); + _nextDepthFirst.set(lastSourceIndex, targetIndex); + _previousNodes.set(targetIndex, lastSourceIndex); + _lastDescendants.set(targetIndex, lastSourceIndex); + } + } + + public void addEdge(long edgeIndex, long sourceIndex, long targetIndex) { + long lastSourceIndex = _lastDescendants.at8(sourceIndex); + long nextSourceIndex = _nextDepthFirst.at8(lastSourceIndex); + long subtreeSizeTarget = _subtreeSize.at8(targetIndex); + long lastTargetIndex = _lastDescendants.at8(targetIndex); + + _parents.set(targetIndex, sourceIndex); + _parentEdges.set(targetIndex, edgeIndex); + + _nextDepthFirst.set(lastSourceIndex, targetIndex); + _previousNodes.set(targetIndex, lastSourceIndex); + _previousNodes.set(nextSourceIndex, lastTargetIndex); + _nextDepthFirst.set(lastTargetIndex, nextSourceIndex); + + while (sourceIndex != -1) { + _subtreeSize.set(sourceIndex, _subtreeSize.at8(sourceIndex) + subtreeSizeTarget); + if (lastSourceIndex == _lastDescendants.at8(sourceIndex)) { + _lastDescendants.set(sourceIndex, lastTargetIndex); + } + sourceIndex = _parents.at8(sourceIndex); + + } + } + + + public void updatePotentials(long edgeIndex, long sourceIndex, long targetIndex, double weight) { + double potential; + if (targetIndex == _targets.at8(edgeIndex)) { + potential = _nodePotentials.at(sourceIndex) - weight - _nodePotentials.at(targetIndex); + } else { + potential = _nodePotentials.at(sourceIndex) + weight - _nodePotentials.at(targetIndex); + } + _nodePotentials.set(targetIndex, _nodePotentials.at(targetIndex) + potential); + long last = _lastDescendants.at8(targetIndex); + while (targetIndex != last) { + targetIndex = _nextDepthFirst.at8(targetIndex); + _nodePotentials.set(targetIndex, _nodePotentials.at(targetIndex) + potential); + } + } +} + +class Edge { + + private long _edgeIndex; + private long _sourceIndex; + private long _targetIndex; + + public Edge(long edgeIndex, long sourceIndex, long targetIndex) { + this._edgeIndex = edgeIndex; + this._sourceIndex = sourceIndex; + this._targetIndex = targetIndex; + } + + public long getEdgeIndex() { + return _edgeIndex; + } + + public long getSourceIndex() { + return _sourceIndex; + } + + public long getTargetIndex() { + return _targetIndex; + } + + @Override + public String toString() { + return _edgeIndex+" "+_sourceIndex+" "+_targetIndex; + } +} + +class NodesEdgesObject { + + private ArrayList _nodes; + private ArrayList _edges; + + public NodesEdgesObject() { + this._nodes = new ArrayList<>(); + this._edges = new ArrayList<>(); + } + + public void addNode(long node){ + _nodes.add(node); + } + + public void removeLastNode(){ + _nodes.remove(_nodes.size()-1); + } + + public long getNode(int index){ + return _nodes.get(index); + } + + public ArrayList getNodes() { + return _nodes; + } + + public int nodeSize(){ + return _nodes.size(); + } + + public void addEdge(long edge){ + _edges.add(edge); + } + + public long getEdge(int index){ + return _edges.get(index); + } + + public ArrayList getEdges() { + return _edges; + } + + public int edgeSize(){ + return _edges.size(); + } + + public int indexOfEdge(long value){ + return _edges.indexOf(value); + } + + + public ArrayList getReversedNodes() { + ArrayList reversed = new ArrayList<>(_nodes); + Collections.reverse(reversed); + return reversed; + } + + public ArrayList getReversedEdges() { + ArrayList reversed = new ArrayList<>(_edges); + Collections.reverse(reversed); + return reversed; + } + + public void reverseNodes(){ + Collections.reverse(_nodes); + } + + public void reverseEdges(){ + Collections.reverse(_edges); + } + + public void addAllNodes(ArrayList newNodes){ + _nodes.addAll(newNodes); + } + + public void addAllEdges(ArrayList newEdges){ + _edges.addAll(newEdges); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder("NEO: nodes: "); + for (long i: _nodes) { + sb.append(i+" "); + } + sb.append("edges: "); + for (long i: _edges) { + sb.append(i+" "); + } + sb.deleteCharAt(sb.length()-1); + return sb.toString(); + } +} + + +/** + * Map Reduce task to find minimal reduced weight (distance). + */ +class FindMinimalWeightTask extends MRTask { + // IN + SpanningTree _tree; + boolean _hasWeightsColumn; + int _constraintsLength; + + //OUT + double minimalWeight = Double.MAX_VALUE; + long minimalIndex = -1; + + FindMinimalWeightTask(SpanningTree tree, boolean hasWeightsColumn, int constraintsLength) { + _tree = tree; + _hasWeightsColumn = hasWeightsColumn; + _constraintsLength = constraintsLength; + } + + @Override + public void map(Chunk[] cs) { + int startDistancesIndex = cs.length - 2 * _constraintsLength - 3; + int startEdgeIndex = cs.length - 3 - _constraintsLength; + for (int i = 0; i < cs[0]._len; i++) { + for (int j = 0; j < _constraintsLength; j++) { + double weight = cs[startDistancesIndex + j].atd(i); + long edgeIndex = cs[startEdgeIndex + j].at8(i); + double tmpWeight = _tree.reduceWeight(edgeIndex, weight); + boolean countValue = !_hasWeightsColumn || cs[startDistancesIndex-1].at8(i) == 1; + if (countValue && tmpWeight < minimalWeight) { + minimalWeight = tmpWeight; + minimalIndex = edgeIndex; + } + } + } + } + + @Override + public void reduce(FindMinimalWeightTask mrt) { + if (mrt.minimalWeight < minimalWeight) { + minimalIndex = mrt.minimalIndex; + minimalWeight = mrt.minimalWeight; + } + } +} diff --git a/h2o-algos/src/test/java/hex/kmeans/KMeansTest.java b/h2o-algos/src/test/java/hex/kmeans/KMeansTest.java index 141c6519e4c4..9ecd4baf3bda 100755 --- a/h2o-algos/src/test/java/hex/kmeans/KMeansTest.java +++ b/h2o-algos/src/test/java/hex/kmeans/KMeansTest.java @@ -20,9 +20,7 @@ import java.util.Comparator; import java.util.Random; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; public class KMeansTest extends TestUtil { public final double threshold = 1e-6; @@ -34,8 +32,8 @@ private static KMeansModel doSeed( KMeansModel.KMeansParameters parms, long seed KMeans job = new KMeans(parms); KMeansModel kmm = job.trainModel().get(); checkConsistency(kmm); - for( int i=0; imake(), new String[]{"x", "y"}, new Vec[]{Vec.makeVec(new long[]{1, 2, 4}, null, Vec.newKey()), + Vec.makeVec(new long[]{1, 2, 3}, null, Vec.newKey())}); + DKV.put(fr); + + points = new Frame(Key.make(), new String[]{"x", "y"}, new Vec[]{Vec.makeVec(new long[]{1, 3}, null, Vec.newKey()), + Vec.makeVec(new long[]{2, 4}, null, Vec.newKey())}); + DKV.put(points); + + KMeansModel.KMeansParameters parms = new KMeansModel.KMeansParameters(); + parms._train = fr._key; + parms._k = 2; + parms._standardize = false; + parms._max_iterations = 10; + parms._cluster_size_constraints = new int[]{1, 1}; + parms._user_points = points._key; + KMeans job = new KMeans(parms); + kmm = job.trainModel().get(); + + // Done building model; produce a score column with cluster choices + fr2 = kmm.score(fr); + for(int i=0; i= parms._cluster_size_constraints[i] : "Minimal size of cluster "+(i+1)+" should be "+parms._cluster_size_constraints[i]+" but is "+kmm._output._size[i]+"."; + } + + } finally { + if( fr != null ) fr.delete(); + if( fr2 != null ) fr2.delete(); + if( points != null ) points.delete(); + if( kmm != null ) kmm.delete(); + } + } + + @Test + public void testNfoldsConstrained() { + Frame tfr = null, points = null; + KMeansModel kmeans = null; + + Scope.enter(); + try { + points = ArrayUtils.frame(ard( + ard(6.0,2.2,4.0,1.0,0), + ard(5.2,3.4,1.4,0.2,1), + ard(6.9,3.1,5.4,2.1,2) + )); + + tfr = parse_test_file("smalldata/iris/iris_wheader.csv"); + DKV.put(tfr); + KMeansModel.KMeansParameters parms = new KMeansModel.KMeansParameters(); + parms._train = tfr._key; + parms._seed = 0xdecaf; + parms._k = 3; + parms._cluster_size_constraints = new int[]{20, 20, 20}; + parms._nfolds = 3; + parms._user_points = points._key; + + KMeans job = new KMeans(parms); + kmeans = job.trainModel().get(); + + ModelMetricsClustering mm = (ModelMetricsClustering)kmeans._output._cross_validation_metrics; + assertNotNull(mm); + } finally { + if (tfr != null) tfr.remove(); + if (kmeans != null) { + kmeans.deleteCrossValidationModels(); + kmeans.delete(); + } + if(points != null) points.remove(); + Scope.exit(); + } + } + + @Test + public void testIrisConstrained() { + KMeansModel kmm = null, kmm2 = null, kmm3 = null, kmm4 = null; + Frame fr = null, points=null, predict1=null, predict2=null, predict3=null, predict4=null; + try { + Scope.enter(); + fr = Scope.track(parse_test_file("smalldata/iris/iris_wheader.csv")); + + points = ArrayUtils.frame(ard( + ard(4.9, 3.0, 1.4, 0.2), + ard(5.6, 2.5, 3.9, 1.1), + ard(6.5, 3.0, 5.2, 2.0) + )); + + KMeansModel.KMeansParameters parms = new KMeansModel.KMeansParameters(); + parms._train = fr._key; + parms._k = 3; + parms._standardize = true; + parms._max_iterations = 10; + parms._user_points = points._key; + parms._cluster_size_constraints = new int[]{49, 46, 55}; + parms._score_each_iteration = true; + parms._ignored_columns = new String[]{"class"}; + + System.out.println("Constrained Kmeans strandardize true (CKT)"); + KMeans job = new KMeans(parms); + kmm = (KMeansModel) Scope.track_generic(job.trainModel().get()); + + for(int i=0; i="+parms._cluster_size_constraints[i]); + assert kmm._output._size[i] >= parms._cluster_size_constraints[i] : "Minimal size of cluster "+(i+1)+" should be "+parms._cluster_size_constraints[i]+" but is "+kmm._output._size[i]+"."; + } + + KMeansModel.KMeansParameters parms3 = new KMeansModel.KMeansParameters(); + parms3._train = fr._key; + parms3._k = 3; + parms3._standardize = true; + parms3._max_iterations = 10; + parms3._user_points = points._key; + parms3._score_each_iteration = true; + parms3._ignored_columns = new String[]{"class"}; + + System.out.println("Loyd Kmeans strandardize true (FKT)"); + KMeans job3 = new KMeans(parms3); + kmm3 = (KMeansModel) Scope.track_generic(job3.trainModel().get()); + + KMeansModel.KMeansParameters parms2 = new KMeansModel.KMeansParameters(); + parms2._train = fr._key; + parms2._k = 3; + parms2._standardize = false; + parms2._max_iterations = 10; + parms2._user_points = points._key; + parms2._score_each_iteration = true; + parms2._ignored_columns = new String[]{"class"}; + parms2._cluster_size_constraints = new int[]{50, 61, 39}; + + System.out.println("Constrained Kmeans strandardize false (CKF)"); + KMeans job2 = new KMeans(parms2); + kmm2 = (KMeansModel) Scope.track_generic(job2.trainModel().get()); + + for(int i=0; i="+parms2._cluster_size_constraints[i]); + assert kmm2._output._size[i] >= parms2._cluster_size_constraints[i] : "Minimal size of cluster "+(i+1)+" should be "+parms2._cluster_size_constraints[i]+" but is "+kmm2._output._size[i]+"."; + } + + KMeansModel.KMeansParameters parms4 = new KMeansModel.KMeansParameters(); + parms4._train = fr._key; + parms4._k = 3; + parms4._standardize = false; + parms4._max_iterations = 10; + parms4._user_points = points._key; + parms4._score_each_iteration = true; + parms4._ignored_columns = new String[]{"class"}; + + System.out.println("Loyd Kmeans strandardize false (FKF)"); + KMeans job4 = new KMeans(parms4); + kmm4 = (KMeansModel) Scope.track_generic(job4.trainModel().get()); + + + predict1 = kmm.score(fr); + predict2 = kmm2.score(fr); + predict3 = kmm3.score(fr); + predict4 = kmm4.score(fr); + + System.out.println("\nPredictions:"); + System.out.println(" | CKT | FKT | CKF | FKF |"); + for (int i=0; i="+parms._cluster_size_constraints[i]); + assert kmm._output._size[i] >= parms._cluster_size_constraints[i] : "Minimal size of cluster "+(i+1)+" should be "+parms._cluster_size_constraints[i]+" but is "+kmm._output._size[i]+"."; + } + + parms._standardize = false; + KMeans job2 = new KMeans(parms); + kmm2 = (KMeansModel) Scope.track_generic(job2.trainModel().get()); + + for(int i=0; i="+parms._cluster_size_constraints[i]); + assert kmm2._output._size[i] >= parms._cluster_size_constraints[i] : "Minimal size of cluster "+(i+1)+" should be "+parms._cluster_size_constraints[i]+" but is "+kmm2._output._size[i]+"."; + } + + } finally { + if( fr != null ) fr.delete(); + if( points != null ) points.delete(); + if( kmm != null ) kmm.delete(); + if( kmm2 != null ) kmm2.delete(); + Scope.exit(); + } + } + +} diff --git a/h2o-docs/src/product/data-science/k-means.rst b/h2o-docs/src/product/data-science/k-means.rst index c41057321c45..be25a5b41e53 100644 --- a/h2o-docs/src/product/data-science/k-means.rst +++ b/h2o-docs/src/product/data-science/k-means.rst @@ -118,8 +118,6 @@ H2O stops splitting when :math:`PRE` falls below a :math:`threshold`, which is a :math:`\big[0.02 + \frac{10}{number\_of\_training\_rows} + \frac{2.5}{number\_of\_model\_features^{2}}\big]` - - FAQ ~~~