Skip to content

Commit

Permalink
Add tabnet_explain.model_fit
Browse files Browse the repository at this point in the history
  • Loading branch information
Philipp Haarmeyer committed Nov 3, 2022
1 parent cc78cc2 commit 2cab7d7
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 0 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ S3method(predict,tabnet_fit)
S3method(print,tabnet_fit)
S3method(print,tabnet_pretrain)
S3method(tabnet_explain,default)
S3method(tabnet_explain,model_fit)
S3method(tabnet_explain,tabnet_fit)
S3method(tabnet_explain,tabnet_pretrain)
S3method(tabnet_fit,data.frame)
Expand Down
5 changes: 5 additions & 0 deletions R/explain.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ tabnet_explain.tabnet_fit <- function(object, new_data) {
#' @rdname tabnet_explain
tabnet_explain.tabnet_pretrain <- tabnet_explain.tabnet_fit

#' @export
#' @rdname tabnet_explain
tabnet_explain.model_fit <- function(object, new_data) {
tabnet_explain(parsnip::extract_fit_engine(object), new_data)
}

convert_to_df <- function(x, nms) {
x <- as.data.frame(as.matrix(x$to(device = "cpu")$detach()))
Expand Down
3 changes: 3 additions & 0 deletions man/tabnet_explain.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 2cab7d7

Please sign in to comment.