Skip to content

hyerarchical model training is not considering the ancestor matrix #187

@cregouby

Description

@cregouby

current situation

according to the hardhat.R code :

tabnet/R/hardhat.R

Lines 166 to 175 in 34e143a

ancestor <- data.tree::ToDataFrameNetwork(x) %>%
mutate_if(is.character, ~.x %>% as.factor %>% as.numeric)
# TODO check correctness
# embed the M matrix in the config$ancestor variable
dims <- c(max(ancestor), max(ancestor))
ancestor_m <- Matrix::sparseMatrix(ancestor$from, ancestor$to, dims = dims, x = 1)
check_type(processed$outcomes)
config <- merge_config_and_dots(config, ...)
tabnet_bridge(processed, config = config, tabnet_model, from_epoch, task = "supervised")

ancestor matrix is computed but not passed to the tabnet model to be fitted.

Expected situation

ancestor matrix is passed through via a config attribute to the model

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions