Skip to content
This repository has been archived by the owner on Oct 1, 2020. It is now read-only.

Commit

Permalink
distr6 fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Raphael committed Aug 5, 2020
1 parent d918385 commit 34e6d29
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 29 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mlr3learners.flexsurv
Title: Learners from the {flexsurv} package for 'mlr3'
Version: 0.1.1.9000
Version: 0.1.2
Authors@R:
person(given = "Raphael",
family = "Sonabend",
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# mlr3learners.flexsurv 0.1.2

- distr6 patch

# mlr3learners.flexsurv 0.1.1

- Removed remotes dependencies
Expand Down
44 changes: 18 additions & 26 deletions R/predict_flexsurvreg.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,35 +38,31 @@ predict_flexsurvreg <- function(object, task, ...) {
# Define the d/p/q/r methods using the d/p/q/r methods that are automatically generated in the
# fitted object. The parameters referenced are defined below and are based on the gamma
# parameters above.
pdf = function(x1) {
}
pdf = function(x) {} # nolint
body(pdf) = substitute({
fn = func
args = as.list(subset(as.data.table(self$parameters()), select = "value"))$value
names(args) = unname(unlist(as.data.table(self$parameters())[, 1]))
do.call(fn, c(list(x = x1), args))
do.call(fn, c(list(x = x), args))
}, list(func = object$dfns$d))

cdf = function(x1) {
}
cdf = function(x) {} # nolint
body(cdf) = substitute({
fn = func
args = as.list(subset(as.data.table(self$parameters()), select = "value"))$value
names(args) = unname(unlist(as.data.table(self$parameters())[, 1]))
do.call(fn, c(list(q = x1), args))
do.call(fn, c(list(q = x), args))
}, list(func = object$dfns$p))

quantile = function(p) {
}
quantile = function(p) {} # nolint
body(quantile) = substitute({
fn = func
args = as.list(subset(as.data.table(self$parameters()), select = "value"))$value
names(args) = unname(unlist(as.data.table(self$parameters())[, 1]))
do.call(fn, c(list(p = p), args))
}, list(func = object$dfns$q))

rand = function(n) {
}
rand = function(n) {} # nolint
body(rand) = substitute({
fn = func
args = as.list(subset(as.data.table(self$parameters()), select = "value"))$value
Expand Down Expand Up @@ -94,18 +90,6 @@ predict_flexsurvreg <- function(object, task, ...) {
pargs = data.table::data.table(matrix(args, ncol = ncol(pars), nrow = length(args)))
pars = rbind(pars, pargs)

params = lapply(pars, function(x) {
x = as.list(x)
names(x) = c(object$dlist$pars, names(args))
yparams = parameters$clone(deep = TRUE)
ind = match(yparams$.__enclos_env__$private$.parameters$id, names(x))
yparams$.__enclos_env__$private$.parameters$value = x[ind]

yparams
})

params = lapply(params, function(x) list(parameters = x))

shared_params = list(
name = "Flexible Parameteric",
short_name = "Flexsurv",
Expand All @@ -115,13 +99,21 @@ predict_flexsurvreg <- function(object, task, ...) {
variateForm = "univariate",
description = "Royston/Parmar Flexible Parametric Survival Model",
.suppressChecks = TRUE,
suppressMoments = TRUE,
pdf = pdf, cdf = cdf, quantile = quantile, rand = rand
)

distr = distr6::VectorDistribution$new(
distribution = "Distribution", params = params,
shared_params = shared_params, decorators = c("CoreStatistics", "ExoticStatistics"))
distlist = lapply(pars, function(x) {
x = as.list(x)
names(x) = c(object$dlist$pars, names(args))
yparams = parameters$clone(deep = TRUE)
ind = match(yparams$.__enclos_env__$private$.parameters$id, names(x))
yparams$.__enclos_env__$private$.parameters$value = x[ind]

do.call(distr6::Distribution$new, c(list(parameters = yparams), shared_params))
})

distr = distr6::VectorDistribution$new(distlist,
decorators = c("CoreStatistics", "ExoticStatistics"))

return(list(distr = distr, lp = lp))
}
5 changes: 3 additions & 2 deletions tests/testthat/test_surv_flexible.R
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
context("surv.flexible")

task = tsk("rats")

test_that("autotest", {
set.seed(200)
learner = lrn("surv.flexible", k = 0, scale = "normal", inits = c(1, 1, 1))
expect_learner(learner)
# there's no reason why sanity in particular is excluded except that because of multiple
# experiments and changing seeds, I've found it isn't possible to remove the "`vnmin` is
# not finite" error
result = run_autotest(learner, exclude = "sanity")
result = run_autotest(learner, exclude = "sanity", check_replicable = FALSE)
expect_true(result, info = result$error)
})

task = tsk("rats")
test_that("manualtest", {
set.seed(15)
learn = lrn("surv.flexible")
Expand Down

0 comments on commit 34e6d29

Please sign in to comment.