Skip to content

Commit

Permalink
Merge pull request #294 from mrc-ide/mrc-4277
Browse files Browse the repository at this point in the history
Proof-of-concept data/compare in odin
  • Loading branch information
weshinsley committed Jun 15, 2023
2 parents 938cb5d + 17925a0 commit e390448
Show file tree
Hide file tree
Showing 10 changed files with 346 additions and 24 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: odin
Title: ODE Generation and Integration
Version: 1.4.7
Version: 1.5.0
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "rich.fitzjohn@gmail.com"),
person("Thibaut", "Jombart", role = "ctb"),
Expand Down
5 changes: 3 additions & 2 deletions R/common.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ RING <- "odin_ring"
## variables) but that needs checking too. Not 100% sure this is done
## on the lhs index bits. Probably need to standardise that at some
## point.
SPECIAL_LHS <- c("initial", "deriv", "update", "output", "dim", "config")
SPECIAL_RHS <- c("user", "interpolate", "delay")
SPECIAL_LHS <- c("initial", "deriv", "update", "output", "dim", "config",
"compare")
SPECIAL_RHS <- c("user", "interpolate", "delay", "data")
INDEX <- c("i", "j", "k", "l", "i5", "i6", "i7", "i8") # TODO: make open
INTERNAL <- "internal"

Expand Down
3 changes: 3 additions & 0 deletions R/generate_c.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ generate_c_code <- function(dat, options, package) {
if (dat$features$mixed) {
stop("Models that mix deriv() and update() are not supported")
}
if (dat$features$has_compare || dat$features$has_data) {
stop("data() and compare() not supported")
}

if (dat$features$has_delay) {
dat$data$elements[[dat$meta$c$use_dde]] <-
Expand Down
3 changes: 3 additions & 0 deletions R/generate_js.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ generate_js <- function(ir, options) {
if (dat$features$mixed) {
stop("Models that mix deriv() and update() are not supported")
}
if (dat$features$has_compare || dat$features$has_data) {
stop("data() and compare() not supported")
}

rewrite <- function(x) {
generate_js_sexp(x, dat$data, dat$meta)
Expand Down
3 changes: 3 additions & 0 deletions R/generate_r.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ generate_r <- function(dat, options) {
if (dat$features$mixed) {
stop("Models that mix deriv() and update() are not supported")
}
if (dat$features$has_compare || dat$features$has_data) {
stop("data() and compare() not supported")
}

if (dat$features$has_delay) {
## We're going to need an additional bit of internal data here,
Expand Down
153 changes: 134 additions & 19 deletions R/ir_parse.R
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ ir_parse_data <- function(eqs, packing, stage, source) {
is_alloc <- vlapply(eqs, function(x) {
x$type == "alloc" && x$name != x$lhs$name_lhs
})
i <- !(is_alloc | type %in% c("copy", "config"))
i <- !(is_alloc | type %in% c("copy", "config", "compare"))

elements <- lapply(eqs[i], ir_parse_data_element, stage)
names(elements) <- vcapply(elements, "[[", "name")
Expand All @@ -180,7 +180,9 @@ ir_parse_data_element <- function(x, stage) {

stage <- stage[[x$name]]

if (is.null(x$lhs$special)) {
if (x$type == "data") {
location <- "data"
} else if (is.null(x$lhs$special)) {
if (rank == 0L && stage == STAGE_TIME) {
location <- "transient"
} else {
Expand Down Expand Up @@ -355,8 +357,12 @@ ir_parse_stage <- function(eqs, dependencies, variables, time_name, source) {
!is.null(x$rhs$delay) || !is.null(x$rhs$interpolate) ||
isTRUE(x$stochastic) || x$type == "delay_continuous"
}
is_data <- function(x) {
x$type == "data"
}

stage[names_if(vlapply(eqs, is_user))] <- STAGE_USER
stage[names_if(vlapply(eqs, is_data))] <- STAGE_TIME
stage[names_if(vlapply(eqs, is_time_dependent))] <- STAGE_TIME
stage[time_name] <- STAGE_TIME
stage[variables] <- STAGE_TIME
Expand Down Expand Up @@ -481,6 +487,8 @@ ir_parse_features <- function(eqs, debug, config, source) {
is_delay <- vlapply(eqs, function(x) !is.null(x$delay))
is_interpolate <- vlapply(eqs, function(x) !is.null(x$interpolate))
is_stochastic <- vlapply(eqs, function(x) isTRUE(x$stochastic))
is_data <- vlapply(eqs, function(x) !is.null(x$data))
is_compare <- vlapply(eqs, function(x) identical(x$lhs$special, "compare"))

## We'll support other debugging bits later, I imagine.
is_debug_print <- vlapply(debug, function(x) x$type == "print")
Expand All @@ -499,6 +507,8 @@ ir_parse_features <- function(eqs, debug, config, source) {
has_delay = any(is_delay),
has_interpolate = any(is_interpolate),
has_stochastic = any(is_stochastic),
has_data = any(is_data),
has_compare = any(is_compare),
has_include = !is.null(config$include),
has_debug = any(is_debug_print),
initial_time_dependent = NULL)
Expand All @@ -510,6 +520,7 @@ ir_parse_components <- function(eqs, dependencies, variables, stage,
eqs_constant <- intersect(names_if(stage == STAGE_CONSTANT), names(eqs))
eqs_user <- intersect(names_if(stage == STAGE_USER), names(eqs))
eqs_time <- intersect(names_if(stage == STAGE_TIME), names(eqs))
eqs_data <- names_if(vlapply(eqs, function(x) x$type == "data"))

## NOTE: we need the equation name here, not the variable name
rhs_special <- if (common$continuous) "deriv" else "update"
Expand Down Expand Up @@ -545,9 +556,16 @@ ir_parse_components <- function(eqs, dependencies, variables, stage,
variables_update_stochastic <- character(0)
}

compare <- names_if(vlapply(eqs, function(x) {
identical(x$lhs$special, "compare")
}))
v <- unique(c(compare, unlist(dependencies[compare], use.names = FALSE)))
eqs_compare <- setdiff(intersect(eqs_time, v), eqs_data)
variables_compare <- intersect(variables, v)

type <- vcapply(eqs, "[[", "type")
core <- unique(c(initial, rhs, output, eqs_initial, eqs_rhs, eqs_output,
eqs_update_stochastic))
eqs_update_stochastic, eqs_compare))

used_in_delay <- unlist(lapply(eqs[type == "delay_continuous"],
function(x) x$delay$depends$variables),
Expand All @@ -560,14 +578,28 @@ ir_parse_components <- function(eqs, dependencies, variables, stage,
ir_parse_check_unused(eqs, dependencies, core, stage, source)
}

list(
ret <- list(
create = list(variables = character(0), equations = eqs_constant),
user = list(variables = character(0), equations = eqs_user),
initial = list(variables = character(0), equations = eqs_initial),
rhs = list(variables = variables_rhs, equations = eqs_rhs),
update_stochastic = list(variables = variables_update_stochastic,
equations = eqs_update_stochastic),
output = list(variables = variables_output, equations = eqs_output))
output = list(variables = variables_output, equations = eqs_output),
compare = list(variables = variables_compare, equations = eqs_compare))

for (i in names(ret)) {
err <- intersect(ret[[i]]$equations, eqs_data)
if (length(err) > 0) {
## TODO: this should be a proper parse error and we should find
## the chain here that drags it in, not actually that easy!
stop(sprintf(
"Data (%s) may only be referred to in compare expressions",
paste(squote(err), collapse = ", ")))
}
}

ret
}


Expand Down Expand Up @@ -640,24 +672,34 @@ ir_parse_exprs <- function(exprs) {
expr_is_debug <- function(x) {
is_call(x, "print")
}

expr_is_compare <- function(x) {
length(x) == 3L && is_call(x[[2]], "compare")
}
is_assignment <- vlapply(exprs, expr_is_assignment)
is_debug <- vlapply(exprs, expr_is_debug)
err <- !(is_assignment | is_debug)
is_compare <- vlapply(exprs, expr_is_compare)
err <- !(is_assignment | is_debug | is_compare)
if (any(err)) {
ir_parse_error(
"Every line must contain an assignment (or a debug statement)",
paste("Every line must contain an assignment, a compare statement",
"or a debug statement"),
unlist(lines[err]), src)
}

## The comparison equations are different enough that we do them
## separately, then join back together:
compare <- Map(ir_parse_compare, exprs[is_compare], lines[is_compare],
MoreArgs = list(source = src))
names(compare) <- vcapply(compare, "[[", "name")

eqs <- Map(ir_parse_expr, exprs[is_assignment], lines[is_assignment],
MoreArgs = list(source = src))
names(eqs) <- vcapply(eqs, "[[", "name")

debug <- Map(ir_parse_debug, exprs[is_debug], lines[is_debug],
MoreArgs = list(source = src))

list(eqs = eqs, source = src, debug = debug)
list(eqs = c(eqs, compare), source = src, debug = debug)
}


Expand All @@ -669,6 +711,8 @@ ir_parse_expr <- function(expr, line, source) {

if (!is.null(rhs$user)) {
type <- "user"
} else if (!is.null(rhs$data)) {
type <- "data"
} else if (!is.null(rhs$interpolate)) {
type <- "interpolate"
} else if (!is.null(rhs$delay)) {
Expand Down Expand Up @@ -800,7 +844,7 @@ ir_parse_expr_lhs <- function(lhs, line, source) {
name_lhs <- name_equation
} else if (special %in% c("initial", "dim")) {
name_lhs <- name_equation
} else if (special %in% c("deriv", "output", "update", "config")) {
} else if (special %in% c("deriv", "output", "update", "config", "compare")) {
name_lhs <- name_data
} else {
stop("odin bug") # nocov
Expand Down Expand Up @@ -911,14 +955,16 @@ ir_parse_expr_check_lhs_name <- function(lhs, special, line, source) {

name <- deparse(lhs)

if (name %in% RESERVED) {
ir_parse_error(sprintf("Reserved name '%s' for lhs", name), line, source)
}
re <- sprintf("^(%s)_.*", paste(RESERVED_PREFIX, collapse = "|"))
if (grepl(re, name)) {
ir_parse_error(sprintf("Variable name cannot start with '%s_'",
sub(re, "\\1", name)),
line, source)
if (!identical(special, "config")) {
if (name %in% RESERVED) {
ir_parse_error(sprintf("Reserved name '%s' for lhs", name), line, source)
}
re <- sprintf("^(%s)_.*", paste(RESERVED_PREFIX, collapse = "|"))
if (grepl(re, name)) {
ir_parse_error(sprintf("Variable name cannot start with '%s_'",
sub(re, "\\1", name)),
line, source)
}
}

name
Expand All @@ -931,6 +977,8 @@ ir_parse_expr_rhs <- function(rhs, line, source) {
ir_parse_expr_rhs_delay(rhs, line, source)
} else if (is_call(rhs, quote(user))) {
ir_parse_expr_rhs_user(rhs, line, source)
} else if (is_call(rhs, quote(data))) {
ir_parse_expr_rhs_data(rhs, line, source)
} else if (is_call(rhs, quote(interpolate))) {
ir_parse_expr_rhs_interpolate(rhs, line, source)
} else {
Expand Down Expand Up @@ -1025,6 +1073,17 @@ ir_parse_expr_rhs_user <- function(rhs, line, source) {
}


ir_parse_expr_rhs_data <- function(rhs, line, source) {
args <- as.list(rhs[-1L])

if (length(args) != 0) {
ir_parse_error("Calls to data() must have no arguments", line, source)
}
data <- list(type = "real_type")
list(data = data)
}


ir_parse_expr_rhs_interpolate <- function(rhs, line, source) {
na <- length(rhs) - 1L
if (na < 2L || na > 3L) {
Expand Down Expand Up @@ -1115,7 +1174,7 @@ ir_parse_expr_rhs_delay <- function(rhs, line, source) {

ir_parse_equations <- function(eqs) {
type <- vcapply(eqs, "[[", "type")
eqs[!(type %in% c("null", "config"))]
eqs[!(type %in% c("null", "config", "data"))]
}


Expand Down Expand Up @@ -1867,3 +1926,59 @@ ir_parse_debug_print <- function(eq, data, source) {

ret
}


ir_parse_compare <- function(eq, line, source) {
if (!is_call(eq, "~")) {
ir_parse_error(
"All compare() expressions must use '~' and not '<-' or '='",
line, source)
}

## TODO: we need to compare here against expected
lhs <- ir_parse_expr_lhs(eq[[2L]], line, source)
rhs <- ir_parse_compare_rhs(eq[[3L]], line, source)

depends <- rhs$depends
depends$variables <- union(depends$variables, lhs$name_data)

list(name = lhs$name_equation,
type = "compare",
lhs = lhs,
rhs = rhs,
depends = depends,
source = line)
}


## TODO: these (for now at least) don't sanitise args. The list
## effectively depends on dust, but that feels suboptimal (e.g., for
## the js target). Things like negative binomial is a trick
## because there are two different forms. We'll expand support in
## dust fairly promptly I'd expect.
ir_parse_compare_rhs <- function(expr, line, source) {
if (!is.recursive(expr)) {
ir_parse_error(
"Expected rhs of compare() expression to be a call",
line, source)
}
stopifnot(is.name(expr[[1]]))
distribution <- deparse(expr[[1]])
valid <- c("normal", "binomial", "negative_binomial_mu",
"negative_binomial_prob", "beta_binomial",
"poisson")
if (!(distribution %in% valid)) {
ir_parse_error(
sprintf("Expected rhs to be a valid distribution (%s)",
paste(squote(valid), collapse = ", ")),
line, source)
}

ir_parse_expr_rhs_check_usage(expr, line, source)

args <- as.list(expr[-1])
depends <- join_deps(lapply(args, find_symbols))
list(distribution = distribution,
args = args,
depends = depends)
}
12 changes: 11 additions & 1 deletion R/ir_serialise.R
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,10 @@ ir_serialise_equation <- function(eq) {
expression_inplace = ir_serialise_equation_expression_inplace(eq),
interpolate = ir_serialise_equation_interpolate(eq),
user = ir_serialise_equation_user(eq),
stop("odin bug"))
data = ir_serialise_equation_data(eq),
compare = ir_serialise_equation_compare(eq),
stop(sprintf("Can't serialise unknown equation type '%s' [odin bug]",
eq$type)))
c(base, extra)
}

Expand Down Expand Up @@ -196,6 +199,13 @@ ir_serialise_equation_user <- function(eq) {
}


ir_serialise_equation_compare <- function(eq) {
compare <- list(distribution = scalar(eq$rhs$distribution),
args = lapply(eq$rhs$args, ir_serialise_expression))
list(compare = compare)
}


ir_serialise_delay_continuous <- function(eq) {
f_contents <- function(x) {
list(name = scalar(x$name),
Expand Down
Loading

0 comments on commit e390448

Please sign in to comment.