Skip to content

Commit

Permalink
PUBDEV-6447 change long precision to double, prepare weights to CV, p…
Browse files Browse the repository at this point in the history
…repare doc
  • Loading branch information
maurever committed Dec 2, 2019
1 parent 962792d commit c656003
Show file tree
Hide file tree
Showing 9 changed files with 179 additions and 171 deletions.
31 changes: 17 additions & 14 deletions h2o-algos/src/main/java/hex/kmeans/KMeans.java
Original file line number Diff line number Diff line change
Expand Up @@ -300,22 +300,21 @@ public void computeImpl() {
Log.info("Cutoff for relative improvement in within_cluster_sum_of_squares: " + rel_improvement_cutoff);

Vec[] vecs2;
long csum = 0;
if(!constrained) {
vecs2 = Arrays.copyOf(vecs, vecs.length+1);
vecs2[vecs2.length-1] = vecs2[0].makeCon(-1);
} else {
int newVecLength = vecs.length + centers.length + 3;
vecs2 = Arrays.copyOf(vecs, newVecLength);

for (int i = vecs.length; i < newVecLength; i++) {
vecs2[i] = vecs2[0].makeCon(-1);
vecs2[i] = vecs2[0].makeCon(Double.MAX_VALUE);
}
// Check sum of constrains
long csum = 0;
for(int i = 0; i<_parms._cluster_size_constraints.length; i++){
assert _parms._cluster_size_constraints[i] > 0: "The value of constraint should be higher then zero.";
csum += _parms._cluster_size_constraints[i];
assert csum <= vecs[0].length(): "The sum of constraints is higher than the number of points.";
assert csum <= vecs[0].length(): "The sum of constraints is higher than the number of data rows.";
}
}

Expand All @@ -333,11 +332,12 @@ public void computeImpl() {
} else {
// Constrained K-means

// Get distances
CalculateDistancesTask countDistancesTask = new CalculateDistancesTask(centers, means, mults, impute_cat, _isCats, k, hasWeightCol(), _parms._constrained_kmeans_precision).doAll(vecs2);
// Get distances and aggregated values
CalculateDistancesTask countDistancesTask = new CalculateDistancesTask(centers, means, mults, impute_cat, _isCats, k, hasWeightCol()).doAll(vecs2);
assert !hasWeightCol() || csum <= countDistancesTask._non_zero_weights : "The sum of constraints is higher than the number of data rows with non zero weights.";

// Calculate center assignments
KMeansSimplexSolver solver = new KMeansSimplexSolver(_parms._cluster_size_constraints, new Frame(vecs2), countDistancesTask._sum, hasWeightCol(), _parms._constrained_kmeans_precision);
KMeansSimplexSolver solver = new KMeansSimplexSolver(_parms._cluster_size_constraints, new Frame(vecs2), countDistancesTask._sum, hasWeightCol(), countDistancesTask._non_zero_weights);

// Get cluster assignments
Frame result = solver.assignClusters();
Expand Down Expand Up @@ -815,10 +815,10 @@ private static class CalculateDistancesTask extends MRTask<CalculateDistancesTas
final int _k;
boolean _hasWeight;
final String[][] _isCats;
long _sum;
int _precision;
double _sum;
long _non_zero_weights;

CalculateDistancesTask(double[][] centers, double[] means, double[] mults, int[] modes, String[][] isCats, int k, boolean hasWeight, int precision) {
CalculateDistancesTask(double[][] centers, double[] means, double[] mults, int[] modes, String[][] isCats, int k, boolean hasWeight) {
_centers = centers;
_means = means;
_mults = mults;
Expand All @@ -827,32 +827,35 @@ private static class CalculateDistancesTask extends MRTask<CalculateDistancesTas
_hasWeight = hasWeight;
_isCats = isCats;
_sum = 0;
_precision = precision;
_non_zero_weights = 0;
}

@Override
public void map(Chunk[] cs) {
int N = cs.length - (_hasWeight ? 1 : 0) - 3 - _centers.length /*data + weight column + distances + old assignment + new assignment */;
assert _centers[0].length == N;
int vecsStart = _hasWeight ? N+1 : N;

double[] values = new double[N]; // Temp data to hold row as doubles
for (int row = 0; row < cs[0]._len; row++) {
double weight = _hasWeight ? cs[N].atd(row) : 1;
if (weight == 0) continue; //skip holdout rows
_non_zero_weights++;
assert (weight == 1); //K-Means only works for weight 1 (or weight 0 for holdout)
data(values, cs, row, _means, _mults, _modes); // Load row as doubles
double[] distances = getDistances(_centers, values, _isCats);
for(int cluster=0; cluster<distances.length; cluster++){
double tmpDist = distances[cluster];
cs[N+cluster].set(row, tmpDist);
_sum += (long) (tmpDist * _precision);
cs[vecsStart+cluster].set(row, tmpDist);
_sum += tmpDist;
}
}
}

@Override
public void reduce(CalculateDistancesTask mrt) {
_sum += mrt._sum;
_non_zero_weights += mrt._non_zero_weights;
}
}

Expand Down Expand Up @@ -905,7 +908,7 @@ private static class CalculateMetricTask extends IterationTask {
}
}
}
assert cluster != -1 : "cluster "+cluster+"is not set for row:"+row; // No broken rows
assert cluster != -1 : "cluster "+cluster+" is not set for row "+row; // No broken rows
_cSqr[cluster] += distance;

// Add values and increment counter for chosen cluster
Expand Down
1 change: 0 additions & 1 deletion h2o-algos/src/main/java/hex/kmeans/KMeansModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ public static class KMeansParameters extends ClusteringModel.ClusteringParameter
// Ex: k = 4, cluster = 3 -> [0, 0, 1, 0]
public boolean _estimate_k = false; // If enabled, iteratively find up to _k clusters
public int[] _cluster_size_constraints = null;
public int _constrained_kmeans_precision = 1000;
}

public static class KMeansOutput extends ClusteringModel.ClusteringOutput {
Expand Down
Loading

0 comments on commit c656003

Please sign in to comment.