Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhancement : faster getTerminalNodeIDs #90

Closed
pej opened this issue Jul 9, 2016 · 4 comments
Closed

Enhancement : faster getTerminalNodeIDs #90

pej opened this issue Jul 9, 2016 · 4 comments

Comments

@pej
Copy link

pej commented Jul 9, 2016

Hi.
Thanks a lot for this brilliant package.
I would like to derive proximity matrices from forests built with ranger. As far as I know there is currently no direct builtin functionality to do so.
Yet the getTerminalNodeIds function might be used. Unfortunately this function (due to its crude R implementation) is quite slow compared with forest learning or predicting processes. I was wondering if you could speed it up or if the learning or predicting functions may also return the getTerminalNodeIds matrix or even better a proximity matrix ?

Thanks

@mnwright
Copy link
Member

You are right, there is no option for the proximity matrix yet and getTerminalNodeIds() is very slow. An option would be to add this to the predict() function, e.g. by setting type = "terminalNodeID" or something similar.

@pej
Copy link
Author

pej commented Jul 12, 2016

Yes this would be great to have such a functionality !

In the meantime, for those that would like to get the terminal node IDs in a faster way, you could use the following RCpp based getTerminalNodeIDs function (which should work on numeric data) :

library(Rcpp)
cppFunction('
NumericVector getTerminalNodeIDsCPP(NumericVector childNodesIDs1,NumericVector childNodesIDs2,NumericVector splitValues,NumericVector splitVarIDs,NumericMatrix mat)
{
NumericVector res(mat.nrow());
for(int i=0;i<mat.nrow();i++)
{
int nodeID = 1;
double value =0;
while (true) {
if ( (childNodesIDs1[nodeID-1] == 0 && childNodesIDs2[nodeID-1] == 0))
{
break;
}
int splitVarID = splitVarIDs[nodeID-1];
value = mat(i,splitVarID-1);//obs[splitVarID-1];
if (value <= splitValues[nodeID-1]) {
nodeID = childNodesIDs1[nodeID-1] + 1;
}
else {
nodeID = childNodesIDs2[nodeID-1] + 1;
}
}
res[i]=nodeID;
}
return(res);
}
')

getTerminalNodeIDs2<-function(data,rf)
{
res=sapply(1:rf$num.trees, function(tree) {
getTerminalNodeIDsCPP(rf$forest$child.nodeIDs[[tree]][[1]],rf$forest$child.nodeIDs[[tree]][[2]],as.double(rf$forest$split.values[[tree]]),rf$forest$split.varIDs[[tree]],as.matrix(data))
})
return(res)
}

#example
library(ranger)
rf=ranger(Species ~ ., data = iris,write.forest=T)
y1=getTerminalNodeIDs2(iris[,-5],rf)
y2=getTerminalNodeIDs(rf,iris[,-5])
identical(y1,y2)
[1] TRUE

system.time(getTerminalNodeIDs(rf,iris[,-5]))
#utilisateur système écoulé

1.551 0.016 1.613

system.time(getTerminalNodeIDs2(iris[,-5],rf))
#utilisateur système écoulé

0.033 0.001 0.034

@mnwright
Copy link
Member

New option "terminalNodes" added to predict(), as described above.

@pej
Copy link
Author

pej commented Sep 22, 2016

Thanks a lot !

On 21 September 2016 at 22:33, Marvin N. Wright notifications@github.com
wrote:

New option "terminalNodes" added to predict(), as described above.


You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
#90 (comment), or mute
the thread
https://github.com/notifications/unsubscribe-auth/AAcme9NlnuvshLtHFGgtoODq0OOfXmOIks5qsZSIgaJpZM4JInJT
.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants