Skip to content

Commit

Permalink
PUBDEV-7269 Multinomial AUC/AUCPR (#4923)
Browse files Browse the repository at this point in the history
*  implement multinomial AUC/AUCPR java backend
*  propagate multinomial AUC/AUCPR to Python/R
* add early stopping, add tests
* fix generic model, mojo
* implement grid search
* add demo
* test big domain
* add doc pages
* add the possibility to disable AUC calculation, set it as default
* add more sklearn code for testing
* fix automl multinomial leaderboard
  • Loading branch information
maurever committed Dec 10, 2020
1 parent 42b65a7 commit 19bfcb9
Show file tree
Hide file tree
Showing 90 changed files with 7,253 additions and 154 deletions.
10 changes: 1 addition & 9 deletions h2o-algos/src/main/java/hex/deeplearning/DeepLearningModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ void set_model_info(DeepLearningModelInfo mi) {
@Override public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
switch(_output.getModelCategory()) {
case Binomial: return new ModelMetricsBinomial.MetricBuilderBinomial(domain);
case Multinomial: return new ModelMetricsMultinomial.MetricBuilderMultinomial(_output.nclasses(),domain);
case Multinomial: return new ModelMetricsMultinomial.MetricBuilderMultinomial(_output.nclasses(),domain, get_params()._auc_type);
case Regression: return new ModelMetricsRegression.MetricBuilderRegression();
case AutoEncoder: return new ModelMetricsAutoEncoder.MetricBuilderAutoEncoder(_output.nfeatures());
default: throw H2O.unimpl("Invalid ModelCategory " + _output.getModelCategory());
Expand Down Expand Up @@ -402,10 +402,6 @@ boolean doScoring(Frame fTrain, Frame fValid, Key<Job> jobKey, int iteration, bo
_output._training_metrics = mtrain;
scoringInfo.scored_train = new ScoreKeeper(mtrain);
hex.ModelMetricsSupervised mm1 = (ModelMetricsSupervised)mtrain;
if (mm1 instanceof ModelMetricsBinomial) {
ModelMetricsBinomial mm = (ModelMetricsBinomial)(mm1);
scoringInfo.training_AUC = mm._auc;
}
if (fTrain.numRows() != training_rows) {
_output._training_metrics._description = "Metrics reported on temporary training frame with " + fTrain.numRows() + " samples";
} else if (fTrain._key != null && fTrain._key.toString().contains("chunks")){
Expand All @@ -431,10 +427,6 @@ boolean doScoring(Frame fTrain, Frame fValid, Key<Job> jobKey, int iteration, bo
_output._validation_metrics = mvalid;
scoringInfo.scored_valid = new ScoreKeeper(mvalid);
if (mvalid != null) {
if (mvalid instanceof ModelMetricsBinomial) {
ModelMetricsBinomial mm = (ModelMetricsBinomial) mvalid;
scoringInfo.validation_AUC = mm._auc;
}
if (fValid.numRows() != validation_rows) {
_output._validation_metrics._description = "Metrics reported on temporary validation frame with " + fValid.numRows() + " samples";
if (get_params()._score_validation_sampling == DeepLearningParameters.ClassSamplingMethod.Stratified) {
Expand Down
2 changes: 1 addition & 1 deletion h2o-algos/src/main/java/hex/gam/GAMModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public class GAMModel extends Model<GAMModel, GAMModel.GAMParameters, GAMModel.G
}
GLMModel.GLMWeightsFun glmf = new GLMModel.GLMWeightsFun(_parms._family, _parms._link, _parms._tweedie_variance_power,
_parms._tweedie_link_power, _parms._theta);
return new MetricBuilderGAM(domain, _ymu, glmf, _rank, true, _parms._intercept, _nclass);
return new MetricBuilderGAM(domain, _ymu, glmf, _rank, true, _parms._intercept, _nclass, _parms._auc_type);
}

public GAMModel(Key<GAMModel> selfKey, GAMParameters parms, GAMModelOutput output) {
Expand Down
6 changes: 3 additions & 3 deletions h2o-algos/src/main/java/hex/gam/MetricBuilderGAM.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public class MetricBuilderGAM extends ModelMetricsSupervised.MetricBuilderSuperv
transient double[] _ds = new double[3];
transient float[] _yact = new float[1];

public MetricBuilderGAM(String[] domain, double[] ymu, GLMModel.GLMWeightsFun glmf, int rank, boolean computeMetrics, boolean intercept, int nclass) {
public MetricBuilderGAM(String[] domain, double[] ymu, GLMModel.GLMWeightsFun glmf, int rank, boolean computeMetrics, boolean intercept, int nclass, MultinomialAucType aucType) {
super(domain==null?0:domain.length, domain);
_intercept = intercept;
_computeMetrics = computeMetrics;
Expand All @@ -40,7 +40,7 @@ public MetricBuilderGAM(String[] domain, double[] ymu, GLMModel.GLMWeightsFun gl
case fractionalbinomial:
_metricBuilder = new ModelMetricsBinomial.MetricBuilderBinomial(domain); break;
case multinomial:
_metricBuilder = new ModelMetricsMultinomial.MetricBuilderMultinomial(nclass, domain); break;
_metricBuilder = new ModelMetricsMultinomial.MetricBuilderMultinomial(nclass, domain, aucType); break;
case ordinal:
_metricBuilder = new ModelMetricsOrdinal.MetricBuilderOrdinal(nclass, domain); break;
default:
Expand Down Expand Up @@ -165,7 +165,7 @@ metricsBinomial._auc, metricsBinomial._logloss, residualDeviance(), _null_devian
mm = new ModelMetricsBinomialGLM.ModelMetricsMultinomialGLM(m, f, metricsMultinomial._nobs,
metricsMultinomial._MSE, metricsMultinomial._domain, metricsMultinomial._sigma, metricsMultinomial._cm,
metricsMultinomial._hit_ratios, metricsMultinomial._logloss, residualDeviance(),_null_deviance, _aic,
nullDOF(), resDOF(), _customMetric);
nullDOF(), resDOF(), metricsMultinomial._auc, _customMetric);
} else if (_glmf._family == GLMModel.GLMParameters.Family.ordinal) { // ordinal should have a different resDOF()
ModelMetricsOrdinal metricsOrdinal = (ModelMetricsOrdinal) mm;
mm = new ModelMetricsBinomialGLM.ModelMetricsOrdinalGLM(m, f, metricsOrdinal._nobs, metricsOrdinal._MSE,
Expand Down
2 changes: 1 addition & 1 deletion h2o-algos/src/main/java/hex/generic/GenericModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
case Binomial:
return new ModelMetricsBinomial.MetricBuilderBinomial(domain);
case Multinomial:
return new ModelMetricsMultinomial.MetricBuilderMultinomial(_output.nclasses(), domain);
return new ModelMetricsMultinomial.MetricBuilderMultinomial(_output.nclasses(), domain, _parms._auc_type);
case Ordinal:
return new ModelMetricsOrdinal.MetricBuilderOrdinal(_output.nclasses(), domain);
case Regression: return new ModelMetricsRegression.MetricBuilderRegression();
Expand Down
7 changes: 5 additions & 2 deletions h2o-algos/src/main/java/hex/generic/GenericModelOutput.java
Original file line number Diff line number Diff line change
Expand Up @@ -110,21 +110,24 @@ auc, binomial._logloss, convertTable(binomial._gains_lift_table),
if (mojoMetrics instanceof MojoModelMetricsMultinomialGLM) {
assert modelAttributes instanceof ModelAttributesGLM;
final ModelAttributesGLM modelAttributesGLM = (ModelAttributesGLM) modelAttributes;
modelAttributesGLM.getModelParameters();
final MojoModelMetricsMultinomialGLM glmMultinomial = (MojoModelMetricsMultinomialGLM) mojoMetrics;
return new ModelMetricsMultinomialGLMGeneric(null, null, mojoMetrics._nobs, mojoMetrics._MSE,
_domains[_domains.length - 1], glmMultinomial._sigma,
convertTable(glmMultinomial._confusion_matrix), convertTable(glmMultinomial._hit_ratios),
glmMultinomial._logloss, new CustomMetric(mojoMetrics._custom_metric_name, mojoMetrics._custom_metric_value),
glmMultinomial._mean_per_class_error, glmMultinomial._nullDegressOfFreedom, glmMultinomial._residualDegressOfFreedom,
glmMultinomial._resDev, glmMultinomial._nullDev, glmMultinomial._AIC, convertTable(modelAttributesGLM._coefficients_table),
glmMultinomial._r2, glmMultinomial._description);
glmMultinomial._r2, convertTable(glmMultinomial._multinomial_auc), convertTable(glmMultinomial._multinomial_aucpr),
MultinomialAucType.valueOf((String)modelAttributes.getParameterValueByName("auc_type")), glmMultinomial._description);
} else {
final MojoModelMetricsMultinomial multinomial = (MojoModelMetricsMultinomial) mojoMetrics;
return new ModelMetricsMultinomialGeneric(null, null, mojoMetrics._nobs, mojoMetrics._MSE,
_domains[_domains.length - 1], multinomial._sigma,
convertTable(multinomial._confusion_matrix), convertTable(multinomial._hit_ratios),
multinomial._logloss, new CustomMetric(mojoMetrics._custom_metric_name, mojoMetrics._custom_metric_value),
multinomial._mean_per_class_error, multinomial._r2, multinomial._description);
multinomial._mean_per_class_error, multinomial._r2, convertTable(multinomial._multinomial_auc), convertTable(multinomial._multinomial_aucpr),
MultinomialAucType.valueOf((String)modelAttributes.getParameterValueByName("auc_type")), multinomial._description);
}
case Regression:
assert mojoMetrics instanceof MojoModelMetricsRegression;
Expand Down
6 changes: 3 additions & 3 deletions h2o-algos/src/main/java/hex/glm/GLMMetricBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public class GLMMetricBuilder extends MetricBuilderSupervised<GLMMetricBuilder>
final boolean _intercept;
private final double [] _ymu;
final boolean _computeMetrics;
public GLMMetricBuilder(String[] domain, double [] ymu, GLMWeightsFun glmf, int rank, boolean computeMetrics, boolean intercept){
public GLMMetricBuilder(String[] domain, double [] ymu, GLMWeightsFun glmf, int rank, boolean computeMetrics, boolean intercept, MultinomialAucType aucType){
super(domain == null?0:domain.length, domain);
_glmf = glmf;
_rank = rank;
Expand All @@ -57,7 +57,7 @@ public GLMMetricBuilder(String[] domain, double [] ymu, GLMWeightsFun glmf, int
_metricBuilder = new MetricBuilderBinomial(domain);
break;
case multinomial:
_metricBuilder = new MetricBuilderMultinomial(domain.length, domain);
_metricBuilder = new MetricBuilderMultinomial(domain.length, domain, aucType);
((MetricBuilderMultinomial) _metricBuilder)._priorDistribution = ymu;
break;
case ordinal:
Expand Down Expand Up @@ -241,7 +241,7 @@ protected void computeAIC(){
metrics = new ModelMetricsBinomialGLM(m, f, metrics._nobs, metrics._MSE, _domain, metricsBinommial._sigma, metricsBinommial._auc, metricsBinommial._logloss, residualDeviance(), null_devince, _aic, nullDOF(), resDOF(), gl, _customMetric);
} else if (_glmf._family == Family.multinomial) {
ModelMetricsMultinomial metricsMultinomial = (ModelMetricsMultinomial) metrics;
metrics = new ModelMetricsMultinomialGLM(m, f, metricsMultinomial._nobs, metricsMultinomial._MSE, metricsMultinomial._domain, metricsMultinomial._sigma, metricsMultinomial._cm, metricsMultinomial._hit_ratios, metricsMultinomial._logloss, residualDeviance(), null_devince, _aic, nullDOF(), resDOF(), _customMetric);
metrics = new ModelMetricsMultinomialGLM(m, f, metricsMultinomial._nobs, metricsMultinomial._MSE, metricsMultinomial._domain, metricsMultinomial._sigma, metricsMultinomial._cm, metricsMultinomial._hit_ratios, metricsMultinomial._logloss, residualDeviance(), null_devince, _aic, nullDOF(), resDOF(), metricsMultinomial._auc, _customMetric);
} else if (_glmf._family == Family.ordinal) { // ordinal should have a different resDOF()
ModelMetricsOrdinal metricsOrdinal = (ModelMetricsOrdinal) metrics;
metrics = new ModelMetricsOrdinalGLM(m, f, metricsOrdinal._nobs, metricsOrdinal._MSE, metricsOrdinal._domain, metricsOrdinal._sigma, metricsOrdinal._cm, metricsOrdinal._hit_ratios, metricsOrdinal._logloss, residualDeviance(), null_devince, _aic, nullDOF(), resDOF(), _customMetric);
Expand Down
4 changes: 2 additions & 2 deletions h2o-algos/src/main/java/hex/glm/GLMModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,9 @@ private int rank(double [] ds) {
domain = binomialClassNames;
if (_parms._HGLM) {
String[] domaint = new String[]{"HGLM_" + _parms._family.toString() + "_" + _parms._rand_family[0].toString()};
return new GLMMetricBuilder(domaint, null, null, 0, true, false);
return new GLMMetricBuilder(domaint, null, null, 0, true, false, MultinomialAucType.NONE);
} else
return new GLMMetricBuilder(domain, _ymu, new GLMWeightsFun(_parms), _output.bestSubmodel().rank(), true, _parms._intercept);
return new GLMMetricBuilder(domain, _ymu, new GLMWeightsFun(_parms), _output.bestSubmodel().rank(), true, _parms._intercept, _parms._auc_type);
}

protected double [] beta_internal(){
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public ModelSchemaV3 schema() {
@Override public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
switch(_output.getModelCategory()) {
case Binomial: return new ModelMetricsBinomial.MetricBuilderBinomial(domain);
case Multinomial: return new ModelMetricsMultinomial.MetricBuilderMultinomial(domain.length,domain);
case Multinomial: return new ModelMetricsMultinomial.MetricBuilderMultinomial(domain.length,domain, _parms._auc_type);
default: throw H2O.unimpl();
}
}
Expand Down
2 changes: 1 addition & 1 deletion h2o-algos/src/main/java/hex/rulefit/RuleFitModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
case Binomial:
return new ModelMetricsBinomial.MetricBuilderBinomial(domain);
case Multinomial:
return new ModelMetricsMultinomial.MetricBuilderMultinomial(_output.nclasses(), domain);
return new ModelMetricsMultinomial.MetricBuilderMultinomial(_output.nclasses(), domain, _parms._auc_type);
case Regression:
return new ModelMetricsRegression.MetricBuilderRegression();
default:
Expand Down
3 changes: 2 additions & 1 deletion h2o-algos/src/main/java/hex/schemas/DRFV3.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ public static final class DRFParametersV3 extends SharedTreeV3.SharedTreeParamet
"custom_metric_func",
"export_checkpoints_dir",
"check_constant_response",
"gainslift_bins"
"gainslift_bins",
"auc_type"
};

// Input fields
Expand Down
3 changes: 2 additions & 1 deletion h2o-algos/src/main/java/hex/schemas/DeepLearningV3.java
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ public static final class DeepLearningParametersV3 extends ModelParametersSchema
"elastic_averaging",
"elastic_averaging_moving_rate",
"elastic_averaging_regularization",
"export_checkpoints_dir"
"export_checkpoints_dir",
"auc_type"
};


Expand Down
3 changes: 2 additions & 1 deletion h2o-algos/src/main/java/hex/schemas/GAMV3.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ public static final class GAMParametersV3 extends ModelParametersSchemaV3<GAMMod
"gam_columns", // array: predictor column names array
"bs", // array, name of basis functions used
"scale", // array, smoothing parameter for GAM,
"keep_gam_cols"
"keep_gam_cols",
"auc_type"
};

@API(help = "Seed for pseudo random number generator (if applicable)", gridable = true)
Expand Down
8 changes: 2 additions & 6 deletions h2o-algos/src/main/java/hex/schemas/GBMV3.java
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,8 @@ public static final class GBMParametersV3 extends SharedTreeV3.SharedTreeParamet
"export_checkpoints_dir",
"monotone_constraints",
"check_constant_response",
"gainslift_bins"
// "use_new_histo_tsk",
// "col_block_sz",
// "min_threads",
// "shared_histo",
// "unordered"
"gainslift_bins",
"auc_type",
};

// Input fields
Expand Down
3 changes: 2 additions & 1 deletion h2o-algos/src/main/java/hex/schemas/GLMV3.java
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ public static final class GLMParametersV3 extends ModelParametersSchemaV3<GLMPar
"max_after_balance_size",
"max_confusion_matrix_size",
"max_runtime_secs",
"custom_metric_func"
"custom_metric_func",
"auc_type"
};

@API(help = "Seed for pseudo random number generator (if applicable)", gridable = true)
Expand Down
3 changes: 2 additions & 1 deletion h2o-algos/src/main/java/hex/schemas/NaiveBayesV3.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ public static final class NaiveBayesParametersV3 extends ModelParametersSchemaV3
"compute_metrics",
"max_runtime_secs",
"export_checkpoints_dir",
"gainslift_bins"
"gainslift_bins",
"auc_type"
};

/*Imbalanced Classes*/
Expand Down
3 changes: 2 additions & 1 deletion h2o-algos/src/main/java/hex/schemas/StackedEnsembleV99.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ public static final class StackedEnsembleParametersV99 extends ModelParametersSc
"seed",
"score_training_samples",
"keep_levelone_frame",
"export_checkpoints_dir"
"export_checkpoints_dir",
"auc_type"
};

public static class AlgorithmValuesProvider extends EnumValuesProvider<Algorithm> {
Expand Down
20 changes: 18 additions & 2 deletions h2o-algos/src/main/java/hex/tree/SharedTree.java
Original file line number Diff line number Diff line change
Expand Up @@ -853,9 +853,13 @@ public static TwoDimTable createScoringHistoryTable(Model.Output _output,
colHeaders.add("Training pr_auc"); colTypes.add("double"); colFormat.add("%.5f");
colHeaders.add("Training Lift"); colTypes.add("double"); colFormat.add("%.5f");
}
if (_output.getModelCategory() == ModelCategory.Binomial || _output.getModelCategory() == ModelCategory.Multinomial) {
if(_output.isClassifier()){
colHeaders.add("Training Classification Error"); colTypes.add("double"); colFormat.add("%.5f");
}
if (_output.getModelCategory() == ModelCategory.Multinomial) {
colHeaders.add("Training AUC"); colTypes.add("double"); colFormat.add("%.5f");
colHeaders.add("Training pr_auc"); colTypes.add("double"); colFormat.add("%.5f");
}
if (hasCustomMetric) {
colHeaders.add("Training Custom"); colTypes.add("double"); colFormat.add("%.5f");
}
Expand All @@ -878,9 +882,13 @@ public static TwoDimTable createScoringHistoryTable(Model.Output _output,
colHeaders.add("Validation pr_auc"); colTypes.add("double"); colFormat.add("%.5f");
colHeaders.add("Validation Lift"); colTypes.add("double"); colFormat.add("%.5f");
}
if (_output.isClassifier()) {
if(_output.isClassifier()){
colHeaders.add("Validation Classification Error"); colTypes.add("double"); colFormat.add("%.5f");
}
if (_output.getModelCategory() == ModelCategory.Multinomial) {
colHeaders.add("Validation AUC"); colTypes.add("double"); colFormat.add("%.5f");
colHeaders.add("Validation pr_auc"); colTypes.add("double"); colFormat.add("%.5f");
}
if (hasCustomMetric) {
colHeaders.add("Validation Custom"); colTypes.add("double"); colFormat.add("%.5f");
}
Expand Down Expand Up @@ -921,6 +929,10 @@ public static TwoDimTable createScoringHistoryTable(Model.Output _output,
table.set(row, col++, st._lift);
}
if (_output.isClassifier()) table.set(row, col++, st._classError);
if (_output.getModelCategory() == ModelCategory.Multinomial) {
table.set(row, col++, st._AUC);
table.set(row, col++, st._pr_auc);
}
if (hasCustomMetric) table.set(row, col++, st._custom_metric);

if (_output._validation_metrics != null) {
Expand All @@ -939,6 +951,10 @@ public static TwoDimTable createScoringHistoryTable(Model.Output _output,
table.set(row, col++, st._lift);
}
if (_output.isClassifier()) table.set(row, col++, st._classError);
if (_output.getModelCategory() == ModelCategory.Multinomial) {
table.set(row, col++, st._AUC);
table.set(row, col++, st._pr_auc);
}
if (hasCustomMetric) table.set(row, col++, st._custom_metric);
}
row++;
Expand Down
2 changes: 1 addition & 1 deletion h2o-algos/src/main/java/hex/tree/SharedTreeModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ public Parameters getParams() {
@Override public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
switch(_output.getModelCategory()) {
case Binomial: return new ModelMetricsBinomial.MetricBuilderBinomial(domain);
case Multinomial: return new ModelMetricsMultinomial.MetricBuilderMultinomial(_output.nclasses(),domain);
case Multinomial: return new ModelMetricsMultinomial.MetricBuilderMultinomial(_output.nclasses(),domain, _parms._auc_type);
case Regression: return new ModelMetricsRegression.MetricBuilderRegression();
default: throw H2O.unimpl();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2282,7 +2282,7 @@ public void testMultinomialMNIST() {
Vec labels = small.vec("C785"); //actual
String[] fullDomain = train.vec("C785").domain(); //actual

ModelMetricsMultinomial mm = ModelMetricsMultinomial.make(preds, labels, fullDomain);
ModelMetricsMultinomial mm = ModelMetricsMultinomial.make(preds, labels, fullDomain, MultinomialAucType.NONE);
Log.info(mm.toString());
}
} catch(Throwable t) {
Expand Down Expand Up @@ -2329,7 +2329,7 @@ public void testMultinomial() {
Vec labels = train.vec("pclass"); //actual
String[] fullDomain = train.vec("pclass").domain(); //actual

ModelMetricsMultinomial mm = ModelMetricsMultinomial.make(preds, labels, fullDomain);
ModelMetricsMultinomial mm = ModelMetricsMultinomial.make(preds, labels, fullDomain, MultinomialAucType.NONE);
Log.info(mm.toString());
} finally {
if (model!=null) model.delete();
Expand Down
Loading

0 comments on commit 19bfcb9

Please sign in to comment.