Skip to content

Add the ability to determine what leaf nodes a test point falls in. #661

@yifan-cui

Description

@yifan-cui

First, incorporate find_leaf_node function into the package.

find_leaf_node <- function(tree, sample) {
  node <- 1
  while (TRUE) {
    if (tree$nodes[[node]]$is_leaf) {
      break()
    }
    split_var <- tree$nodes[[node]]$split_variable
    split_value <- tree$nodes[[node]]$split_value
    if (sample[split_var] <= split_value) {
      node <- tree$nodes[[node]]$left_child
    } else {
      node <- tree$nodes[[node]]$right_child
    }
  }
  node
}

Second, find_leaf_node gives external numbering but s.forest$_leaf_samples[[1]] is internal.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions