Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PUBDEV-4940 implement UpliftRandomForest #5224

Merged
merged 11 commits into from
Nov 30, 2021
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions h2o-algos/src/main/java/hex/DataInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
* Supports sparse data, sparse columns can be transformed to sparse rows on the fly with some (significant) memory overhead,
* as the data of the whole chunk(s) will be copied.
*
*/
*/
public class DataInfo extends Keyed<DataInfo> {
public int [] _activeCols;
public Frame _adaptedFrame; // the modified DataInfo frame (columns sorted by largest categorical -> least then all numerical columns)
Expand Down Expand Up @@ -65,7 +65,6 @@ public void addResponse(String [] names, Vec[] vecs) {

public double[] numNAFill() {return _numNAFill; }
public double numNAFill(int nid) {return _numNAFill[nid];}

public void setCatNAFill(int[] catNAFill) {
_catNAFill = catNAFill;
}
Expand Down Expand Up @@ -129,17 +128,18 @@ public boolean isSigmaScaled(){
public boolean _offset;
public boolean _weights;
public boolean _fold;
public boolean _treatment;
public Model.InteractionPair[] _interactions; // raw set of interactions
public Model.InteractionSpec _interactionSpec; // formal specification of interactions
public int _interactionVecs[]; // the interaction columns appearing in _adaptedFrame
public int[] _numOffsets; // offset column indices used by numerical interactions: total number of numerical columns is given by _numOffsets[_nums] - _numOffsets[0]
public int responseChunkId(int n){return n + _cats + _nums + (_weights?1:0) + (_offset?1:0) + (_fold?1:0);}
public int responseChunkId(int n){return n + _cats + _nums + (_weights?1:0) + (_offset?1:0) + (_fold?1:0) + (_treatment?1:0);}
public int treatmentChunkId(){return _cats + _nums + (_weights?1:0) + (_offset?1:0) + (_fold?1:0);}
public int foldChunkId(){return _cats + _nums + (_weights?1:0) + (_offset?1:0);}

public int offsetChunkId(){return _cats + _nums + (_weights ?1:0);}
public int weightChunkId(){return _cats + _nums;}
public int outputChunkId() { return outputChunkId(0);}
public int outputChunkId(int n) { return n + _cats + _nums + (_weights?1:0) + (_offset?1:0) + (_fold?1:0) + _responses;}
public int outputChunkId(int n) { return n + _cats + _nums + (_weights?1:0) + (_offset?1:0) + (_fold?1:0) + (_treatment?1:0) + _responses;}
public void addOutput(String name, Vec v) {_adaptedFrame.add(name,v);}
public Vec getOutputVec(int i) {return _adaptedFrame.vec(outputChunkId(i));}
public void setResponse(String name, Vec v){ setResponse(name,v,0);}
Expand All @@ -151,7 +151,7 @@ public boolean isSigmaScaled(){
public final int [][] _catLvls; // cat lvls post filter (e.g. by strong rules)
public final int [][] _intLvls; // interaction lvls post filter (e.g. by strong rules)

private DataInfo() { _intLvls=null; _catLvls = null; _skipMissing = true; _imputeMissing = false; _valid = false; _offset = false; _weights = false; _fold = false; }
private DataInfo() { _intLvls=null; _catLvls = null; _skipMissing = true; _imputeMissing = false; _valid = false; _offset = false; _weights = false; _fold = false; _treatment=false;}
public String[] _coefNames;
public int[] _coefOriginalIndices; //
@Override protected long checksum_impl() {throw H2O.unimpl();} // don't really need checksum
Expand All @@ -168,7 +168,15 @@ public DataInfo(Frame train, Frame valid, int nResponses, boolean useAllFactorLe
}

public DataInfo(Frame train, Frame valid, int nResponses, boolean useAllFactorLevels, TransformType predictor_transform, TransformType response_transform, boolean skipMissing, boolean imputeMissing, boolean missingBucket, boolean weight, boolean offset, boolean fold, Model.InteractionSpec interactions) {
this(train, valid, nResponses, useAllFactorLevels, predictor_transform, response_transform, skipMissing, imputeMissing, new MeanImputer(), missingBucket, weight, offset, fold, interactions);
this(train, valid, nResponses, useAllFactorLevels, predictor_transform, response_transform, skipMissing, imputeMissing, new MeanImputer(), missingBucket, weight, offset, fold, false, interactions);
}

public DataInfo(Frame train, Frame valid, int nResponses, boolean useAllFactorLevels, TransformType predictor_transform, TransformType response_transform, boolean skipMissing, boolean imputeMissing, boolean missingBucket, boolean weight, boolean offset, boolean fold, boolean treatment, Model.InteractionSpec interactions) {
this(train, valid, nResponses, useAllFactorLevels, predictor_transform, response_transform, skipMissing, imputeMissing, new MeanImputer(), missingBucket, weight, offset, fold, treatment, interactions);
}

public DataInfo(Frame train, Frame valid, int nResponses, boolean useAllFactorLevels, TransformType predictor_transform, TransformType response_transform, boolean skipMissing, boolean imputeMissing, Imputer imputer, boolean missingBucket, boolean weight, boolean offset, boolean fold, Model.InteractionSpec interactions) {
this(train, valid, nResponses, useAllFactorLevels, predictor_transform, response_transform, skipMissing, imputeMissing, imputer, missingBucket, weight, offset, fold, false, interactions);
}

/**
Expand All @@ -191,14 +199,15 @@ public DataInfo(Frame train, Frame valid, int nResponses, boolean useAllFactorLe
* A. As a list of pairs of column indices.
* B. As a list of pairs of column indices with limited enums.
*/
public DataInfo(Frame train, Frame valid, int nResponses, boolean useAllFactorLevels, TransformType predictor_transform, TransformType response_transform, boolean skipMissing, boolean imputeMissing, Imputer imputer, boolean missingBucket, boolean weight, boolean offset, boolean fold, Model.InteractionSpec interactions) {
public DataInfo(Frame train, Frame valid, int nResponses, boolean useAllFactorLevels, TransformType predictor_transform, TransformType response_transform, boolean skipMissing, boolean imputeMissing, Imputer imputer, boolean missingBucket, boolean weight, boolean offset, boolean fold, boolean treatment, Model.InteractionSpec interactions) {
super(Key.<DataInfo>make());
assert predictor_transform != null;
assert response_transform != null;
_valid = valid != null;
_offset = offset;
_weights = weight;
_fold = fold;
_treatment = treatment;
assert !(skipMissing && imputeMissing) : "skipMissing and imputeMissing cannot both be true";
_skipMissing = skipMissing;
_imputeMissing = imputeMissing;
Expand Down Expand Up @@ -228,7 +237,7 @@ public DataInfo(Frame train, Frame valid, int nResponses, boolean useAllFactorLe
final Vec[] tvecs = train.vecs();

// Count categorical-vs-numerical
final int n = tvecs.length-_responses - (offset?1:0) - (weight?1:0) - (fold?1:0);
final int n = tvecs.length-_responses - (offset?1:0) - (weight?1:0) - (fold?1:0) - (treatment?1:0);
int [] nums = MemoryManager.malloc4(n);
int [] cats = MemoryManager.malloc4(n);
int nnums = 0, ncats = 0;
Expand Down Expand Up @@ -323,7 +332,7 @@ public DataInfo(Frame train, Frame valid, int nResponses, boolean useAllFactorLe
numIdx++;
}
}
for(int i = names.length-nResponses - (weight?1:0) - (offset?1:0) - (fold?1:0); i < names.length; ++i) {
for(int i = names.length-nResponses - (weight?1:0) - (offset?1:0) - (fold?1:0) - (treatment?1:0); i < names.length; ++i) {
names[i] = train._names[i];
tvecs2[i] = train.vec(i);
}
Expand Down Expand Up @@ -423,6 +432,7 @@ private DataInfo(DataInfo dinfo,Frame fr, double [] normMul, double [] normSub,
_offset = dinfo._offset;
_weights = dinfo._weights;
_fold = dinfo._fold;
_treatment = dinfo._treatment;
_valid = false;
_interactions = null;
ArrayList<Integer> interactionVecs = new ArrayList<>();
Expand Down Expand Up @@ -588,7 +598,7 @@ public DataInfo filterExpandedColumns(int [] cols){
normMul[k-id] = _normMul[cols[k]-off];
}
DataInfo dinfo = new DataInfo(this,f,normMul,normSub,catLvls,intLvls,catModes,cols);
dinfo._nums=f.numCols()-dinfo._cats - dinfo._responses - (dinfo._offset?1:0) - (dinfo._weights?1:0) - (dinfo._fold?1:0);
dinfo._nums=f.numCols()-dinfo._cats - dinfo._responses - (dinfo._offset?1:0) - (dinfo._weights?1:0) - (dinfo._fold?1:0) - (dinfo._treatment?1:0);
dinfo._numMeans=new double[nnums];
dinfo._numNAFill=new double[nnums];
int colsSize = id+nnums; // small optimization
Expand Down Expand Up @@ -735,6 +745,7 @@ public boolean isInteractionVec(int colid) {
* weight column
* offset column
* fold column
* treatment column
*
* @return expanded number of columns in the underlying frame
*/
Expand Down Expand Up @@ -1450,6 +1461,7 @@ public DataInfo scoringInfo(String[] names, Frame adaptFrame, int nResponses, bo
res._weights = _weights && adaptFrame.find(names[weightChunkId()]) != -1;
res._offset = _offset && adaptFrame.find(names[offsetChunkId()]) != -1;
res._fold = _fold && adaptFrame.find(names[foldChunkId()]) != -1;
res._treatment = _treatment && adaptFrame.find(names[treatmentChunkId()]) != -1;
if (nResponses != -1) {
res._responses = nResponses;
} else {
Expand Down
1 change: 1 addition & 0 deletions h2o-algos/src/main/java/hex/api/RegisterAlgos.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public void registerEndPoints(RestApiContext context) {
new ANOVAGLM(true),
new PSVM(true),
new hex.rulefit .RuleFit (true),
new hex.tree.uplift.UpliftDRF (true),
new hex.maxrglm .MaxRGLM (true)
};

Expand Down
2 changes: 2 additions & 0 deletions h2o-algos/src/main/java/hex/generic/GenericModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ public String getResponseColumn() {
return genModel.isSupervised() ? genModel.getResponseName() : null;
}
@Override
public String getTreatmentColumn() {return null;}
@Override
public double missingColumnsType() {
return Double.NaN;
}
Expand Down
26 changes: 26 additions & 0 deletions h2o-algos/src/main/java/hex/schemas/UpliftDRFModelV3.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package hex.schemas;

import hex.tree.uplift.UpliftDRFModel;

public class UpliftDRFModelV3 extends SharedTreeModelV3<UpliftDRFModel,
UpliftDRFModelV3,
UpliftDRFModel.UpliftDRFParameters,
UpliftDRFV3.UpliftDRFParametersV3,
UpliftDRFModel.UpliftDRFOutput,
UpliftDRFModelV3.UpliftDRFModelOutputV3> {

public static final class UpliftDRFModelOutputV3 extends SharedTreeModelV3.SharedTreeModelOutputV3<UpliftDRFModel.UpliftDRFOutput, UpliftDRFModelOutputV3> {}

public UpliftDRFV3.UpliftDRFParametersV3 createParametersSchema() { return new UpliftDRFV3.UpliftDRFParametersV3(); }
public UpliftDRFModelOutputV3 createOutputSchema() { return new UpliftDRFModelOutputV3(); }

//==========================
// Custom adapters go here

// Version&Schema-specific filling into the impl
@Override public UpliftDRFModel createImpl() {
UpliftDRFV3.UpliftDRFParametersV3 p = this.parameters;
UpliftDRFModel.UpliftDRFParameters parms = p.createImpl();
return new UpliftDRFModel( model_id.key(), parms, new UpliftDRFModel.UpliftDRFOutput(null) );
}
}
73 changes: 73 additions & 0 deletions h2o-algos/src/main/java/hex/schemas/UpliftDRFV3.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package hex.schemas;

import hex.AUUC;
import hex.tree.uplift.UpliftDRF;
import hex.tree.uplift.UpliftDRFModel.UpliftDRFParameters;
import water.api.API;

public class UpliftDRFV3 extends SharedTreeV3<UpliftDRF, UpliftDRFV3, UpliftDRFV3.UpliftDRFParametersV3> {

public static final class UpliftDRFParametersV3 extends SharedTreeV3.SharedTreeParametersV3<UpliftDRFParameters, UpliftDRFParametersV3> {
static public String[] fields = new String[]{
"model_id",
"training_frame",
"validation_frame",
"score_each_iteration",
"score_tree_interval",
"response_column",
"ignored_columns",
"ignore_const_cols",
"balance_classes",
"class_sampling_factors",
"max_after_balance_size",
"ntrees",
"max_depth",
"min_rows",
"nbins",
"nbins_top_level",
"nbins_cats",
"max_runtime_secs",
"seed",
"mtries",
"sample_rate",
"sample_rate_per_class",
"checkpoint",
"col_sample_rate_change_per_level",
"col_sample_rate_per_tree",
"histogram_type",
"categorical_encoding",
"calibrate_model",
"calibration_frame",
"distribution",
"custom_metric_func",
"export_checkpoints_dir",
"check_constant_response",
"treatment_column",
"uplift_metric",
"auuc_type",
"auuc_nbins"
};

// Input fields
@API(help = "Number of variables randomly sampled as candidates at each split. If set to -1, defaults to sqrt{p} for classification and p/3 for regression (where p is the # of predictors", gridable = true)
public int mtries;

@API(help = "Row sample rate per tree (from 0.0 to 1.0)", gridable = true)
public double sample_rate;

@API(help = "Define column which will be use for computing uplift gain to select best split for a tree. The column has to devide dataset into treatment (value 1) and control (value 0) group.", gridable = false, level = API.Level.secondary, required = true,
is_member_of_frames = {"training_frame", "validation_frame"},
is_mutually_exclusive_with = {"ignored_columns","response_column", "weights_column"})
public String treatment_column;

@API(help = "Divergence metric used to find best split when building an upplift tree.", level = API.Level.secondary, values = { "AUTO", "KL", "Euclidean", "ChiSquared"})
public UpliftDRFParameters.UpliftMetricType uplift_metric;

@API(help = "AUUC metric used to calculate Area under Uplift.", level = API.Level.secondary, values = { "AUTO", "Qini", "Lift", "Gain"})
public AUUC.AUUCType auuc_type;

@API(help = "Number of bins to calculate Area under Uplift.", level = API.Level.secondary)
public int auuc_nbins;

}
}
Loading