Skip to content

Commit

Permalink
IGNITE-9393:[ML] KMeans fails on complex data in cache
Browse files Browse the repository at this point in the history
this closes apache#4628
  • Loading branch information
zaleslaw authored and akalash committed Nov 19, 2018
1 parent d4fc7dd commit ef2be90
Show file tree
Hide file tree
Showing 14 changed files with 191 additions and 159 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,13 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> {
/**
* Trains model based on the specified data.
*
* @param datasetBuilder Dataset builder.
* @param datasetBuilder Dataset builder.
* @param featureExtractor Feature extractor.
* @param lbExtractor Label extractor.
* @param lbExtractor Label extractor.
* @return Model.
*/
@Override public <K, V> KMeansModel fit(DatasetBuilder<K, V> datasetBuilder,
IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
assert datasetBuilder != null;

PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<Double, LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>(
Expand All @@ -85,7 +85,14 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> {
(upstream, upstreamSize) -> new EmptyContext(),
partDataBuilder
)) {
final int cols = dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> a == null ? b : a);
final int cols = dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> {
if (a == null)
return b == null ? 0 : b;
if (b == null)
return a;
return b;
});

centers = initClusterCentersRandomly(dataset, k);

boolean converged = false;
Expand Down Expand Up @@ -113,7 +120,8 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> {
centers[i] = newCentroids[i];
}
}
} catch (Exception e) {
}
catch (Exception e) {
throw new RuntimeException(e);
}
return new KMeansModel(centers, distance);
Expand All @@ -124,15 +132,14 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> {
*
* @param centers Current centers on the current iteration.
* @param dataset Dataset.
* @param cols Amount of columns.
* @param cols Amount of columns.
* @return Helper data to calculate the new centroids.
*/
private TotalCostAndCounts calcDataForNewCentroids(Vector[] centers,
Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset, int cols) {
Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset, int cols) {
final Vector[] finalCenters = centers;

return dataset.compute(data -> {

TotalCostAndCounts res = new TotalCostAndCounts();

for (int i = 0; i < data.rowSize(); i++) {
Expand All @@ -147,20 +154,29 @@ private TotalCostAndCounts calcDataForNewCentroids(Vector[] centers,

int finalI = i;
res.sums.compute(centroidIdx,
(IgniteBiFunction<Integer, Vector, Vector>) (ind, v) -> v.plus(data.getRow(finalI).features()));
(IgniteBiFunction<Integer, Vector, Vector>)(ind, v) -> {
Vector features = data.getRow(finalI).features();
return v == null ? features : v.plus(features);
});

res.counts.merge(centroidIdx, 1,
(IgniteBiFunction<Integer, Integer, Integer>) (i1, i2) -> i1 + i2);
(IgniteBiFunction<Integer, Integer, Integer>)(i1, i2) -> i1 + i2);
}
return res;
}, (a, b) -> a == null ? b : a.merge(b));
}, (a, b) -> {
if (a == null)
return b == null ? new TotalCostAndCounts() : b;
if (b == null)
return a;
return a.merge(b);
});
}

/**
* Find the closest cluster center index and distance to it from a given point.
*
* @param centers Centers to look in.
* @param pnt Point.
* @param pnt Point.
*/
private IgniteBiTuple<Integer, Double> findClosestCentroid(Vector[] centers, LabeledVector pnt) {
double bestDistance = Double.POSITIVE_INFINITY;
Expand All @@ -180,12 +196,11 @@ private IgniteBiTuple<Integer, Double> findClosestCentroid(Vector[] centers, Lab
* K cluster centers are initialized randomly.
*
* @param dataset The dataset to pick up random centers.
* @param k Amount of clusters.
* @param k Amount of clusters.
* @return K cluster centers.
*/
private Vector[] initClusterCentersRandomly(Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset,
int k) {

int k) {
Vector[] initCenters = new DenseVector[k];

// Gets k or less vectors from each partition.
Expand All @@ -211,12 +226,19 @@ private Vector[] initClusterCentersRandomly(Dataset<EmptyContext, LabeledVectorS

rndPnt.add(data.getRow(nextIdx));
}
} else // If it's not enough vectors to pick k vectors.
}
else // If it's not enough vectors to pick k vectors.
for (int i = 0; i < data.rowSize(); i++)
rndPnt.add(data.getRow(i));
}
return rndPnt;
}, (a, b) -> a == null ? b : Stream.concat(a.stream(), b.stream()).collect(Collectors.toList()));
}, (a, b) -> {
if (a == null)
return b == null ? new ArrayList<>() : b;
if (b == null)
return a;
return Stream.concat(a.stream(), b.stream()).collect(Collectors.toList());
});

// Shuffle them.
Collections.shuffle(rndPnts);
Expand All @@ -228,7 +250,8 @@ private Vector[] initClusterCentersRandomly(Dataset<EmptyContext, LabeledVectorS
rndPnts.remove(rndPnt);
initCenters[i] = rndPnt.features();
}
} else
}
else
throw new RuntimeException("The KMeans Trainer required more than " + k + " vectors to find " + k + " clusters");

return initCenters;
Expand All @@ -245,7 +268,6 @@ public static class TotalCostAndCounts {
/** Count of points closest to the center with a given index. */
ConcurrentHashMap<Integer, Integer> counts = new ConcurrentHashMap<>();


/** Count of points closest to the center with a given index. */
ConcurrentHashMap<Integer, ConcurrentHashMap<Double, Integer>> centroidStat = new ConcurrentHashMap<>();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,7 @@ private <K, V> CentroidStat getCentroidStat(DatasetBuilder<K, V> datasetBuilder,
(upstream, upstreamSize) -> new EmptyContext(),
partDataBuilder
)) {

return dataset.compute(data -> {

CentroidStat res = new CentroidStat();

for (int i = 0; i < data.rowSize(); i++) {
Expand All @@ -171,15 +169,21 @@ private <K, V> CentroidStat getCentroidStat(DatasetBuilder<K, V> datasetBuilder,
centroidStat.put(lb, 1);
res.centroidStat.put(centroidIdx, centroidStat);
} else {
int cnt = centroidStat.containsKey(lb) ? centroidStat.get(lb) : 0;
int cnt = centroidStat.getOrDefault(lb, 0);
centroidStat.put(lb, cnt + 1);
}

res.counts.merge(centroidIdx, 1,
(IgniteBiFunction<Integer, Integer, Integer>) (i1, i2) -> i1 + i2);
}
return res;
}, (a, b) -> a == null ? b : a.merge(b));
}, (a, b) -> {
if (a == null)
return b == null ? new CentroidStat() : b;
if (b == null)
return a;
return a.merge(b);
});

} catch (Exception e) {
throw new RuntimeException(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.ignite.ml.knn.classification;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -79,7 +80,13 @@ protected List<LabeledVector> findKNearestNeighbors(Vector v) {
List<LabeledVector> neighborsFromPartitions = dataset.compute(data -> {
TreeMap<Double, Set<Integer>> distanceIdxPairs = getDistances(v, data);
return Arrays.asList(getKClosestVectors(data, distanceIdxPairs));
}, (a, b) -> a == null ? b : Stream.concat(a.stream(), b.stream()).collect(Collectors.toList()));
}, (a, b) -> {
if (a == null)
return b == null ? new ArrayList<>() : b;
if (b == null)
return a;
return Stream.concat(a.stream(), b.stream()).collect(Collectors.toList());
});

LabeledVectorSet<Double, LabeledVector> neighborsToFilter = buildLabeledDatasetOnListOfVectors(neighborsFromPartitions);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,13 @@ else if (b == null)
@Override protected int getColumns() {
return dataset.compute(
data -> data.getFeatures() == null ? null : data.getFeatures().length / data.getRows(),
(a, b) -> a == null ? b : a
(a, b) -> {
if (a == null)
return b == null ? 0 : b;
if (b == null)
return a;
return b;
}
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,13 @@ public LinearRegressionSGDTrainer(UpdatesStrategy<? super MultilayerPerceptron,
if (data.getFeatures() == null)
return null;
return data.getFeatures().length / data.getRows();
}, (a, b) -> a == null ? b : a);
}, (a, b) -> {
if (a == null)
return b == null ? 0 : b;
if (b == null)
return a;
return b;
});

MLPArchitecture architecture = new MLPArchitecture(cols);
architecture = architecture.withAddedLayer(1, true, Activators.LINEAR);
Expand All @@ -100,7 +106,7 @@ public LinearRegressionSGDTrainer(UpdatesStrategy<? super MultilayerPerceptron,
seed
);

IgniteBiFunction<K, V, double[]> lbE = (IgniteBiFunction<K, V, double[]>)(k, v) -> new double[]{lbExtractor.apply(k, v)};
IgniteBiFunction<K, V, double[]> lbE = (IgniteBiFunction<K, V, double[]>)(k, v) -> new double[] {lbExtractor.apply(k, v)};

MultilayerPerceptron mlp = trainer.fit(datasetBuilder, featureExtractor, lbE);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public class LogisticRegressionSGDTrainer<P extends Serializable> extends Single
* @param seed Seed for random generator.
*/
public LogisticRegressionSGDTrainer(UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy, int maxIterations,
int batchSize, int locIterations, long seed) {
int batchSize, int locIterations, long seed) {
this.updatesStgy = updatesStgy;
this.maxIterations = maxIterations;
this.batchSize = batchSize;
Expand All @@ -82,7 +82,13 @@ public LogisticRegressionSGDTrainer(UpdatesStrategy<? super MultilayerPerceptron
if (data.getFeatures() == null)
return null;
return data.getFeatures().length / data.getRows();
}, (a, b) -> a == null ? b : a);
}, (a, b) -> {
if (a == null)
return b == null ? 0 : b;
if (b == null)
return a;
return b;
});

MLPArchitecture architecture = new MLPArchitecture(cols);
architecture = architecture.withAddedLayer(1, true, Activators.SIGMOID);
Expand All @@ -100,7 +106,7 @@ public LogisticRegressionSGDTrainer(UpdatesStrategy<? super MultilayerPerceptron
seed
);

MultilayerPerceptron mlp = trainer.fit(datasetBuilder, featureExtractor, (k, v) -> new double[]{lbExtractor.apply(k, v)});
MultilayerPerceptron mlp = trainer.fit(datasetBuilder, featureExtractor, (k, v) -> new double[] {lbExtractor.apply(k, v)});

double[] params = mlp.parameters().getStorage().data();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,14 @@ public class LogRegressionMultiClassTrainer<P extends Serializable>
/**
* Trains model based on the specified data.
*
* @param datasetBuilder Dataset builder.
* @param datasetBuilder Dataset builder.
* @param featureExtractor Feature extractor.
* @param lbExtractor Label extractor.
* @param lbExtractor Label extractor.
* @return Model.
*/
@Override public <K, V> LogRegressionMultiClassModel fit(DatasetBuilder<K, V> datasetBuilder,
IgniteBiFunction<K, V, Vector> featureExtractor,
IgniteBiFunction<K, V, Double> lbExtractor) {
IgniteBiFunction<K, V, Vector> featureExtractor,
IgniteBiFunction<K, V, Double> lbExtractor) {
List<Double> classes = extractClassLabels(datasetBuilder, lbExtractor);

LogRegressionMultiClassModel multiClsMdl = new LogRegressionMultiClassModel();
Expand All @@ -92,7 +92,8 @@ public class LogRegressionMultiClassTrainer<P extends Serializable>
}

/** Iterates among dataset and collects class labels. */
private <K, V> List<Double> extractClassLabels(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Double> lbExtractor) {
private <K, V> List<Double> extractClassLabels(DatasetBuilder<K, V> datasetBuilder,
IgniteBiFunction<K, V, Double> lbExtractor) {
assert datasetBuilder != null;

PartitionDataBuilder<K, V, EmptyContext, LabelPartitionDataOnHeap> partDataBuilder = new LabelPartitionDataBuilderOnHeap<>(lbExtractor);
Expand All @@ -108,14 +109,22 @@ private <K, V> List<Double> extractClassLabels(DatasetBuilder<K, V> datasetBuild

final double[] lbs = data.getY();

for (double lb : lbs) locClsLabels.add(lb);
for (double lb : lbs)
locClsLabels.add(lb);

return locClsLabels;
}, (a, b) -> a == null ? b : Stream.of(a, b).flatMap(Collection::stream).collect(Collectors.toSet()));
}, (a, b) -> {
if (a == null)
return b == null ? new HashSet<>() : b;
if (b == null)
return a;
return Stream.of(a, b).flatMap(Collection::stream).collect(Collectors.toSet());
});

res.addAll(clsLabels);

} catch (Exception e) {
}
catch (Exception e) {
throw new RuntimeException(e);
}
return res;
Expand Down

0 comments on commit ef2be90

Please sign in to comment.