Skip to content

Commit

Permalink
Fixes predictions with non-standard contrasts
Browse files Browse the repository at this point in the history
  • Loading branch information
chjackson committed Jan 17, 2024
1 parent 577b54c commit 640a58e
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
6 changes: 4 additions & 2 deletions R/flexsurvreg.R
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,7 @@ flexsurvreg <- function(formula, anc=NULL, data, weights, bhazard, rtrunc, subse
mx[[i]] <- length(unlist(mx)) + seq_len(ncol(mml[[i]][,-1,drop=FALSE]))
}
X <- compress.model.matrices(mml)
contr.save <- lapply(mml, function(x)attr(x,"contrasts"))

weights <- model.extract(m, "weights")
if (is.null(weights)) weights <- m$"(weights)" <- rep(1, nrow(X))
Expand Down Expand Up @@ -1026,7 +1027,8 @@ flexsurvreg <- function(formula, anc=NULL, data, weights, bhazard, rtrunc, subse
AIC = -2*ret$loglik + 2*ret$npars,
data = dat, datameans = colMeans(X),
N=nrow(dat$Y), events=sum(dat$Y[,"status"]==1), trisk=sum(dat$Y[,"time"]),
concat.formula=f2, all.formulae=forms, dfns=dfns),
concat.formula=f2, all.formulae=forms, all.contrasts=contr.save,
dfns=dfns),
ret,
list(covdata = covdata)) # temporary position so cyclomort doesn't break
ret$BIC <- BIC.flexsurvreg(ret, cens=TRUE)
Expand Down Expand Up @@ -1105,7 +1107,7 @@ form.model.matrix <- function(object, newdata, na.action=na.pass, forms=NULL){
names(mml) <- names(forms)
forms[[1]] <- delete.response(terms(forms[[1]]))
for (i in seq_along(forms)){
mml[[i]] <- model.matrix(forms[[i]], mf)
mml[[i]] <- model.matrix(forms[[i]], mf, contrasts.arg = object$all.contrasts[[i]])
}
X <- compress.model.matrices(mml)

Expand Down
6 changes: 5 additions & 1 deletion tests/testthat/test_contrasts.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
## Thanks to @https://github.com/anddis
## Thanks to Andrea Discacciati @https://github.com/anddis
## Fixed https://github.com/chjackson/flexsurv/issues/178

test_that("Non-default factor contrasts", {
Expand All @@ -19,4 +19,8 @@ test_that("Non-default factor contrasts", {

expect_equivalent(coef(fit.c), coef(fit.poi))
expect_true(coef(fit.c)[1] != coef(fit.noc)[1])

summ.noc <- summary(fit.noc, type="survival", t=10, tidy=TRUE, ci=FALSE)
summ.c <- summary(fit.c, type="survival", t=10, tidy=TRUE, ci=FALSE)
expect_equivalent(summ.noc$estimates, summ.c$estimates)
})

0 comments on commit 640a58e

Please sign in to comment.