-
Notifications
You must be signed in to change notification settings - Fork 24.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ML] Stratified cross validation split for classification (#54087)
As classification now works for multiple classes, randomly picking training/test data frame rows is not good enough. This commit introduces a stratified cross validation splitter that maintains the proportion of the each class in the dataset in the sample that is used for training the model.
- Loading branch information
1 parent
81d8510
commit af7b95b
Showing
6 changed files
with
403 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
90 changes: 90 additions & 0 deletions
90
...csearch/xpack/ml/dataframe/process/crossvalidation/StratifiedCrossValidationSplitter.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License; | ||
* you may not use this file except in compliance with the Elastic License. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation; | ||
|
||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; | ||
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; | ||
|
||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Random; | ||
|
||
/** | ||
* Given a dependent variable, randomly splits the dataset trying | ||
* to preserve the proportion of each class in the training sample. | ||
*/ | ||
public class StratifiedCrossValidationSplitter implements CrossValidationSplitter { | ||
|
||
private final int dependentVariableIndex; | ||
private final double samplingRatio; | ||
private final Random random; | ||
private final Map<String, ClassSample> classSamples; | ||
|
||
public StratifiedCrossValidationSplitter(List<String> fieldNames, String dependentVariable, Map<String, Long> classCardinalities, | ||
double trainingPercent, long randomizeSeed) { | ||
assert trainingPercent >= 1.0 && trainingPercent <= 100.0; | ||
this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable); | ||
this.samplingRatio = trainingPercent / 100.0; | ||
this.random = new Random(randomizeSeed); | ||
this.classSamples = new HashMap<>(); | ||
classCardinalities.entrySet().forEach(entry -> classSamples.put(entry.getKey(), new ClassSample(entry.getValue()))); | ||
} | ||
|
||
private static int findDependentVariableIndex(List<String> fieldNames, String dependentVariable) { | ||
int dependentVariableIndex = fieldNames.indexOf(dependentVariable); | ||
if (dependentVariableIndex < 0) { | ||
throw ExceptionsHelper.serverError("Could not find dependent variable [" + dependentVariable + "] in fields " + fieldNames); | ||
} | ||
return dependentVariableIndex; | ||
} | ||
|
||
@Override | ||
public void process(String[] row, Runnable incrementTrainingDocs, Runnable incrementTestDocs) { | ||
|
||
if (canBeUsedForTraining(row) == false) { | ||
incrementTestDocs.run(); | ||
return; | ||
} | ||
|
||
String classValue = row[dependentVariableIndex]; | ||
ClassSample sample = classSamples.get(classValue); | ||
if (sample == null) { | ||
throw new IllegalStateException("Unknown class [" + classValue + "]; expected one of " + classSamples.keySet()); | ||
} | ||
|
||
// The idea here is that the probability increases as the chances we have to get the target proportion | ||
// for a class decreases. | ||
double p = (samplingRatio * sample.cardinality - sample.training) / (sample.cardinality - sample.observed); | ||
|
||
boolean isTraining = random.nextDouble() <= p; | ||
|
||
sample.observed++; | ||
if (isTraining) { | ||
sample.training++; | ||
incrementTrainingDocs.run(); | ||
} else { | ||
row[dependentVariableIndex] = DataFrameDataExtractor.NULL_VALUE; | ||
incrementTestDocs.run(); | ||
} | ||
} | ||
|
||
private boolean canBeUsedForTraining(String[] row) { | ||
return row[dependentVariableIndex] != DataFrameDataExtractor.NULL_VALUE; | ||
} | ||
|
||
private static class ClassSample { | ||
|
||
private final long cardinality; | ||
private long training; | ||
private long observed; | ||
|
||
private ClassSample(long cardinality) { | ||
this.cardinality = cardinality; | ||
} | ||
} | ||
} |
Oops, something went wrong.