Skip to content

Commit

Permalink
[ML] Stratified cross validation split for classification (#54087)
Browse files Browse the repository at this point in the history
As classification now works for multiple classes, randomly
picking training/test data frame rows is not good enough.
This commit introduces a stratified cross validation splitter
that maintains the proportion of the each class in the dataset
in the sample that is used for training the model.
  • Loading branch information
dimitris-athanasiou committed Mar 24, 2020
1 parent 81d8510 commit af7b95b
Show file tree
Hide file tree
Showing 6 changed files with 403 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public class Classification implements DataFrameAnalysis {
/**
* The max number of classes classification supports
*/
private static final int MAX_DEPENDENT_VARIABLE_CARDINALITY = 30;
public static final int MAX_DEPENDENT_VARIABLE_CARDINALITY = 30;

private static ConstructingObjectParser<Classification, Void> createParser(boolean lenient) {
ConstructingObjectParser<Classification, Void> parser = new ConstructingObjectParser<>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.MlStatsIndex;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
Expand Down Expand Up @@ -162,7 +161,7 @@ private void processData(DataFrameAnalyticsTask task, ProcessContext processCont
AnalyticsResultProcessor resultProcessor = processContext.resultProcessor.get();
try {
writeHeaderRecord(dataExtractor, process);
writeDataRows(dataExtractor, process, config.getAnalysis(), task.getStatsHolder().getProgressTracker(),
writeDataRows(dataExtractor, process, config, task.getStatsHolder().getProgressTracker(),
task.getStatsHolder().getDataCountsTracker());
processContext.statsPersister.persistWithRetry(task.getStatsHolder().getDataCountsTracker().report(config.getId()),
DataCounts::documentId);
Expand Down Expand Up @@ -214,11 +213,12 @@ private void processData(DataFrameAnalyticsTask task, ProcessContext processCont
}
}

private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess<AnalyticsResult> process, DataFrameAnalysis analysis,
ProgressTracker progressTracker, DataCountsTracker dataCountsTracker) throws IOException {
private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess<AnalyticsResult> process,
DataFrameAnalyticsConfig config, ProgressTracker progressTracker, DataCountsTracker dataCountsTracker)
throws IOException {

CrossValidationSplitter crossValidationSplitter = new CrossValidationSplitterFactory(dataExtractor.getFieldNames())
.create(analysis);
CrossValidationSplitter crossValidationSplitter = new CrossValidationSplitterFactory(client, config, dataExtractor.getFieldNames())
.create();

// The extra fields are for the doc hash and the control field (should be an empty string)
String[] record = new String[dataExtractor.getFieldNames().size() + 2];
Expand Down Expand Up @@ -324,7 +324,8 @@ private void refreshIndices(String jobId) {
);
refreshRequest.indicesOptions(IndicesOptions.lenientExpandOpen());

LOGGER.debug("[{}] Refreshing indices {}", jobId, Arrays.toString(refreshRequest.indices()));
LOGGER.debug(() -> new ParameterizedMessage("[{}] Refreshing indices {}",
jobId, Arrays.toString(refreshRequest.indices())));

try (ThreadContext.StoredContext ignore = client.threadPool().getThreadContext().stashWithOrigin(ML_ORIGIN)) {
client.admin().indices().refresh(refreshRequest).actionGet();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,81 @@
*/
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

public class CrossValidationSplitterFactory {

private static final Logger LOGGER = LogManager.getLogger(CrossValidationSplitterFactory.class);

private final Client client;
private final DataFrameAnalyticsConfig config;
private final List<String> fieldNames;

public CrossValidationSplitterFactory(List<String> fieldNames) {
public CrossValidationSplitterFactory(Client client, DataFrameAnalyticsConfig config, List<String> fieldNames) {
this.client = Objects.requireNonNull(client);
this.config = Objects.requireNonNull(config);
this.fieldNames = Objects.requireNonNull(fieldNames);
}

public CrossValidationSplitter create(DataFrameAnalysis analysis) {
if (analysis instanceof Regression) {
Regression regression = (Regression) analysis;
return new RandomCrossValidationSplitter(
fieldNames, regression.getDependentVariable(), regression.getTrainingPercent(), regression.getRandomizeSeed());
public CrossValidationSplitter create() {
if (config.getAnalysis() instanceof Regression) {
return createRandomSplitter();
}
if (analysis instanceof Classification) {
Classification classification = (Classification) analysis;
return new RandomCrossValidationSplitter(
fieldNames, classification.getDependentVariable(), classification.getTrainingPercent(), classification.getRandomizeSeed());
if (config.getAnalysis() instanceof Classification) {
return createStratifiedSplitter((Classification) config.getAnalysis());
}
return (row, incrementTrainingDocs, incrementTestDocs) -> incrementTrainingDocs.run();
}

private CrossValidationSplitter createRandomSplitter() {
Regression regression = (Regression) config.getAnalysis();
return new RandomCrossValidationSplitter(
fieldNames, regression.getDependentVariable(), regression.getTrainingPercent(), regression.getRandomizeSeed());
}

private CrossValidationSplitter createStratifiedSplitter(Classification classification) {
String aggName = "dependent_variable_terms";
SearchRequestBuilder searchRequestBuilder = client.prepareSearch(config.getDest().getIndex())
.setSize(0)
.setAllowPartialSearchResults(false)
.addAggregation(AggregationBuilders.terms(aggName)
.field(classification.getDependentVariable())
.size(Classification.MAX_DEPENDENT_VARIABLE_CARDINALITY));

try {
SearchResponse searchResponse = ClientHelper.executeWithHeaders(config.getHeaders(), ClientHelper.ML_ORIGIN, client,
searchRequestBuilder::get);
Aggregations aggs = searchResponse.getAggregations();
Terms terms = aggs.get(aggName);
Map<String, Long> classCardinalities = new HashMap<>();
for (Terms.Bucket bucket : terms.getBuckets()) {
classCardinalities.put(String.valueOf(bucket.getKey()), bucket.getDocCount());
}

return new StratifiedCrossValidationSplitter(fieldNames, classification.getDependentVariable(), classCardinalities,
classification.getTrainingPercent(), classification.getRandomizeSeed());
} catch (Exception e) {
ParameterizedMessage msg = new ParameterizedMessage("[{}] Dependent variable terms search failed", config.getId());
LOGGER.error(msg, e);
throw new ElasticsearchException(msg.getFormattedMessage(), e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,18 @@ class RandomCrossValidationSplitter implements CrossValidationSplitter {
private boolean isFirstRow = true;

RandomCrossValidationSplitter(List<String> fieldNames, String dependentVariable, double trainingPercent, long randomizeSeed) {
assert trainingPercent >= 1.0 && trainingPercent <= 100.0;
this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable);
this.trainingPercent = trainingPercent;
this.random = new Random(randomizeSeed);
}

private static int findDependentVariableIndex(List<String> fieldNames, String dependentVariable) {
for (int i = 0; i < fieldNames.size(); i++) {
if (fieldNames.get(i).equals(dependentVariable)) {
return i;
}
int dependentVariableIndex = fieldNames.indexOf(dependentVariable);
if (dependentVariableIndex < 0) {
throw ExceptionsHelper.serverError("Could not find dependent variable [" + dependentVariable + "] in fields " + fieldNames);
}
throw ExceptionsHelper.serverError("Could not find dependent variable [" + dependentVariable + "] in fields " + fieldNames);
return dependentVariableIndex;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/

package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;

import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;

/**
* Given a dependent variable, randomly splits the dataset trying
* to preserve the proportion of each class in the training sample.
*/
public class StratifiedCrossValidationSplitter implements CrossValidationSplitter {

private final int dependentVariableIndex;
private final double samplingRatio;
private final Random random;
private final Map<String, ClassSample> classSamples;

public StratifiedCrossValidationSplitter(List<String> fieldNames, String dependentVariable, Map<String, Long> classCardinalities,
double trainingPercent, long randomizeSeed) {
assert trainingPercent >= 1.0 && trainingPercent <= 100.0;
this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable);
this.samplingRatio = trainingPercent / 100.0;
this.random = new Random(randomizeSeed);
this.classSamples = new HashMap<>();
classCardinalities.entrySet().forEach(entry -> classSamples.put(entry.getKey(), new ClassSample(entry.getValue())));
}

private static int findDependentVariableIndex(List<String> fieldNames, String dependentVariable) {
int dependentVariableIndex = fieldNames.indexOf(dependentVariable);
if (dependentVariableIndex < 0) {
throw ExceptionsHelper.serverError("Could not find dependent variable [" + dependentVariable + "] in fields " + fieldNames);
}
return dependentVariableIndex;
}

@Override
public void process(String[] row, Runnable incrementTrainingDocs, Runnable incrementTestDocs) {

if (canBeUsedForTraining(row) == false) {
incrementTestDocs.run();
return;
}

String classValue = row[dependentVariableIndex];
ClassSample sample = classSamples.get(classValue);
if (sample == null) {
throw new IllegalStateException("Unknown class [" + classValue + "]; expected one of " + classSamples.keySet());
}

// The idea here is that the probability increases as the chances we have to get the target proportion
// for a class decreases.
double p = (samplingRatio * sample.cardinality - sample.training) / (sample.cardinality - sample.observed);

boolean isTraining = random.nextDouble() <= p;

sample.observed++;
if (isTraining) {
sample.training++;
incrementTrainingDocs.run();
} else {
row[dependentVariableIndex] = DataFrameDataExtractor.NULL_VALUE;
incrementTestDocs.run();
}
}

private boolean canBeUsedForTraining(String[] row) {
return row[dependentVariableIndex] != DataFrameDataExtractor.NULL_VALUE;
}

private static class ClassSample {

private final long cardinality;
private long training;
private long observed;

private ClassSample(long cardinality) {
this.cardinality = cardinality;
}
}
}
Loading

0 comments on commit af7b95b

Please sign in to comment.