Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-16046 UpliftDRF Implement Cross Validation #16048

Draft
wants to merge 8 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion h2o-algos/src/main/java/hex/generic/GenericModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
case AnomalyDetection:
return new ModelMetricsAnomaly.MetricBuilderAnomaly();
case BinomialUplift:
return new ModelMetricsBinomialUplift.MetricBuilderBinomialUplift(domain, null);
return new ModelMetricsBinomialUplift.MetricBuilderBinomialUplift(domain, null, _parms._auuc_nbins,AUUC.calculateProbs(_parms._auuc_nbins));
default:
throw H2O.unimpl();
}
Expand Down
2 changes: 1 addition & 1 deletion h2o-algos/src/main/java/hex/tree/SharedTreeModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ public boolean forceStrictlyReproducibleHistograms() {
case Binomial: return new ModelMetricsBinomial.MetricBuilderBinomial(domain);
case Multinomial: return new ModelMetricsMultinomial.MetricBuilderMultinomial(_output.nclasses(),domain, _parms._auc_type);
case Regression: return new ModelMetricsRegression.MetricBuilderRegression();
case BinomialUplift: return new ModelMetricsBinomialUplift.MetricBuilderBinomialUplift(domain, ((UpliftDRFModel.UpliftDRFOutput)_output)._defaultAuucThresholds);
case BinomialUplift: return new ModelMetricsBinomialUplift.MetricBuilderBinomialUplift(domain, ((UpliftDRFModel.UpliftDRFOutput)_output)._defaultAuucThresholds, _parms._auuc_nbins, AUUC.calculateProbs( _parms._auuc_nbins));
default: throw H2O.unimpl();
}
}
Expand Down
8 changes: 3 additions & 5 deletions h2o-algos/src/main/java/hex/tree/uplift/UpliftDRF.java
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,9 @@ public ModelCategory[] can_build() {
if (hasOffsetCol())
error("_offset_column", "Offsets are not yet supported for Uplift DRF.");
if (hasWeightCol())
error("_weight_column", "Weights are not yet supported for Uplift DRF.");
if (hasFoldCol())
error("_fold_column", "Cross-validation is not yet supported for Uplift DRF.");
if (_parms._nfolds > 0)
error("_nfolds", "Cross-validation is not yet supported for Uplift DRF.");
if(!_parms._weights_column.equals("__internal_cv_weights__")) {
error("_weight_column", "Weights are not yet supported for Uplift DRF.");
}
if (_nclass == 1)
error("_distribution", "UpliftDRF currently support binomial classification problems only.");
if (_nclass > 2 || _parms._distribution.equals(DistributionFamily.multinomial))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public void initActualParamValues() {
}

@Override public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
return new ModelMetricsBinomialUplift.MetricBuilderBinomialUplift(domain, _output._defaultAuucThresholds);
return new ModelMetricsBinomialUplift.MetricBuilderBinomialUplift(domain, _output._defaultAuucThresholds, _parms._auuc_nbins, AUUC.calculateProbs(_parms._auuc_nbins));
}

@Override
Expand Down
82 changes: 49 additions & 33 deletions h2o-algos/src/test/java/hex/tree/uplift/UpliftDRFTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ public void testBasicTrain() {
p._response_column = "conversion";
p._seed = 0xDECAF;
p._ntrees = 3;
p._score_each_iteration = true;
p._auuc_nbins = 450;

UpliftDRF udrf = new UpliftDRF(p);
UpliftDRFModel model = udrf.trainModel().get();
Expand Down Expand Up @@ -153,17 +155,16 @@ public void testBasicTrainErrorDoNotSupportMultinomialResponseColumn() {
}
}


@Test(expected = H2OModelBuilderIllegalArgumentException.class)
public void testBasicTrainErrorDoNotSupportNfolds() {
public void testBasicTrainErrorDoNotSupportOffset() {
try {
Scope.enter();
Frame train = generateFrame();
UpliftDRFModel.UpliftDRFParameters p = new UpliftDRFModel.UpliftDRFParameters();
p._train = train._key;
p._treatment_column = "treatment";
p._response_column = "conversion";
p._nfolds = 10;
p._offset_column = "C1";

UpliftDRF udrf = new UpliftDRF(p);
udrf.trainModel().get();
Expand All @@ -173,61 +174,75 @@ public void testBasicTrainErrorDoNotSupportNfolds() {
}

@Test(expected = H2OModelBuilderIllegalArgumentException.class)
public void testBasicTrainErrorDoNotSupportFoldColumn() {
public void testBasicTrainErrorDoNotSupportDistribution() {
try {
Scope.enter();
Frame train = generateFrame();
UpliftDRFModel.UpliftDRFParameters p = new UpliftDRFModel.UpliftDRFParameters();
p._train = train._key;
p._treatment_column = "treatment";
p._response_column = "conversion";
p._fold_column = "C0";
p._distribution = DistributionFamily.multinomial;

UpliftDRF udrf = new UpliftDRF(p);
udrf.trainModel().get();
} finally {
Scope.exit();
}
}

@Test(expected = H2OModelBuilderIllegalArgumentException.class)
public void testBasicTrainErrorDoNotSupportOffset() {
@Test
public void testBasicTrainSupportEarlyStoppingAUUC() {
try {
Scope.enter();
Frame train = generateFrame();
int ntrees = 42;
UpliftDRFModel.UpliftDRFParameters p = new UpliftDRFModel.UpliftDRFParameters();
p._train = train._key;
p._treatment_column = "treatment";
p._response_column = "conversion";
p._offset_column = "C1";
p._stopping_metric = ScoreKeeper.StoppingMetric.AUUC;
p._stopping_rounds = 2;
p._ntrees = ntrees;
p._score_each_iteration = true;

UpliftDRF udrf = new UpliftDRF(p);
udrf.trainModel().get();
UpliftDRFModel model = udrf.trainModel().get();
Scope.track_generic(model);
assertNotNull(model);
assertTrue(model._output._treeStats._num_trees < ntrees);
} finally {
Scope.exit();
}
}

@Test(expected = H2OModelBuilderIllegalArgumentException.class)
public void testBasicTrainErrorDoNotSupportDistribution() {
@Test
public void testBasicTrainSupportEarlyStoppingATE() {
try {
Scope.enter();
Frame train = generateFrame();
int ntrees = 42;
UpliftDRFModel.UpliftDRFParameters p = new UpliftDRFModel.UpliftDRFParameters();
p._train = train._key;
p._treatment_column = "treatment";
p._response_column = "conversion";
p._distribution = DistributionFamily.multinomial;
p._stopping_metric = ScoreKeeper.StoppingMetric.ATE;
p._stopping_rounds = 2;
p._ntrees = ntrees;
p._score_each_iteration = true;

UpliftDRF udrf = new UpliftDRF(p);
udrf.trainModel().get();
UpliftDRFModel model = udrf.trainModel().get();
Scope.track_generic(model);
assertNotNull(model);
assertTrue(model._output._treeStats._num_trees < ntrees);
} finally {
Scope.exit();
}
}

@Test
public void testBasicTrainSupportEarlyStoppingAUUC() {
public void testBasicTrainSupportEarlyStoppingATT() {
try {
Scope.enter();
Frame train = generateFrame();
Expand All @@ -236,7 +251,7 @@ public void testBasicTrainSupportEarlyStoppingAUUC() {
p._train = train._key;
p._treatment_column = "treatment";
p._response_column = "conversion";
p._stopping_metric = ScoreKeeper.StoppingMetric.AUUC;
p._stopping_metric = ScoreKeeper.StoppingMetric.ATT;
p._stopping_rounds = 2;
p._ntrees = ntrees;
p._score_each_iteration = true;
Expand All @@ -252,7 +267,7 @@ public void testBasicTrainSupportEarlyStoppingAUUC() {
}

@Test
public void testBasicTrainSupportEarlyStoppingATE() {
public void testBasicTrainSupportEarlyStoppingATC() {
try {
Scope.enter();
Frame train = generateFrame();
Expand All @@ -261,7 +276,7 @@ public void testBasicTrainSupportEarlyStoppingATE() {
p._train = train._key;
p._treatment_column = "treatment";
p._response_column = "conversion";
p._stopping_metric = ScoreKeeper.StoppingMetric.ATE;
p._stopping_metric = ScoreKeeper.StoppingMetric.ATC;
p._stopping_rounds = 2;
p._ntrees = ntrees;
p._score_each_iteration = true;
Expand All @@ -277,7 +292,7 @@ public void testBasicTrainSupportEarlyStoppingATE() {
}

@Test
public void testBasicTrainSupportEarlyStoppingATT() {
public void testBasicTrainSupportEarlyStoppingQini() {
try {
Scope.enter();
Frame train = generateFrame();
Expand All @@ -286,7 +301,7 @@ public void testBasicTrainSupportEarlyStoppingATT() {
p._train = train._key;
p._treatment_column = "treatment";
p._response_column = "conversion";
p._stopping_metric = ScoreKeeper.StoppingMetric.ATT;
p._stopping_metric = ScoreKeeper.StoppingMetric.qini;
p._stopping_rounds = 2;
p._ntrees = ntrees;
p._score_each_iteration = true;
Expand All @@ -302,50 +317,51 @@ public void testBasicTrainSupportEarlyStoppingATT() {
}

@Test
public void testBasicTrainSupportEarlyStoppingATC() {
public void testBasicTrainSupportCV() {
try {
Scope.enter();
Frame train = generateFrame();
int ntrees = 42;
int ntrees = 10;
UpliftDRFModel.UpliftDRFParameters p = new UpliftDRFModel.UpliftDRFParameters();
p._train = train._key;
p._treatment_column = "treatment";
p._response_column = "conversion";
p._stopping_metric = ScoreKeeper.StoppingMetric.ATC;
p._stopping_rounds = 2;
p._ntrees = ntrees;
p._score_each_iteration = true;
p._nfolds = 3;
p._auuc_nbins = 400;

UpliftDRF udrf = new UpliftDRF(p);
UpliftDRFModel model = udrf.trainModel().get();
Scope.track_generic(model);
assertNotNull(model);
assertTrue(model._output._treeStats._num_trees < ntrees);
} finally {
Scope.exit();
}
}

@Test
public void testBasicTrainSupportEarlyStoppingQini() {
public void testSupportCVCriteo() {
try {
Scope.enter();
Frame train = generateFrame();
int ntrees = 42;
Frame train = Scope.track(parseTestFile("smalldata/uplift/criteo_uplift_13k.csv"));
train.toCategoricalCol("treatment");
train.toCategoricalCol("conversion");
UpliftDRFModel.UpliftDRFParameters p = new UpliftDRFModel.UpliftDRFParameters();
p._train = train._key;
p._ignored_columns = new String[]{"visit", "exposure"};
p._treatment_column = "treatment";
p._response_column = "conversion";
p._stopping_metric = ScoreKeeper.StoppingMetric.qini;
p._stopping_rounds = 2;
p._ntrees = ntrees;
p._seed = 0xDECAF;
p._ntrees = 11;
p._score_each_iteration = true;
p._nfolds = 3;
p._auuc_nbins = 50;

UpliftDRF udrf = new UpliftDRF(p);
UpliftDRFModel model = udrf.trainModel().get();
Scope.track_generic(model);
assertNotNull(model);
assertTrue(model._output._treeStats._num_trees < ntrees);
} finally {
Scope.exit();
}
Expand Down
Loading
Loading