Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 1f0551c
Showing
83 changed files
with
10,844 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
Package: approxOT | ||
Type: Package | ||
Title: Approximate and Exact Optimal Transport Methods | ||
Version: 1.0 | ||
Date: 2022-01-07 | ||
Author: Eric Dunipace | ||
Maintainer: Eric Dunipace <edunipace@mail.harvard.edu> | ||
Description: R and C++ functions to perform exact and | ||
approximate optimal transport. All C++ methods are linkable | ||
to other R packages via their header files. | ||
License: GPL (>= 3.0) | ||
Imports: Rcpp (>= 1.0.3), stats | ||
LinkingTo: Rcpp, RcppEigen, RcppCGAL, BH | ||
BugReports: https://github.com/ericdunipace/approxOT/issues | ||
Suggests: testthat (>= 2.1.0), transport | ||
RoxygenNote: 7.1.1 | ||
NeedsCompilation: yes | ||
Packaged: 2022-01-08 06:42:15 UTC; eifer | ||
Repository: CRAN | ||
Date/Publication: 2022-01-10 18:02:44 UTC |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
ffc6da520b30cee455f03c91d446b5aa *DESCRIPTION | ||
9fbdc692ea6fbaca50a7af68505b86fa *NAMESPACE | ||
6f99890dc04a2890d5e681cf2c38bebc *R/RcppExports.R | ||
38f63ac108c704c098adb38f4dc1bf1c *R/approxOT-package.R | ||
e89a14923f71ebdcf13fe33720903d14 *R/cost_calculation.R | ||
f6e038fc2a5c9d9c92ed343875268bfd *R/hilbert.projection.R | ||
b4c05a951e57184a971f2acc6440c3f9 *R/sinkhorn.R | ||
ad36cef7a9aecde7345b10a8fd3869cf *R/transport_options.R | ||
36958fd303db2cb31128755e69856c3e *R/transport_plan.R | ||
311b0d3ad8f09ecc724c1592b5100036 *R/unload.R | ||
7f8b1cde45986aa4fe02f40495ce2d9a *R/wasserstein_distance.R | ||
63bde141206b385278d79f5c15cb50fb *README.md | ||
4c468a3dda16ec70573948ba2307bcf4 *inst/CITATION | ||
212874d8f2d98598293a57753726e73f *inst/include/approxOT_types.h | ||
eb436d6657691aa863a72656186276b0 *inst/include/cost.h | ||
4597916ea7a841fd78a11dfaec20db9b *inst/include/full_bipartitegraph.h | ||
0d603a0b2f7eaba69ef299c2b20bd8f3 *inst/include/hilbert_cgal.h | ||
2cfac11d6126e69c934afc0b15bff605 *inst/include/network_simplex_simple.h | ||
0da96d40d7068ec3f2f97272f95662e2 *inst/include/networkflow.h | ||
5e8bb1cf08a1cb713b383405a7f345b4 *inst/include/round_feasible.h | ||
27a82b2fec3ddf5034bfd4b8086813fb *inst/include/shortsimplex.h | ||
1ac021a2e1d96b4de46fe59b9fea1952 *inst/include/sort.h | ||
31f82312cdd9f89e6cffc181c066989c *inst/include/systematic_sample.h | ||
2c1a2b8955cc66f0de33788849545a45 *inst/include/trans_approxOT.h | ||
e23e89153b9ecf7dd49d2b55f4e12422 *inst/include/trans_gandkhorn.h | ||
0e0290d9b97ed8400db7d5a4bdc726c6 *inst/include/trans_greenkhorn.h | ||
2bcb98d404baeb3bd71450829a849d8d *inst/include/trans_hilbert.h | ||
2b2395e6f5ad3a08b417bd9f31f70b00 *inst/include/trans_randkhorn.h | ||
c64aea71a0f6f6f4f0bcae065484fd5c *inst/include/trans_rank.h | ||
ad10d58988825a440cee1feb6dc5ef69 *inst/include/trans_shortsimplex.h | ||
bef5c93b23d7d2ead6ecd3f60014002b *inst/include/trans_sinkhorn.h | ||
a280ae57738632ce2dbe4865d88fc912 *inst/include/trans_swap.h | ||
c74777fa68b59bcd78c5c551dd07c210 *inst/include/trans_univariate.h | ||
90d10ef4657954b575ee7e7cc58aa3b1 *inst/include/trans_univariate_approx_pwr.h | ||
059decd76013ef6b6c0bdda77713c035 *inst/include/transport.h | ||
dabbf34aee9a5abae8ea3c997faded2e *inst/include/utils.h | ||
a02025eeb3bba22eef8e3e9872d654e1 *inst/include/wasserstein.h | ||
02c0fde86acfe291d7741cc91dc0bc4d *man/approxOT-package.Rd | ||
7e600afb27119941489f7728ea749f27 *man/cost_calc.Rd | ||
cc30319d1ac5ceeae61d670ac895f4e2 *man/general_1d_transport.Rd | ||
da8c8bcb69470af52e7a5b2fc25570e7 *man/hilbert.projection.Rd | ||
80a97f51ec324464ed2976f8489bbab8 *man/hilbert_proj_.Rd | ||
41208402d74440e69edad0112d6d376b *man/sinkhorn_distance.Rd | ||
3dc9063765331f08febb48cd1ee51472 *man/sinkhorn_pot.Rd | ||
e9d43312b66ed6b7906d33b7d382338c *man/sinkhorn_transport.Rd | ||
c249f2f549f9c49860c8824539a298ad *man/transport_options.Rd | ||
9989e719a3f4bf5b9782c95b1f9cc6ac *man/transport_plan.Rd | ||
ef90bdbb031a55d8864f447e227a00a1 *man/transport_plan_given_C.Rd | ||
452ecf171a57c89063c2cd32434dc458 *man/transport_plan_multimarg.Rd | ||
a6f646d3ecd74f526caa073b0dff2d4b *man/wasserstein.Rd | ||
36808b5a1ae1985d1b39c7082ea3787e *src/Makevars | ||
f85e1e7a2701c8a8158d5cbd7e0338e5 *src/Makevars.win | ||
7b60fb130fd2b35435a96bc34c3e82eb *src/RcppExports.cpp | ||
49021e38b3c7bbdf5b8054f6abdf2ffb *src/cost.cpp | ||
babb4706ef54cd811d58a0e2eacb2d0c *src/hilbert_cgal.cpp | ||
6f2638b646746faf1cb8570d24e46d41 *src/hilbert_projection.cpp | ||
8284b7b9df1c0e15dc9c8e277b08340c *src/networkflow.cpp | ||
f943a99f31dbe84692c121a3484a4698 *src/round_feasible.cpp | ||
f02a6d33280c2b3a0f959b270f5af6d5 *src/shortsimplex.c | ||
41ba408ab047211fb91406d088f74d4f *src/sinkhorn_test.cpp | ||
68beead4ce2a14a7fa754492c54a0c64 *src/sort.cpp | ||
e12e3fcddfa14e85518c851484a3e3d2 *src/systematic_sample.cpp | ||
cdd858e511fbb7fbb41f768dabc5cd01 *src/trans_approxOT.cpp | ||
ea5a7d5665da9ae50a50e494202751b2 *src/trans_gandkhorn.cpp | ||
5cdbd68a93789760913dfd3c456d10df *src/trans_greenkhorn.cpp | ||
87c99b4b0090269538cb4c3a5b7fa75a *src/trans_hilbert.cpp | ||
b679a0824ac45bdc49bf90daed8c8510 *src/trans_randkhorn.cpp | ||
73f6aba1bedf54bb50ad57adb4b8e3fa *src/trans_rank.cpp | ||
81f0800b9c07c573abb41d012c32ca5c *src/trans_shortsimplex.cpp | ||
061884405b51201b1bf60726b902d35b *src/trans_sinkhorn.cpp | ||
1d47e723a68bdaa7c8fbf2a4dde1b420 *src/trans_swap.cpp | ||
e4758bc972ae748200814c028b551124 *src/trans_univariate.cpp | ||
b4474ba7a38371a72a188fee5b0727c9 *src/trans_univariate_approx_pwr.cpp | ||
9b92dd5ad49a83fd5f7367116961b161 *src/transport.cpp | ||
6ea5b2133bab675b93b3b6477af507da *src/utils.cpp | ||
6473d081693854b3ed38e877e1015295 *src/wasserstein.cpp | ||
f42704eb6cd1acfd149d06a89c47640e *tests/testthat.R | ||
1a4516e061efac67730d2a9634504c05 *tests/testthat/test-general_hilbert_transport.R | ||
644760a1befa1c0fefd491c368ba6763 *tests/testthat/test-hilbert.R | ||
347f099ccb131e7e238686f18335404f *tests/testthat/test-transport_plan.R | ||
d5ad857c16b7aae050f35c514ffe61ec *tests/testthat/test-transport_plan_multimarg.R | ||
452cccd2ab98505eb0e18c433ccf5ff8 *tests/testthat/test-wasserstein.R |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# Generated by roxygen2: do not edit by hand | ||
|
||
export(cost_calc) | ||
export(general_1d_transport) | ||
export(hilbert.projection) | ||
export(sinkhorn_pot) | ||
export(transport_options) | ||
export(transport_plan) | ||
export(transport_plan_given_C) | ||
export(transport_plan_multimarg) | ||
export(wasserstein) | ||
importFrom(Rcpp,evalCpp) | ||
importFrom(Rcpp,sourceCpp) | ||
useDynLib(approxOT) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# Generated by using Rcpp::compileAttributes() -> do not edit by hand | ||
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393 | ||
|
||
cost_calculation_ <- function(A_, B_, p) { | ||
.Call('_approxOT_cost_calculation_', PACKAGE = 'approxOT', A_, B_, p) | ||
} | ||
|
||
multi_marg_final_cost_ <- function(idx_, data_, mass_, M, D, p, ground_p) { | ||
.Call('_approxOT_multi_marg_final_cost_', PACKAGE = 'approxOT', idx_, data_, mass_, M, D, p, ground_p) | ||
} | ||
|
||
multi_marg_given_dist_ <- function(idx_, mass_, cost_, M, N_cost, p) { | ||
.Call('_approxOT_multi_marg_given_dist_', PACKAGE = 'approxOT', idx_, mass_, cost_, M, N_cost, p) | ||
} | ||
|
||
#' Returns orders along the Hilbert space-filling Curve | ||
#' | ||
#' @param A a matrix of data-values of class Eigen::MatrixXd | ||
#' @return An integer vector of orders | ||
#' @keywords internal | ||
hilbert_proj_ <- function(A) { | ||
.Call('_approxOT_hilbert_proj_', PACKAGE = 'approxOT', A) | ||
} | ||
|
||
sinkhorn_ <- function(p_, q_, cost_matrix_, epsilon, niterations) { | ||
.Call('_approxOT_sinkhorn_', PACKAGE = 'approxOT', p_, q_, cost_matrix_, epsilon, niterations) | ||
} | ||
|
||
sinkhorn_pot_ <- function(mass_a, mass_b, cost_matrix, epsilon, niterations, unbiased, cost_matrix_A, cost_matrix_B) { | ||
.Call('_approxOT_sinkhorn_pot_', PACKAGE = 'approxOT', mass_a, mass_b, cost_matrix, epsilon, niterations, unbiased, cost_matrix_A, cost_matrix_B) | ||
} | ||
|
||
transport_C_ <- function(mass_a_, mass_b_, cost_matrix_, method_, epsilon_, niter_, unbiased_, threads_, cost_matrix_A_, cost_matrix_B_) { | ||
.Call('_approxOT_transport_C_', PACKAGE = 'approxOT', mass_a_, mass_b_, cost_matrix_, method_, epsilon_, niter_, unbiased_, threads_, cost_matrix_A_, cost_matrix_B_) | ||
} | ||
|
||
transport_ <- function(A_, B_, p, ground_p, method_, a_sort, epsilon_ = 0.0, niter_ = 0L, unbiased_ = FALSE, threads_ = 1L) { | ||
.Call('_approxOT_transport_', PACKAGE = 'approxOT', A_, B_, p, ground_p, method_, a_sort, epsilon_, niter_, unbiased_, threads_) | ||
} | ||
|
||
transport_swap_ <- function(A_, B_, idx_, mass_, p, ground_p, tolerance_, niter_ = 0L) { | ||
.Call('_approxOT_transport_swap_', PACKAGE = 'approxOT', A_, B_, idx_, mass_, p, ground_p, tolerance_, niter_) | ||
} | ||
|
||
wasserstein_ <- function(mass_, cost_, p, from_, to_) { | ||
.Call('_approxOT_wasserstein_', PACKAGE = 'approxOT', mass_, cost_, p, from_, to_) | ||
} | ||
|
||
wasserstein_p_iid_ <- function(X_, Y_, p) { | ||
.Call('_approxOT_wasserstein_p_iid_', PACKAGE = 'approxOT', X_, Y_, p) | ||
} | ||
|
||
wasserstein_p_iid_p_ <- function(X_, Y_, p) { | ||
.Call('_approxOT_wasserstein_p_iid_p_', PACKAGE = 'approxOT', X_, Y_, p) | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
#' An R package to perform exact and approximate optimal transport. | ||
#' | ||
#' R and C++ functions to perform exact and approximate optimal transport. All C++ methods are linkable to other R packages via their header files. | ||
#' @author Eric Dunipace | ||
#' @docType package | ||
#' @name approxOT | ||
#' @useDynLib approxOT | ||
#' @importFrom Rcpp sourceCpp | ||
#' @importFrom Rcpp evalCpp | ||
#' @rdname approxOT-package | ||
"_PACKAGE" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
#' Calculate cost matrix | ||
#' | ||
#' @param X matrix of values in first sample. Observations should be by column, not rows. | ||
#' @param Y matrix of Values in second sample. Observations should be by column, not rows. | ||
#' @param ground_p power of the Lp norm to use in cost calculation. | ||
#' | ||
#' @return matrix of costs | ||
#' @export | ||
#' | ||
#' @examples | ||
#' X <- matrix(rnorm(10*100), 10, 100) | ||
#' Y <- matrix(rnorm(10*100), 10, 100) | ||
#' # the Euclidean distance | ||
#' cost <- cost_calc(X, Y, ground_p = 2) | ||
cost_calc <- function(X, Y, ground_p){ | ||
|
||
if (!is.double(ground_p) ) ground_p <- as.double(ground_p) | ||
if (nrow(X) != nrow(Y)) { | ||
stop("Rows of X and Y should be equal to have same dimension. Observations should be unique by column") | ||
} | ||
|
||
return( cost_calculation_(X, Y, ground_p) ) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
#' Get order along the Hilbert curve | ||
#' | ||
#' @param X matrix of values. Observations are unique by rows. | ||
#' @param Sigma Covariance of the data. If provided, uses a Mahalanobis distance. | ||
#' | ||
#' @return Index of orders | ||
#' @export | ||
#' | ||
#' @examples | ||
#' X <- matrix(rnorm(10*3), 3, 10) | ||
#' idx <- hilbert.projection(X) | ||
#' print(idx) | ||
hilbert.projection <- function(X, Sigma = NULL) { | ||
X <- as.matrix(t(X)) | ||
if(!is.null(Sigma)) { | ||
X <- backsolve(chol(Sigma), X, transpose = TRUE) | ||
} | ||
idx <- hilbert_proj_(X) + 1 # +1 to adjust C to R indexing | ||
return(idx) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# sinkhorn_distance <- function(x1, x2, p = 1, ground_p = 2, eps = 0.05, niter = 100){ | ||
# w1 <- rep(1/ncol(x1), ncol(x1)) | ||
# w2 <- rep(1/ncol(x2), ncol(x2)) | ||
# C <- cost_calc(x1, x2, ground_p) | ||
# epsilon <- eps * median(C^p) | ||
# wass <- sinkhorn_(w1, w2, C^p, epsilon, niter) | ||
# ### CORRECTION OF THE MARGINALS | ||
# # explained in the appendix of Coupling of Particle Filters, Jacob Lindsten Schon (arXiv v2 appendix E) | ||
# Phat <- wass$transportmatrix | ||
# u <- rowSums(Phat) | ||
# utilde <- colSums(Phat) | ||
# alpha <- min(pmin(w1/u, w2/utilde)) | ||
# r <- (w1 - alpha * u) / (1 - alpha) | ||
# rtilde <- (w2 - alpha * utilde) / (1 - alpha) | ||
# P <- alpha * Phat + (1 - alpha) * matrix(r, ncol = 1) %*% matrix(rtilde, nrow = 1) | ||
# return(list(uncorrected = (sum(Phat * C^p))^(1/p), corrected = (sum(P * C^p))^(1/p))) | ||
# } | ||
|
||
#' Test sinkhorn distance | ||
#' | ||
#' @param mass_x empiric measure of first sample | ||
#' @param mass_y empiric measure of second sample | ||
#' @param cost cost matrix | ||
#' @param p power to raise the cost matrix by | ||
#' @param eps epsilon of cost matrix | ||
#' @param niter number of iterations | ||
#' | ||
#' @return a numeric value | ||
#' | ||
#' @keywords internal | ||
sinkhorn_distance <- function(mass_x, mass_y, cost = NULL, p = 1, eps = 0.05, niter = 100){ | ||
costp <- cost^p | ||
epsilon <- eps * stats::median(costp) | ||
wass <- sinkhorn_(mass_x, mass_y, costp, epsilon, niter) | ||
### CORRECTION OF THE MARGINALS | ||
# explained in the appendix of Coupling of Particle Filters, Jacob Lindsten Schon (arXiv v2 appendix E) | ||
Phat <- wass$transportmatrix | ||
u <- rowSums(Phat) | ||
utilde <- colSums(Phat) | ||
alpha <- min(pmin(mass_x/u, mass_y/utilde)) | ||
r <- (mass_x - alpha * u) / (1 - alpha) | ||
rtilde <- (mass_y - alpha * utilde) / (1 - alpha) | ||
P <- if ( alpha < 1 ) { | ||
alpha * Phat + (1 - alpha) * matrix(r, ncol = 1) %*% matrix(rtilde, nrow = 1) | ||
} else { | ||
Phat | ||
} | ||
return(list(uncorrected = (sum(Phat * costp))^(1/p), corrected = (sum(P * costp))^(1/p))) | ||
} | ||
|
||
#' Test sinkhorn transportation plan | ||
#' | ||
#' @param mass_x empiric measure of first sample | ||
#' @param mass_y empiric measure of second sample | ||
#' @param cost cost matrix | ||
#' @param eps epsilon of cost matrix | ||
#' @param niterations number of iterations | ||
#' | ||
#' @return transportation plan as list with slots "from","to", and "mass" | ||
#' | ||
#' @keywords internal | ||
sinkhorn_transport <- function(mass_x, mass_y, cost = NULL, eps = 0.05, niterations = 100){ | ||
n1 <- length(mass_x) | ||
n2 <- length(mass_y) | ||
# costp <- cost^p | ||
epsilon <- eps * stats::median(cost) | ||
transp <- sinkhorn_(mass_x, mass_y, cost, epsilon, niterations) | ||
### CORRECTION OF THE MARGINALS | ||
# explained in the appendix of Coupling of Particle Filters, Jacob Lindsten Schon (arXiv v2 appendix E) | ||
Phat <- transp$transportmatrix | ||
u <- rowSums(Phat) | ||
utilde <- colSums(Phat) | ||
alpha <- min(pmin(mass_x/u, mass_y/utilde)) | ||
r <- (mass_x - alpha * u) / (1 - alpha) | ||
rtilde <- (mass_y - alpha * utilde) / (1 - alpha) | ||
P <- alpha * Phat + (1 - alpha) * matrix(r, ncol = 1) %*% matrix(rtilde, nrow = 1) | ||
return(list(from = rep(1:n1, n2), | ||
to = rep(1:n1, each = n2), | ||
mass = c(P))) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
#' Function returning supported optimal transportation methods. | ||
#' | ||
#' @return Returns a vector of supported transport methods | ||
#' @export | ||
#' | ||
#' @details The currently supported methods are | ||
#' \itemize{ | ||
#' \item{exact, networkflow:}{ Utilize the networkflow algorithm to solve the exact optimal transport problem} | ||
#' \item{shortsimplex:}{ Use the shortsimplex algorithm to solve the exact optimal transport problem} | ||
#' \item{sinkhorn:}{ Use Sinkhorn's algorithm to solve the approximate optimal transport problem} | ||
#' \item{greenkhorn:}{ Use the Greenkhorn algorithm to solve the approximate optimal transport problem} | ||
#' \item{randkhorn:}{ (NOT CURRENTLY IMPLEMENTED) Use the randkhorn algorithm to solve the approximate optimal transport problem} | ||
#' \item{grandkhorn:}{ (NOT CURRENTLY IMPLEMENTED) Use the grandkhorn algorithm to solve the approximate optimal transport problem} | ||
#' \item{hilbert:}{ Use hilbert sorting to perform approximate optimal transport} | ||
#' \item{rank:}{ use the average covariate ranks to perform approximate optimal transport} | ||
#' \item{univariate:}{ Use appropriate optimal transport methods for univariate data} | ||
#' \item{swapping:}{ Utilize the swapping algorithm to perform approximate optimal transport} | ||
#' \item{sliced:}{ Use the sliced optimal transport distance} | ||
#' } | ||
transport_options <- function() { | ||
return(c("exact", "networkflow","shortsimplex", | ||
"sinkhorn", "greenkhorn", | ||
"randkhorn", "gandkhorn", | ||
"hilbert", "rank", "univariate", | ||
"univariate.approximation.pwr", | ||
"swapping", "sliced")) | ||
} |
Oops, something went wrong.