generated from opensafely/research-template
/
validate_models.R
97 lines (72 loc) · 3.05 KB
/
validate_models.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
################################################################################
# Description: Validate models on witheld test data
#
# Author: Emily S Nightingale
# Date: 09/09/2020
#
################################################################################
# sink(paste0("./log_model_validate_",x,".txt"))
###############################################################################
library(tidyverse)
library(data.table)
theme_set(theme_bw())
args = commandArgs(trailingOnly=TRUE)
fits <- readRDS(args[1])
test <- readRDS(args[2])
sink("./output_model_val.txt", type = "output")
print("No. care homes in testing data:")
n_distinct(test$household_id)
## ------------------------------- Functions -------------------------------- ##
brier_test <- function(fit){
test$pred <- predict(fit, newdata = test, type = "response")
test %>%
summarise(score = mean((pred - event_ahead)^2)) %>%
pull(score)
}
# Plot ROC and calculate AUC as simple accuracy measure (no additional packages)
simple_roc <- function(labels, scores){
labels <- labels[order(scores, decreasing=TRUE)]
data.frame(TPR=cumsum(labels)/sum(labels), FPR=cumsum(!labels)/sum(!labels), labels)
}
computeAUC <- function(pos.scores, neg.scores, n_sample=100000) {
pos.sample <- sample(pos.scores, n_sample, replace=T)
neg.sample <- sample(neg.scores, n_sample, replace=T)
mean(1.0*(pos.sample > neg.sample) + 0.5*(pos.sample==neg.sample))
}
## ------------------------------ Prediction -------------------------------- ##
data.frame(score = sapply(fits, brier_test)) %>%
rownames_to_column(var = "Model") %>%
mutate(diff = score - min(score)) %>%
arrange(diff) %>%
mutate(across(-Model, function(x) round(x,6))) -> brier_comp
print("Brier scores on test data:")
brier_comp
# Plot distribution of predicted risk for event/no event
pdf(file = "./test_pred_figs.pdf", height = 7, width = 10)
for (f in seq_along(fits)){
test$pred <- predict(fits[[f]], newdata = test, type = "response")
# Boxplot of predicted risk for event/no event
print(
ggplot(test, aes(x = as.factor(event_ahead), y = pred, fill = as.factor(event_ahead))) +
geom_boxplot() +
labs(title = paste0(names(fits)[f], ": Model-predicted risk versus observed outcome"), y = "Predicted risk",x = "14-day event",
subtitle = paste0("Median predictions: ",round(median(test$pred[test$event_ahead == 1]),4), " for event = 1 and ",round(median(test$pred[test$event_ahead == 0]),4), " for event = 0.")) +
theme(legend.position = "none") +
coord_flip()
)
# ROC
pos.scores <- test$pred[test$event_ahead == 1]
neg.scores <- test$pred[test$event_ahead == 0]
auc <- computeAUC(pos.scores, neg.scores)
roc <- simple_roc(test$event_ahead,test$pred)
print(
ggplot(roc, aes(FPR, TPR)) +
geom_line(lty = "dashed", col = "blue") +
geom_abline() +
labs(title = paste0(names(fits)[f],": AUC = ",round(auc,2)))
)
}
dev.off()
################################################################################
sink()
################################################################################