Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PUBDEV-8927: Remove redundant predictors if found for backward mode #6446

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions h2o-algos/src/main/java/hex/modelselection/ModelSelection.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import water.Key;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.PrettyPrint;

import java.lang.reflect.Field;
Expand All @@ -19,7 +18,6 @@
import java.util.stream.IntStream;
import java.util.stream.Stream;

import static hex.gam.MatrixFrameUtils.GamUtils.copy2DArray;
import static hex.genmodel.utils.MathUtils.combinatorial;
import static hex.glm.GLMModel.GLMParameters.Family.*;
import static hex.modelselection.ModelSelectionModel.ModelSelectionParameters.Mode.*;
Expand Down Expand Up @@ -380,17 +378,15 @@ void buildMaxRModels(ModelSelectionModel model) {
* https://h2oai.atlassian.net/browse/PUBDEV-8428
*/
private int buildBackwardModels(ModelSelectionModel model) {
List<String> coefNames = new ArrayList<>(Arrays.asList(_predictorNames));
List<Integer> coefIndice = IntStream.rangeClosed(0, coefNames.size()-1).boxed().collect(Collectors.toList());
List<String> predNames = new ArrayList<>(Arrays.asList(_predictorNames));
Frame train = DKV.getGet(_parms._train);
List<String> numPredNames = coefNames.stream().filter(x -> train.vec(x).isNumeric()).collect(Collectors.toList());
List<String> catPredNames = coefNames.stream().filter(x -> !numPredNames.contains(x)).collect(Collectors.toList());
List<String> numPredNames = predNames.stream().filter(x -> train.vec(x).isNumeric()).collect(Collectors.toList());
List<String> catPredNames = predNames.stream().filter(x -> !numPredNames.contains(x)).collect(Collectors.toList());
int numModelsBuilt = 0;
String[] coefName = coefNames.toArray(new String[0]);
String[] coefName = predNames.toArray(new String[0]);
for (int predNum = _numPredictors; predNum >= _parms._min_predictor_number; predNum--) {
int modelIndex = predNum-1;
int[] coefInd = coefIndice.stream().mapToInt(Integer::intValue).toArray();
Frame trainingFrame = generateOneFrame(coefInd, _parms, coefName, _foldColumn);
Frame trainingFrame = generateOneFrame(null, _parms, coefName, _foldColumn);
DKV.put(trainingFrame);
GLMModel.GLMParameters[] glmParam = generateGLMParameters(new Frame[]{trainingFrame}, _parms,
_glmNFolds, _foldColumn, _foldAssignment);
Expand All @@ -399,17 +395,20 @@ private int buildBackwardModels(ModelSelectionModel model) {

// evaluate which variable to drop for next round of testing and store corresponding values
// if p_values_threshold is specified, model building may stop
model._output.extractPredictors4NextModel(glmModel, modelIndex, coefNames, coefIndice, numPredNames,
model._output.extractPredictors4NextModel(glmModel, modelIndex, predNames, numPredNames,
catPredNames);
numModelsBuilt++;
DKV.remove(trainingFrame._key);
coefName = predNames.toArray(new String[0]);
_job.update(predNum, "Finished building all models with "+predNum+" predictors.");
if (_parms._p_values_threshold > 0) { // check if p-values are used to stop model building
if (DoubleStream.of(model._output._coef_p_values[modelIndex])
.limit(model._output._coef_p_values[modelIndex].length-1)
.allMatch(x -> x <= _parms._p_values_threshold))
break;
}
if (predNames.size() == 0) // no more predictors available to build models with
break;
}
return numModelsBuilt;
}
Expand Down
98 changes: 85 additions & 13 deletions h2o-algos/src/main/java/hex/modelselection/ModelSelectionModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
import water.util.TwoDimTable;

import java.io.Serializable;
import java.util.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
Expand Down Expand Up @@ -328,6 +331,15 @@ void updateBestModels(GLMModel bestModel, int index) {
updateAddedRemovedPredictors(index);
}

void extractCoeffs(GLMModel model, int index) {
_coefficient_names[index] = model._output.coefficientNames().clone(); // all coefficients
ArrayList<String> coeffNames = new ArrayList<>(Arrays.asList(model._output.coefficientNames()));
_coefficient_names[index] = coeffNames.toArray(new String[0]); // without intercept
List<String> predNames = Stream.of(model.names()).collect(Collectors.toList());
predNames.remove(model._parms._response_column);
_best_predictors_subset[index] = predNames.stream().toArray(String[]::new);
}

void updateBestModels(String[] predictorNames, List<String> allCoefNames, int index, boolean hasIntercept,
int actualCPMSize, int[] predsubset, double[][] lastCPM, double r2Scale,
CoeffNormalization coeffN, int[][] pred2CPMIndex, DataInfo dinfo) {
Expand Down Expand Up @@ -453,28 +465,88 @@ void updateAddedRemovedPredictors(int index) {
_predictors_added_per_step[index] = new String[]{""};
}

void extractCoeffs(GLMModel model, int index) {
_coefficient_names[index] = model._output.coefficientNames().clone(); // all coefficients
ArrayList<String> coeffNames = new ArrayList<>(Arrays.asList(model._output.coefficientNames()));
_coefficient_names[index] = coeffNames.toArray(new String[0]); // without intercept
List<String> predNames = Stream.of(model.names()).collect(Collectors.toList());
predNames.remove(model._parms._response_column);
_best_predictors_subset[index] = predNames.stream().toArray(String[]::new);
/**
* Method to remove redundant predictors at the beginning of backward method.
*/
void resetCoeffs(GLMModel model, List<String> predNames, List<String> numPredNames, List<String> catPredNames) {
final String[] coeffName = model._output.coefficientNames();
int[] idxs = model._output.bestSubmodel().idxs;
if (idxs == null) // no redundant predictors
return;
List<String> coeffNames = Arrays.stream(idxs).mapToObj(x -> coeffName[x]).collect(Collectors.toList());
resetAllPreds(predNames, catPredNames, numPredNames, model, coeffNames); // remove redundant preds
}

void resetAllPreds(List<String> predNames, List<String> catPredNames, List<String> numPredNames,
GLMModel model, List<String> coeffNames) {
if (model._output.bestSubmodel().idxs.length == model.coefficients().size()) // no redundant predictors
return;
resetNumPredNames(numPredNames, coeffNames);
resetCatPredNames(model.dinfo(), model._output.bestSubmodel().idxs, catPredNames);
if (predNames.size() > (numPredNames.size() + catPredNames.size())) {
predNames.clear();
predNames.addAll(catPredNames);
predNames.addAll(numPredNames);
}
}

public void resetNumPredNames(List<String> numPredNames, List<String> coeffNames) {
List<String> newNumPredNames = numPredNames.stream().filter(x -> coeffNames.contains(x)).collect(Collectors.toList());
numPredNames.clear();
numPredNames.addAll(newNumPredNames);
}

public void resetCatPredNames(DataInfo dinfo, int[] idxs, List<String> catPredNames) {
List<String> newCatPredNames = new ArrayList<>();
List<Integer> idxsList = Arrays.stream(idxs).boxed().collect(Collectors.toList());
int[] catOffset = dinfo._catOffsets;
int catIndex = catOffset.length;
int maxCatOffset = catOffset[catIndex-1];
for (int index=1; index<catIndex; index++) {
int offsetedIndex = index-1;
List<Integer> currCatList = IntStream.range(catOffset[offsetedIndex], catOffset[index]).boxed().collect(Collectors.toList());
if (currCatList.stream().filter(x -> idxsList.contains(x)).count() > 0 && currCatList.get(currCatList.size()-1) < maxCatOffset) {
newCatPredNames.add(catPredNames.get(offsetedIndex));
}
}
if (newCatPredNames.size() < catPredNames.size()) {
catPredNames.clear();
catPredNames.addAll(newCatPredNames);
}
}

/***
* Eliminate predictors with lowest z-value (z-score) magnitude as described in III of
* ModelSelectionTutorial.pdf in https://h2oai.atlassian.net/browse/PUBDEV-8428
*/
void extractPredictors4NextModel(GLMModel model, int index, List<String> predNames, List<Integer> predIndices,
List<String> numPredNames, List<String> catPredNames) {
void extractPredictors4NextModel(GLMModel model, int index, List<String> predNames, List<String> numPredNames,
List<String> catPredNames) {
boolean firstRun = (index+1) == predNames.size();
List<String> oldPredNames = firstRun ? new ArrayList<>(predNames) : null;
extractCoeffs(model, index);
_best_model_ids[index] = model.getKey();
int predIndex2Remove = findMinZValue(model, numPredNames, catPredNames, predNames);
_predictors_removed_per_step[index] = new String[] {predNames.get(predIndex2Remove)};
predIndices.remove(predIndices.indexOf(predIndex2Remove));
String pred2Remove = predNames.get(predIndex2Remove);
if (firstRun) // remove redundant predictors if present
resetCoeffs(model, predNames, numPredNames, catPredNames);
List<String> redundantPred = firstRun ?
oldPredNames.stream().filter(x -> !predNames.contains(x)).collect(Collectors.toList()) : null;
_best_model_ids[index] = model.getKey();

if (redundantPred != null && redundantPred.size() > 0) {
redundantPred = redundantPred.stream().map(x -> x+"(redundant_predictor)").collect(Collectors.toList());
redundantPred.add(pred2Remove);
_predictors_removed_per_step[index] = redundantPred.stream().toArray(String[]::new);
} else {
_predictors_removed_per_step[index] = new String[]{pred2Remove};
}

_z_values[index] = model._output.zValues().clone();
_coef_p_values[index] = model._output.pValues().clone();
predNames.remove(pred2Remove);
if (catPredNames.contains(pred2Remove))
catPredNames.remove(pred2Remove);
else
numPredNames.remove(pred2Remove);
}
}

Expand Down
60 changes: 27 additions & 33 deletions h2o-algos/src/main/java/hex/modelselection/ModelSelectionUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,10 @@ public static Frame generateOneFrame(int[] predIndices, Model.Parameters parms,
String foldColumn) {
final Frame predVecs = new Frame(Key.make());
final Frame train = parms.train();
int numPreds = predIndices.length;
boolean usePredIndices = predIndices != null;
int numPreds = usePredIndices? predIndices.length : predNames.length;
for (int index = 0; index < numPreds; index++) {
int predVecNum = predIndices[index];
int predVecNum = usePredIndices ? predIndices[index] : index;
predVecs.add(predNames[predVecNum], train.vec(predNames[predVecNum]));
}
if (parms._weights_column != null)
Expand Down Expand Up @@ -802,10 +803,8 @@ public static int findMinZValue(GLMModel model, List<String> numPredNames, List<

// choose the min z-value from numerical and categorical predictors and return its index in predNames
if (categoricalPred._minZVal >= 0 && categoricalPred._minZVal < numericalPred._minZVal) { // categorical pred has minimum z-value
catPredNames.remove(catPredNames.indexOf(categoricalPred._predName));
return predNames.indexOf(categoricalPred._predName);
} else { // numerical pred has minimum z-value
numPredNames.remove(numPredNames.indexOf(numericalPred._predName));
return predNames.indexOf(numericalPred._predName);
}
}
Expand All @@ -819,8 +818,8 @@ public static PredNameMinZVal findNumMinZVal(List<String> numPredNames, List<Dou
int eleInd = coeffNames.indexOf(predName);
double oneZValue = zValList.get(eleInd);
if (Double.isNaN(oneZValue)) {
zValList.set(eleInd, 0.0);
numZValues.add(0.0); // NaN corresponds to coefficient of 0.0
zValList.set(eleInd, Double.POSITIVE_INFINITY);
numZValues.add(Double.POSITIVE_INFINITY); // NaN corresponds to inactive predictors
} else {
numZValues.add(oneZValue);
}
Expand All @@ -847,25 +846,30 @@ public static PredNameMinZVal findCatMinZVal(GLMModel model, List<Double> zValLi
String catPredMinZ = null;
if (catOffsets != null) {
minCatVal = Double.MAX_VALUE;
int numCatCol = catOffsets.length-1;
int numCatCol = catOffsets.length - 1;

for (int catInd = 0; catInd < numCatCol; catInd++) { // go through each categorical column
List<Double> catZValues = new ArrayList<>();
int nextCatOffset = catOffsets[catInd+1];
for (int eleInd = catOffsets[catInd]; eleInd < nextCatOffset; eleInd++) { // check z-value for each level
double oneZVal = zValList.get(eleInd);
if (Double.isNaN(oneZVal)) {
zValList.set(eleInd, 0.0);
catZValues.add(0.0);
} else {
catZValues.add(oneZVal);
int numNaN = (int) zValList.stream().filter(x -> Double.isNaN(x)).count();
if (numNaN == zValList.size()) { // if all levels are NaN, this predictor is redundant
new PredNameMinZVal(catPredMinZ, Double.POSITIVE_INFINITY);
} else {
for (int catInd = 0; catInd < numCatCol; catInd++) { // go through each categorical column
List<Double> catZValues = new ArrayList<>();
int nextCatOffset = catOffsets[catInd + 1];
for (int eleInd = catOffsets[catInd]; eleInd < nextCatOffset; eleInd++) { // check z-value for each level
double oneZVal = zValList.get(eleInd);
if (Double.isNaN(oneZVal)) { // one level is inactivity, let other levels be used
zValList.set(eleInd, 0.0);
catZValues.add(0.0);
} else {
catZValues.add(oneZVal);
}
}
}
if (catZValues.size() > 0) {
double oneCatMinZ = catZValues.stream().max(Double::compare).get(); // choose the best z-value here
if (oneCatMinZ < minCatVal) {
minCatVal = oneCatMinZ;
catPredMinZ = columnNames[catInd];
if (catZValues.size() > 0) {
double oneCatMinZ = catZValues.stream().max(Double::compare).get(); // choose the best z-value here
if (oneCatMinZ < minCatVal) {
minCatVal = oneCatMinZ;
catPredMinZ = columnNames[catInd];
}
}
}
}
Expand All @@ -892,16 +896,6 @@ public static List<String> extraModelColumnNames(List<String> coefNames, GLMMode
}
return coefUsed;
}

public static void updateValidSubset(List<Integer> validSubset, List<Integer> originalSubset,
List<Integer> currSubsetIndices) {
List<Integer> onlyInOriginal = new ArrayList<>(originalSubset);
onlyInOriginal.removeAll(currSubsetIndices);
List<Integer> onlyInCurr = new ArrayList<>(currSubsetIndices);
onlyInCurr.removeAll(originalSubset);
validSubset.addAll(onlyInOriginal);
validSubset.removeAll(onlyInCurr);
}

/***
* Given a predictor subset and the complete CPM, we extract the CPM associated with the predictors
Expand Down
Loading