-
Notifications
You must be signed in to change notification settings - Fork 0
Functional api guide #18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ndles multi input and output
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remaining comments which cannot be posted as a review comment to avoid GitHub Rate Limit
|
||
folds <- rsample::vfold_cv(iris, v = 2) | ||
params <- extract_parameter_set_dials(tune_wf) |> | ||
params <- extract_parameter_set_dials(tune_wf) |> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[air] reported by reviewdog 🐶
params <- extract_parameter_set_dials(tune_wf) |> | |
params <- extract_parameter_set_dials(tune_wf) |> |
input_block_1 <- function(input_shape) layer_input(shape = input_shape, name = "input_1") | ||
input_block_2 <- function(input_shape) layer_input(shape = input_shape, name = "input_2") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[air] reported by reviewdog 🐶
input_block_1 <- function(input_shape) layer_input(shape = input_shape, name = "input_1") | |
input_block_2 <- function(input_shape) layer_input(shape = input_shape, name = "input_2") | |
input_block_1 <- function(input_shape) { | |
layer_input(shape = input_shape, name = "input_1") | |
} | |
input_block_2 <- function(input_shape) { | |
layer_input(shape = input_shape, name = "input_2") | |
} |
flatten_b = inp_spec(flatten_block, "input_b"), | ||
path_a = inp_spec(dense_path, "flatten_a"), | ||
path_b = inp_spec(dense_path, "flatten_b"), | ||
concatenated = inp_spec(concat_block, c(path_a = "in_1", path_b = "in_2")), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[air] reported by reviewdog 🐶
concatenated = inp_spec(concat_block, c(path_a = "in_1", path_b = "in_2")), | |
concatenated = inp_spec( | |
concat_block, | |
c(path_a = "in_1", path_b = "in_2") | |
), |
expect_equal(names(preds), c(".pred_class")) | ||
expect_equal(nrow(preds), 5) | ||
expect_true(is.factor(preds$.pred_class)) | ||
}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[air] reported by reviewdog 🐶
}) | |
}) |
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[air] reported by reviewdog 🐶
kerasnip/R/register_fit_predict.R
Lines 195 to 200 in 8267e80
combined_preds <- purrr::map2_dfc(results, names(results), function(res, name) { | |
lvls <- object$fit$lvl[[name]] # Assuming object$fit$lvl is a named list of levels | |
if (is.null(lvls)) { | |
# Fallback if levels are not specifically named for this output | |
lvls <- paste0("class", 1:ncol(res)) # This might not be correct for classes, but a placeholder | |
} |
R/build_and_compile_model.R
Outdated
block_name | ||
]]) && | ||
length(y_processed$class_levels[[block_name]]) > 0, | ||
num_classes = if (!is.null(y_processed$class_levels[[block_name]])) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[air] reported by reviewdog 🐶
num_classes = if (!is.null(y_processed$class_levels[[block_name]])) { | |
num_classes = if ( | |
!is.null(y_processed$class_levels[[block_name]]) | |
) { |
R/build_and_compile_model.R
Outdated
block_hyperparams$num_classes <- current_y_info$num_classes | ||
} | ||
} | ||
} else { # Single output case |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[air] reported by reviewdog 🐶
} else { # Single output case | |
} else { | |
# Single output case |
R/register_fit_predict.R
Outdated
combined_preds <- purrr::map2_dfc(results, names(results), function(res, name) { | ||
lvls <- object$fit$lvl[[name]] # Assuming object$fit$lvl is a named list of levels | ||
if (is.null(lvls)) { | ||
# Fallback if levels are not specifically named for this output | ||
lvls <- paste0("class", 1:ncol(res)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[air] reported by reviewdog 🐶
combined_preds <- purrr::map2_dfc(results, names(results), function(res, name) { | |
lvls <- object$fit$lvl[[name]] # Assuming object$fit$lvl is a named list of levels | |
if (is.null(lvls)) { | |
# Fallback if levels are not specifically named for this output | |
lvls <- paste0("class", 1:ncol(res)) | |
combined_preds <- purrr::map2_dfc( | |
results, | |
names(results), | |
function(res, name) { | |
lvls <- object$fit$lvl[[name]] # Assuming object$fit$lvl is a named list of levels | |
if (is.null(lvls)) { | |
# Fallback if levels are not specifically named for this output | |
lvls <- paste0("class", 1:ncol(res)) | |
} | |
colnames(res) <- lvls | |
tibble::as_tibble(res, .name_repair = "unique") %>% | |
dplyr::rename_with(~ paste0(".pred_", name, "_", .x)) |
tibble::tibble(.pred_class = pred_class) %>% | ||
dplyr::rename_with(~ paste0(".pred_class_", name)) | ||
}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[air] reported by reviewdog 🐶
tibble::tibble(.pred_class = pred_class) %>% | |
dplyr::rename_with(~ paste0(".pred_class_", name)) | |
}) | |
) |
set_engine("keras") | ||
fit_3 <- fit(spec_3, mpg ~ ., data = mtcars) | ||
model_3_layers <- fit_3 |> | ||
extract_keras_model() |> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[air] reported by reviewdog 🐶
extract_keras_model() |> | |
extract_keras_model() |> |
output_block_1 <- function(tensor) layer_dense(tensor, units = 1, name = "output_1") | ||
output_block_2 <- function(tensor) layer_dense(tensor, units = 1, name = "output_2") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[air] reported by reviewdog 🐶
output_block_1 <- function(tensor) layer_dense(tensor, units = 1, name = "output_1") | |
output_block_2 <- function(tensor) layer_dense(tensor, units = 1, name = "output_2") | |
output_block_1 <- function(tensor) { | |
layer_dense(tensor, units = 1, name = "output_1") | |
} | |
output_block_2 <- function(tensor) { | |
layer_dense(tensor, units = 1, name = "output_2") | |
} |
input_b = input_block_2, | ||
path_a = inp_spec(dense_path, "input_a"), | ||
path_b = inp_spec(dense_path, "input_b"), | ||
concatenated = inp_spec(concat_block, c(path_a = "in_1", path_b = "in_2")), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[air] reported by reviewdog 🐶
concatenated = inp_spec(concat_block, c(path_a = "in_1", path_b = "in_2")), | |
concatenated = inp_spec( | |
concat_block, | |
c(path_a = "in_1", path_b = "in_2") | |
), |
expect_equal(nrow(preds), 5) | ||
expect_true(is.numeric(preds$.pred_output_1)) | ||
expect_true(is.numeric(preds$.pred_output_2)) | ||
}) No newline at end of file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[air] reported by reviewdog 🐶
}) | |
}) |
R/build_and_compile_model.R
Outdated
block_name | ||
]]) && | ||
length(y_processed$class_levels[[block_name]]) > 0, | ||
num_classes = if (!is.null(y_processed$class_levels[[block_name]])) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[air] reported by reviewdog 🐶
num_classes = if (!is.null(y_processed$class_levels[[block_name]])) { | |
num_classes = if ( | |
!is.null(y_processed$class_levels[[block_name]]) | |
) { |
R/build_and_compile_model.R
Outdated
block_hyperparams$num_classes <- current_y_info$num_classes | ||
} | ||
} | ||
} else { # Single output case |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[air] reported by reviewdog 🐶
} else { # Single output case | |
} else { | |
# Single output case |
] | ||
merged_args | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[air] reported by reviewdog 🐶
} | |
} |
R/register_fit_predict.R
Outdated
combined_preds <- purrr::map2_dfc(results, names(results), function(res, name) { | ||
lvls <- object$fit$lvl[[name]] # Assuming object$fit$lvl is a named list of levels | ||
if (is.null(lvls)) { | ||
# Fallback if levels are not specifically named for this output | ||
lvls <- paste0("class", 1:ncol(res)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[air] reported by reviewdog 🐶
combined_preds <- purrr::map2_dfc(results, names(results), function(res, name) { | |
lvls <- object$fit$lvl[[name]] # Assuming object$fit$lvl is a named list of levels | |
if (is.null(lvls)) { | |
# Fallback if levels are not specifically named for this output | |
lvls <- paste0("class", 1:ncol(res)) | |
combined_preds <- purrr::map2_dfc( | |
results, | |
names(results), | |
function(res, name) { | |
lvls <- object$fit$lvl[[name]] # Assuming object$fit$lvl is a named list of levels | |
if (is.null(lvls)) { | |
# Fallback if levels are not specifically named for this output | |
lvls <- paste0("class", 1:ncol(res)) | |
} | |
colnames(res) <- lvls | |
tibble::as_tibble(res, .name_repair = "unique") %>% | |
dplyr::rename_with(~ paste0(".pred_", name, "_", .x)) |
colnames(res) <- lvls | ||
tibble::as_tibble(res, .name_repair = "unique") %>% | ||
dplyr::rename_with(~ paste0(".pred_", name, "_", .x)) | ||
}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[air] reported by reviewdog 🐶
colnames(res) <- lvls | |
tibble::as_tibble(res, .name_repair = "unique") %>% | |
dplyr::rename_with(~ paste0(".pred_", name, "_", .x)) | |
}) | |
) |
set_engine("keras") | ||
fit_3 <- fit(spec_3, mpg ~ ., data = mtcars) | ||
model_3_layers <- fit_3 |> | ||
extract_keras_model() |> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[air] reported by reviewdog 🐶
extract_keras_model() |> | |
extract_keras_model() |> |
input_block_1 <- function(input_shape) layer_input(shape = input_shape, name = "input_1") | ||
input_block_2 <- function(input_shape) layer_input(shape = input_shape, name = "input_2") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[air] reported by reviewdog 🐶
input_block_1 <- function(input_shape) layer_input(shape = input_shape, name = "input_1") | |
input_block_2 <- function(input_shape) layer_input(shape = input_shape, name = "input_2") | |
input_block_1 <- function(input_shape) { | |
layer_input(shape = input_shape, name = "input_1") | |
} | |
input_block_2 <- function(input_shape) { | |
layer_input(shape = input_shape, name = "input_2") | |
} |
output_block_1 <- function(tensor) layer_dense(tensor, units = 1, name = "output_1") | ||
output_block_2 <- function(tensor) layer_dense(tensor, units = 1, name = "output_2") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[air] reported by reviewdog 🐶
output_block_1 <- function(tensor) layer_dense(tensor, units = 1, name = "output_1") | |
output_block_2 <- function(tensor) layer_dense(tensor, units = 1, name = "output_2") | |
output_block_1 <- function(tensor) { | |
layer_dense(tensor, units = 1, name = "output_1") | |
} | |
output_block_2 <- function(tensor) { | |
layer_dense(tensor, units = 1, name = "output_2") | |
} |
input_b = input_block_2, | ||
path_a = inp_spec(dense_path, "input_a"), | ||
path_b = inp_spec(dense_path, "input_b"), | ||
concatenated = inp_spec(concat_block, c(path_a = "in_1", path_b = "in_2")), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[air] reported by reviewdog 🐶
concatenated = inp_spec(concat_block, c(path_a = "in_1", path_b = "in_2")), | |
concatenated = inp_spec( | |
concat_block, | |
c(path_a = "in_1", path_b = "in_2") | |
), |
expect_equal(nrow(preds), 5) | ||
expect_true(is.numeric(preds$.pred_output_1)) | ||
expect_true(is.numeric(preds$.pred_output_2)) | ||
}) No newline at end of file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[air] reported by reviewdog 🐶
}) | |
}) |
Refactoring functional api to handle multi-input and multi-output correctly and updating guid to reflect so