Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PUBDEV-7909 fix how gain is calculated in xgbfi for gbm #5141

Merged
merged 8 commits into from
Nov 23, 2020

Conversation

koniecsveta
Copy link
Contributor

@koniecsveta koniecsveta commented Nov 19, 2020

No description provided.

@koniecsveta koniecsveta force-pushed the zuzana/PUBDEV-7909/fix_xgbfi_for_gbm branch 2 times, most recently from 15681bd to 9aa549d Compare November 19, 2020 15:22
}

@Test
public void testXGBoostFeatureInteractionsCheckRanksVsVarimp() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XGBoost?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}
}

DKV.remove(f._key);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed since already tracked by scope, will remove

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


SharedTreeSubgraph treeSubgraph = model.getSharedTreeSubgraph(0, 0);

String[] keysToCheck = new String[]{"DPROS", "ID", "PSA", "GLEASON", "VOL"};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should ID column be in ignored_columns?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -609,4 +610,17 @@ public final int getRightChildIndex() {
return rightChild != null ? rightChild.internalId : -1;
}

public float getGain() {
if (!Float.isNaN(gain)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this because the same code handles XGB and GBM?

if so it would be better to have a flag, eg.: useSEforGain and switch on that

Copy link
Contributor Author

@koniecsveta koniecsveta Nov 19, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, useSEforGain will be better

}

@Test
public void testXGBoostFeatureInteractionsCheckRanksVsVarimp() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

confusing test name

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


DKV.remove(f._key);
} finally {
if( model != null ) model.delete();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not use scope?

@@ -609,4 +610,15 @@ public final int getRightChildIndex() {
return rightChild != null ? rightChild.internalId : -1;
}

public float getGain(boolean useSEforGain) {
if (useSEforGain) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is useSEforGain defined? I dont think useSquaredErrorForGain is too long, but much more readable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it depends on whether you call GBMModel.getFeatureInteractions(...) or XGBoostModel.getFeatureInteractions(). for XGB, gain is prepared in boosterbyte and set in XGBoostModel.constructSubgraph(...) so useSquaredErrorForGain = false, for GBM it's not prepared and it has to be calculated from squared error, so useSquaredErrorForGain = true

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, stupid question, I missed the method arg boolean useSEforGain
knit: would also use the long name here

@koniecsveta koniecsveta force-pushed the zuzana/PUBDEV-7909/fix_xgbfi_for_gbm branch from 839f6e3 to a95a715 Compare November 19, 2020 16:54
Copy link
Contributor

@honzasterba honzasterba left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good, minor improvement suggestions

try {
Frame f = Scope.track(parse_test_file("smalldata/logreg/prostate.csv"));
f.replace(f.find("CAPSULE"), f.vec("CAPSULE").toNumericVec());
DKV.put(f);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unnecessary I think

@@ -609,4 +610,15 @@ public final int getRightChildIndex() {
return rightChild != null ? rightChild.internalId : -1;
}

public float getGain(boolean useSEforGain) {
if (useSEforGain) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, stupid question, I missed the method arg boolean useSEforGain
knit: would also use the long name here

for (Map.Entry<String, FeatureInteraction> featureInteraction : featureInteractions1.entrySet()) {
list1.add(new KeyValue(featureInteraction.getKey(), featureInteraction.getValue().gain));
}
list1.sort((a,b) -> a.getValue() < b.getValue() ? -1 : a.getValue() == b.getValue() ? 0 : 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could be simplified with Comparator.comparing(KeyValue::getValue)

}
list2.sort((a,b) -> a.getValue() < b.getValue() ? -1 : a.getValue() == b.getValue() ? 0 : 1);

List<String> sortedKeys1 = list1.stream().map(KeyValue::getKey).collect(Collectors.toList());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorting could be merged into the stream().sorted(Comparator.comparing().map()...

@koniecsveta koniecsveta force-pushed the zuzana/PUBDEV-7909/fix_xgbfi_for_gbm branch from 442a5e1 to d9f409f Compare November 20, 2020 16:23
@koniecsveta koniecsveta merged commit bcdc149 into rel-zermelo Nov 23, 2020
@koniecsveta koniecsveta deleted the zuzana/PUBDEV-7909/fix_xgbfi_for_gbm branch November 23, 2020 08:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants