Skip to content

Commit

Permalink
Merge pull request #40 from astamm/karcher_mean
Browse files Browse the repository at this point in the history
Karcher mean
  • Loading branch information
jdtuck committed Feb 13, 2024
2 parents 97fe21c + be4b87b commit c9bf800
Show file tree
Hide file tree
Showing 8 changed files with 387 additions and 125 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ export(bootTB)
export(calc_shape_dist)
export(curve_boxplot)
export(curve_depth)
export(curve_dist)
export(curve_geodesic)
export(curve_karcher_cov)
export(curve_karcher_mean)
Expand Down
11 changes: 6 additions & 5 deletions R/calc_shape_dist.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,11 @@ calc_shape_dist <- function(beta1, beta2,
scale = scale
)

q1dotq2 <- innerprod_q2(q1, out$q2best)

if (q1dotq2 > 1) q1dotq2 <- 1
if (q1dotq2 < -1) q1dotq2 <- -1

# Compute amplitude distance
if (scale) {
q1dotq2 <- innerprod_q2(q1, out$q2best)
if (q1dotq2 > 1) q1dotq2 <- 1
if (q1dotq2 < -1) q1dotq2 <- -1
if (include.length)
d <- sqrt(acos(q1dotq2) ^ 2 + log(lenq1 / lenq2) ^ 2)
else
Expand All @@ -91,13 +90,15 @@ calc_shape_dist <- function(beta1, beta2,
d <- sqrt(innerprod_q2(v, v))
}

# Compute phase distance
gam <- out$gambest
time1 <- seq(0, 1, length.out = T1)
binsize <- mean(diff(time1))
psi <- sqrt(gradient(gam, binsize))
v <- inv_exp_map(rep(1, length(gam)), psi)
dx <- sqrt(trapz(time1, v ^ 2))

# Return results
list(
d = d,
dx = dx,
Expand Down
87 changes: 87 additions & 0 deletions R/curve_dist.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#' Distance Matrix Computation
#'
#' Computes the pairwise distance matrix between a set of curves using the
#' elastic shape distance as computed by [`calc_shape_dist()`].
#'
#' @param beta A numeric array of shape \eqn{L \times M \times N} specifying the
#' set of \eqn{N} curves of length \eqn{M} in \eqn{L}-dimensional space.
#' @inheritParams calc_shape_dist
#' @param ncores An integer value specifying the number of cores to use for
#' parallel computation. If `ncores` is greater than the number of available
#' cores, a warning is issued and the maximum number of available cores is
#' used. Defaults to `1L`.
#'
#' @return A list of two objects, `Da` and `Dp`, each of class `dist` containing
#' the amplitude and phase distances, respectively.
#' @export
#'
#' @examples
#' out <- curve_dist(beta[, , 1, 1:4])
curve_dist <- function(beta,
mode = "O",
rotation = FALSE,
scale = FALSE,
include.length = FALSE,
ncores = 1L) {
navail <- max(parallel::detectCores() - 1, 1)

if (ncores > navail) {
cli::cli_alert_warning(
"The number of requested cores ({ncores}) is larger than the number of
available cores ({navail}). Using the maximum number of available cores..."
)
ncores <- navail
}

if (ncores > 1L) {
cl <- parallel::makeCluster(ncores)
doParallel::registerDoParallel(cl)
on.exit(parallel::stopCluster(cl))
} else
foreach::registerDoSEQ()

dims <- dim(beta)
L <- dims[1]
M <- dims[2]
N <- dims[3]
K <- N * (N - 1) / 2

k <- NULL
out <- foreach::foreach(k = 0:(K - 1), .combine = cbind, .packages = "fdasrvf") %dopar% {
# Compute indices i and j of distance matrix from linear index k
i <- N - 2 - floor(sqrt(-8 * k + 4 * N * (N - 1) - 7) / 2.0 - 0.5)
j <- k + i + 1 - N * (N - 1) / 2 + (N - i) * ((N - i) - 1) / 2
# Increment indices as previous ones are 0-based while R expects 1-based
res <- calc_shape_dist(
beta1 = beta[, , i + 1],
beta2 = beta[, , j + 1],
mode = mode,
rotation = rotation,
scale = scale,
include.length = include.length
)
matrix(c(res$d, res$dx), ncol = 1)
}

Da <- out[1, ]
attributes(Da) <- NULL
attr(Da, "Labels") <- 1:N
attr(Da, "Size") <- N
attr(Da, "Diag") <- FALSE
attr(Da, "Upper") <- FALSE
attr(Da, "call") <- match.call()
attr(Da, "method") <- "calc_shape_dist (amplitude)"
class(Da) <- "dist"

Dp <- out[2, ]
attributes(Dp) <- NULL
attr(Dp, "Labels") <- 1:N
attr(Dp, "Size") <- N
attr(Dp, "Diag") <- FALSE
attr(Dp, "Upper") <- FALSE
attr(Dp, "call") <- match.call()
attr(Dp, "method") <- "calc_shape_dist (phase)"
class(Dp) <- "dist"

list(Da = Da, Dp = Dp)
}
15 changes: 10 additions & 5 deletions R/curve_functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ find_rotation_seed_coord <- function(beta1, beta2,
len1 <- out$len

scl <- 4
minE <- 1000
minE <- Inf
if (mode == "C")
end_idx <- floor(T1 / scl)
else
Expand Down Expand Up @@ -212,11 +212,16 @@ find_rotation_seed_coord <- function(beta1, beta2,
gam <- seq(0, 1, length.out = T1)
}

dist <- innerprod_q2(q1, q2new)
if (dist < -1) dist <- -1
if (dist > 1) dist <- 1
if (scale) {
dist <- innerprod_q2(q1, q2new)
if (dist < -1) dist <- -1
if (dist > 1) dist <- 1
Ec <- acos(dist)
} else {
v <- q1 - q2new
Ec <- sqrt(innerprod_q2(v, v))
}

Ec <- acos(dist)
if (Ec < minE) {
Rbest <- Rout
beta2best <- beta2new
Expand Down
Loading

0 comments on commit c9bf800

Please sign in to comment.