-
Notifications
You must be signed in to change notification settings - Fork 31
/
plot.R
419 lines (397 loc) · 13 KB
/
plot.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
#' Plot EpiNow2 Credible Intervals
#'
#' @description `r lifecycle::badge("stable")`
#' Adds lineranges for user specified credible intervals
#' @param plot A `{ggplot2}` plot
#'
#' @param CrIs Numeric list of credible intervals present in the data. As
#' produced by [extract_CrIs()].
#'
#' @param alpha Numeric, overall alpha of the target line range
#'
#' @param linewidth Numeric, line width of the default line range.
#'
#' @return A `{ggplot2}` plot.
plot_CrIs <- function(plot, CrIs, alpha, linewidth) {
index <- 1
alpha_per_CrI <- alpha / (length(CrIs) - 1)
for (CrI in CrIs) {
bottom <- paste0("lower_", CrI)
top <- paste0("upper_", CrI)
if (index == 1) {
plot <- plot +
ggplot2::geom_ribbon(
ggplot2::aes(ymin = .data[[bottom]], ymax = .data[[top]]
),
alpha = 0.2, linewidth = linewidth
)
} else {
plot <- plot +
ggplot2::geom_ribbon(
ggplot2::aes(
ymin = .data[[bottom]], ymax = .data[[top]],
col = NULL
),
alpha = alpha_per_CrI
)
}
index <- index + 1
}
return(plot)
}
#' Plot Estimates
#'
#' @description `r lifecycle::badge("questioning")`
#' Allows users to plot the output from [estimate_infections()] easily.
#' In future releases it may be depreciated in favour of increasing the
#' functionality of the S3 plot methods.
#'
#' @param estimate A `<data.table>` of estimates containing the following
#' variables: date, type (must contain "estimate", "estimate based on partial
#' data" and optionally "forecast").
#'
#' @param reported A `<data.table>` of reported cases with the following
#' variables: date, confirm.
#'
#' @param ylab Character string. Title for the plot y
#' axis.
#'
#' @param hline Numeric, if supplied gives the horizontal intercept for a
#' indicator line.
#'
#' @param obs_as_col Logical, defaults to `TRUE`. Should observed data, if
#' supplied, be plotted using columns or as points (linked using a line).
#'
#' @param max_plot Numeric, defaults to 10. A multiplicative upper bound on the\
#' number of cases shown on the plot. Based on the maximum number of reported
#' cases.
#'
#' @param estimate_type Character vector indicating the type of data to plot.
#' Default to all types with supported options being: "Estimate", "Estimate
#' based on partial data", and "Forecast".
#'
#' @return A `ggplot2` object
#' @export
#' @importFrom ggplot2 ggplot aes geom_col geom_line geom_point geom_vline
#' @importFrom ggplot2 geom_hline geom_ribbon scale_y_continuous theme_bw
#' @importFrom scales comma
#' @importFrom data.table setDT fifelse copy as.data.table
#' @importFrom purrr map
#' @importFrom rlang arg_match
#' @examples
#' # get example model results
#' out <- readRDS(system.file(
#' package = "EpiNow2", "extdata", "example_estimate_infections.rds"
#' ))
#'
#' # plot infections
#' plot_estimates(
#' estimate = out$summarised[variable == "infections"],
#' reported = out$observations,
#' ylab = "Cases", max_plot = 2
#' ) + ggplot2::facet_wrap(~type, scales = "free_y")
#'
#' # plot reported cases estimated via Rt
#' plot_estimates(
#' estimate = out$summarised[variable == "reported_cases"],
#' reported = out$observations,
#' ylab = "Cases"
#' )
#'
#' # plot Rt estimates
#' plot_estimates(
#' estimate = out$summarised[variable == "R"],
#' ylab = "Effective Reproduction No.",
#' hline = 1
#' )
#'
#' #' # plot Rt estimates without forecasts
#' plot_estimates(
#' estimate = out$summarised[variable == "R"],
#' ylab = "Effective Reproduction No.",
#' hline = 1, estimate_type = "Estimate"
#' )
plot_estimates <- function(estimate, reported, ylab, hline,
obs_as_col = TRUE, max_plot = 10,
estimate_type = c(
"Estimate", "Estimate based on partial data",
"Forecast")
) {
# convert input to data.table
estimate <- data.table::as.data.table(estimate)
if (!missing(reported)) {
reported <- data.table::as.data.table(reported)
}
# map type to presentation form
to_sentence <- function(x) {
substr(x, 1, 1) <- toupper(substr(x, 1, 1))
x
}
estimate <- estimate[, type := to_sentence(type)]
orig_estimate <- copy(estimate)
estimate_type <- arg_match(estimate_type, multiple = TRUE)
# scale plot values based on reported cases
if (!missing(reported) && !is.na(max_plot)) {
sd_cols <- c(
grep("lower_", colnames(estimate), value = TRUE, fixed = TRUE),
grep("upper_", colnames(estimate), value = TRUE, fixed = TRUE)
)
cols <- setdiff(colnames(reported), c("date", "confirm", "breakpoint"))
if (length(cols) > 1) {
max_cases_to_plot <- data.table::copy(reported)[,
.(max = round(max(confirm, na.rm = TRUE) * max_plot, 0)),
by = cols
]
estimate <- estimate[max_cases_to_plot, on = cols]
} else {
max_cases_to_plot <- round(
max(reported$confirm, na.rm = TRUE) * max_plot, 0
)
estimate <- estimate[, max := max_cases_to_plot]
}
estimate <- estimate[, lapply(.SD, pmin, max),
by = setdiff(colnames(estimate), sd_cols), .SDcols = sd_cols
]
}
# initialise plot
plot <- ggplot2::ggplot(
estimate, ggplot2::aes(x = date, col = type, fill = type)
)
# add in reported data if present (either as column or as a line)
if (!missing(reported)) {
if (obs_as_col) {
plot <- plot +
ggplot2::geom_col(
data = reported[date >= min(estimate$date, na.rm = TRUE) &
date <= max(estimate$date, na.rm = TRUE)],
ggplot2::aes(y = confirm), fill = "grey", col = "white",
show.legend = FALSE, na.rm = TRUE
)
} else {
plot <- plot +
ggplot2::geom_line(
data = reported,
ggplot2::aes(y = confirm, fill = NULL),
linewidth = 1.1, alpha = 0.5, col = "black", na.rm = TRUE
) +
ggplot2::geom_point(
data = reported,
ggplot2::aes(y = confirm, fill = NULL),
linewidth = 1.1, alpha = 1, col = "black",
show.legend = FALSE, na.rm = TRUE
)
}
}
# plot estimates
plot <- plot +
ggplot2::geom_vline(
xintercept = orig_estimate[
type == "Estimate based on partial data"][date == max(date)
]$date,
linetype = 2
)
# plot CrIs
plot <- plot_CrIs(plot, extract_CrIs(estimate),
alpha = 0.6, linewidth = 0.05
)
# add plot theming
plot <- plot +
ggplot2::theme_bw() +
ggplot2::theme(legend.position = "bottom") +
ggplot2::scale_color_brewer(palette = "Dark2") +
ggplot2::scale_fill_brewer(palette = "Dark2") +
ggplot2::labs(y = ylab, x = "Date", col = "Type", fill = "Type") +
ggplot2::expand_limits(y = 0) +
ggplot2::scale_x_date(date_breaks = "1 week", date_labels = "%b %d") +
ggplot2::scale_y_continuous(labels = scales::comma) +
ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 90))
# add in a horizontal line if required
if (!missing(hline)) {
plot <- plot +
ggplot2::geom_hline(yintercept = hline, linetype = 2)
}
return(plot)
}
#' Plot a Summary of the Latest Results
#'
#' @description `r lifecycle::badge("questioning")`
#' Used to return a summary plot across regions (using results generated by
#' [summarise_results()]).
#'
#' May be depreciated in later releases in favour of enhanced S3 methods.
#'
#' @param summary_results A data.table as returned by [summarise_results()] (the
#' `data` object).
#'
#' @param x_lab A character string giving the label for the x axis, defaults to
#' region.
#'
#' @param log_cases Logical, should cases be shown on a logged scale. Defaults
#' to `FALSE`.
#'
#' @param max_cases Numeric, no default. The maximum number of cases to plot.
#'
#' @return A `{ggplot2}` object
#' @export
#' @importFrom ggplot2 ggplot aes geom_linerange geom_hline facet_wrap
#' @importFrom ggplot2 theme guides labs expand_limits guide_legend
#' @importFrom ggplot2 scale_color_manual .data coord_cartesian
#' @importFrom ggplot2 theme_bw element_blank scale_y_continuous
#' @importFrom scales comma
#' @importFrom patchwork plot_layout
#' @importFrom data.table setDT
plot_summary <- function(summary_results,
x_lab = "Region",
log_cases = FALSE,
max_cases) {
# set input to data.table
summary_results <- data.table::setDT(summary_results)
# extract CrIs
CrIs <- extract_CrIs(summary_results)
max_CrI <- max(CrIs)
# generic plotting function
inner_plot <- function(df) {
plot <- ggplot2::ggplot(df, ggplot2::aes(
x = region,
col = `Expected change in daily reports`
))
# plot CrIs
index <- 1
alpha_per_CrI <- 0.8 / (length(CrIs) - 1)
for (CrI in CrIs) {
bottom <- paste0("lower_", CrI)
top <- paste0("upper_", CrI)
plot <- plot +
ggplot2::geom_linerange(
ggplot2::aes(ymin = .data[[bottom]], ymax = .data[[top]]
),
alpha = ifelse(index == 1, 0.4, alpha_per_CrI),
linewidth = 4
)
index <- index + 1
}
plot <- plot +
ggplot2::geom_hline(yintercept = 1, linetype = 2) +
ggplot2::facet_wrap(~metric, ncol = 1, scales = "free_y") +
ggplot2::theme_bw() +
ggplot2::scale_color_manual(values = c(
Increasing = "#e75f00",
"Likely increasing" = "#fd9e49",
"Likely decreasing" = "#5fa2ce",
Decreasing = "#1170aa",
Stable = "#7b848f"
), drop = FALSE)
}
# check max_cases
upper_CrI <- paste0("upper_", max_CrI) # nolint
max_upper <- max(
summary_results[
metric == "New infections per day"][, ..upper_CrI],
na.rm = TRUE
)
max_cases <- min(
c(
max_cases,
max_upper + 1
),
na.rm = TRUE
)
# cases plot
cases_plot <-
inner_plot(
summary_results[metric == "New infections per day"]
) +
ggplot2::labs(x = x_lab, y = "") +
ggplot2::expand_limits(y = 0) +
ggplot2::theme(
axis.title.x = ggplot2::element_blank(),
axis.text.x = ggplot2::element_blank()
) +
ggplot2::theme(legend.position = "none")
if (log_cases) {
cases_plot <- cases_plot +
ggplot2::scale_y_log10(
labels = scales::comma,
limits = c(NA, ifelse(!missing(max_cases), max_cases, NA)),
oob = scales::squish
)
} else {
cases_plot <- cases_plot +
ggplot2::scale_y_continuous(
labels = scales::comma,
limits = c(NA, ifelse(!missing(max_cases), max_cases, NA)),
oob = scales::squish
)
}
# rt plot
rt_data <- summary_results[metric == "Effective reproduction no."]
uppers <- grepl("upper_", colnames(rt_data), fixed = TRUE) # nolint
max_rt <- max(data.table::copy(rt_data)[, ..uppers], na.rm = TRUE)
rt_plot <-
inner_plot(rt_data) +
ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 90)) +
ggplot2::theme(legend.position = "bottom") +
ggplot2::guides(col = ggplot2::guide_legend(nrow = 2)) +
ggplot2::labs(x = x_lab, y = "") +
ggplot2::expand_limits(y = c(0, min(max_rt, 4))) +
ggplot2::coord_cartesian(ylim = c(0, min(max_rt, 4)))
# join plots together
plot <- cases_plot + rt_plot + patchwork::plot_layout(ncol = 1)
return(plot)
}
#' Plot method for estimate_infections
#'
#' @description `r lifecycle::badge("maturing")`
#' `plot` method for class `<estimate_infections>`.
#'
#' @param x A list of output as produced by `estimate_infections`
#'
#' @param type A character vector indicating the name of the plot to return.
#' Defaults to "summary" with supported options being "infections", "reports",
#' "R", "growth_rate", "summary", "all". If "all" is supplied all plots are
#' generated.
#'
#' @param ... Pass additional arguments to report_plots
#' @importFrom rlang arg_match
#'
#' @seealso plot report_plots estimate_infections
#' @aliases plot
#' @method plot estimate_infections
#' @return List of plots as produced by [report_plots()]
#' @export
plot.estimate_infections <- function(x,
type = c(
"summary", "infections", "reports", "R",
"growth_rate", "all"
), ...) {
out <- report_plots(
summarised_estimates = x$summarised,
reported = x$observations, ...
)
type <- arg_match(type)
if (type == "all") {
type <- c("summary", "infections", "reports", "R", "growth_rate")
}
if (!is.null(out)) {
out <- out[type]
if (length(type) == 1) {
out <- out[[1]]
}
return(out)
} else {
return(invisible(NULL))
}
}
#' Plot method for epinow
#'
#' @description `r lifecycle::badge("maturing")`
#' `plot` method for class `<epinow>`.
#' @param x A list of output as produced by [epinow()].
#' @inheritParams plot.estimate_infections
#' @seealso plot plot.estimate_infections report_plots estimate_infections
#' @method plot epinow
#' @return List of plots as produced by [report_plots()]
#' @export
plot.epinow <- function(x, type = "summary", ...) {
plot.estimate_infections(x$estimates, type = type, ...)
}