Skip to content

Commit

Permalink
Merge pull request #8 from hoxo-m/RuLSIF
Browse files Browse the repository at this point in the history
implement RuSLIF
  • Loading branch information
hoxo-m authored Jun 23, 2019
2 parents d15d406 + 9257acc commit 3a6cd81
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 18 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

S3method(print,densratio)
export(KLIEP)
export(RuLSIF)
export(densratio)
export(uLSIF)
importFrom(utils,str)
67 changes: 67 additions & 0 deletions R/RuLSIF.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#' Estimate alpha-Relative Density Ratio p(x)/(alpha p(x) + (1-alpha) q(x))
#' by RuLSIF (Relative unconstrained Least-Square Importance Fitting)
#'
#' @param x1 numeric vector or matrix. Data from a numerator distribution p(x).
#' @param x2 numeric vector or matrix. Data from a denominator distribution q(x).
#' @param sigma positive numeric vector. Search range of Gaussian kernel bandwidth.
#' @param lambda positive numeric vector. Search range of regularization parameter.
#' @param alpha numeric value from 0.0 to 1.0. Relative parameter. Default 0.1.
#' @param kernel_num positive integer. Number of kernels.
#' @param verbose logical. Default TRUE.
#'
#' @return RuLSIF object which has `compute_density_ratio()`.
#'
#' @export
RuLSIF <- function(x1, x2,
sigma = 10 ^ seq(-3, 1, length.out = 9),
lambda = 10 ^ seq(-3, 1, length.out = 9),
alpha = 0.1, kernel_num = 100, verbose = TRUE) {

if(verbose) message("################## Start RuLSIF ##################")
if(is.vector(x1)) x1 <- matrix(x1)
if(is.vector(x2)) x2 <- matrix(x2)
if(ncol(x1) != ncol(x2)) stop("x1 and x2 must be same dimensions.")

nx1 <- nrow(x1)
nx2 <- nrow(x2)

kernel_num <- min(kernel_num, nx1)
centers <- x1[sample(nx1, size = kernel_num), , drop = FALSE]

if(length(sigma) != 1 || length(lambda) != 1) {
if(verbose) message("Searching optimal sigma and lambda...")
opt_params <- RuLSIF_search_sigma_and_lambda(x1, x2, centers, sigma, lambda, alpha, verbose)
sigma <- opt_params$sigma
lambda <- opt_params$lambda
if(verbose) message(sprintf("Found optimal sigma = %.3f, lambda = %.3f.", sigma, lambda))
}

if(verbose) message("Optimizing kernel weights...")
phi_x1 <- compute_kernel_Gaussian(x1, centers, sigma)
phi_x2 <- compute_kernel_Gaussian(x2, centers, sigma)
H <- alpha * crossprod(phi_x1) / nx1 + (1 - alpha) * crossprod(phi_x2) / nx2
h <- colMeans(phi_x1)
kernel_weights <- solve(H + diag(lambda, kernel_num, kernel_num)) %*% h
kernel_weights[kernel_weights < 0] <- 0
if(verbose) message("End.")

result <- list(kernel_weights = as.vector(kernel_weights),
lambda = lambda,
alpha = alpha,
kernel_info = list(
kernel = "Gaussian",
kernel_num = kernel_num,
sigma = sigma,
centers = centers
),
compute_density_ratio = function(x) {
if(is.vector(x)) x <- matrix(x)
phi_x <- compute_kernel_Gaussian(x, centers, sigma)
density_ratio <- as.vector(phi_x %*% kernel_weights)
density_ratio
}
)
class(result) <- c("RuLSIF", class(result))
if(verbose) message("################## Finished RuLSIF ###############")
result
}
45 changes: 45 additions & 0 deletions R/RuLSIF_search_sigma_and_lambda.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
RuLSIF_search_sigma_and_lambda <- function(x, y, centers, sigma_list, lambda_list, alpha, verbose) {
nx <- nrow(x)
ny <- nrow(y)
n_min <- min(nx, ny)
kernel_num <- nrow(centers)

score_new <- Inf
sigma_new <- 0
lambda_new <- 0
for (sigma in sigma_list) {
phi_x <- compute_kernel_Gaussian(x, centers, sigma)
phi_y <- compute_kernel_Gaussian(y, centers, sigma)
H <- alpha * crossprod(phi_x) / nx + (1 - alpha) * crossprod(phi_y) / ny
h <- colMeans(phi_x)
phi_x <- t(phi_x[seq_len(n_min), , drop = FALSE])
phi_y <- t(phi_y[seq_len(n_min), , drop = FALSE])
for (lambda in lambda_list) {
B <- H + diag(lambda * (ny - 1) / ny, nrow = kernel_num, ncol = kernel_num)
B_inv <- solve(B)
B_inv_X <- B_inv %*% phi_y
X_B_inv_X <- phi_y * B_inv_X
denom <- ones(n_min, value = ny) - ones(kernel_num) %*% X_B_inv_X
B0 <- B_inv %*% (h %*% ones(n_min)) +
B_inv_X %*% diag(as.vector((t(h) %*% B_inv_X) / denom))
B1 <- B_inv %*% phi_x +
B_inv_X %*% diag(as.vector((ones(kernel_num) %*% (phi_x * B_inv_X)) / denom))
B2 <- (ny-1) * (nx * B0 - B1) / (ny * (nx - 1))
B2[B2 < 0] <- 0
r_y <- t(ones(kernel_num) %*% (phi_y * B2))
r_x <- t(ones(kernel_num) %*% (phi_x * B2))
score <- (crossprod(r_y) / 2 - ones(n_min) %*% r_x) / n_min
if(score < score_new) {
if(verbose) message(sprintf(" sigma = %.3f, lambda = %.3f, score = %.3f", sigma, lambda, score))
score_new <- score
sigma_new <- sigma
lambda_new <- lambda
}
}
}
list(sigma = sigma_new, lambda = lambda_new)
}

ones <- function(size, value=1) {
t(rep(value, size))
}
25 changes: 15 additions & 10 deletions R/densratio.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
#'
#' @param x numeric vector or matrix. Data from a numerator distribution p(x).
#' @param y numeric vector or matrix. Data from a denominator distribution q(y).
#' @param method "uLSIF"(default) or "KLIEP".
#' @param method "uLSIF" (default), "KLIEP", or "RuLSIF".
#' @param sigma positive numeric vector. Search range of Gaussian kernel bandwidth.
#' @param lambda positive numeric vector. Search range of regularization parameter for uLSIF.
#' @param lambda positive numeric vector. Search range of regularization parameter for uLSIF and RuLSIF.
#' @param alpha numeric in [0, 1]. Relative parameter for RuLSIF. Default 0.1.
#' @param kernel_num positive integer. Number of kernels.
#' @param fold positive integer. Numer of the folds of cross validation for KLIEP.
#' @param verbose logical(default TRUE).
#' @param verbose logical (default TRUE).
#'
#' @return densratio object that contains a function to compute estimated density ratio.
#'
Expand All @@ -23,22 +24,28 @@
#' plot(new_x, estimated_density_ratio, pch=19)
#'
#' @export
densratio <- function(x, y, method = c("uLSIF", "KLIEP"),
sigma = "auto", lambda = "auto",
densratio <- function(x, y, method = c("uLSIF", "RuLSIF", "KLIEP"),
sigma = "auto", lambda = "auto", alpha = 0.1,
kernel_num = 100, fold = 5, verbose = TRUE) {
# Prepare Arguments -------------------------------------------------------
method <- match.arg(method)

# To Retain Default Arguments in Functions of Methods ---------------------
params <- alist(x = x, y = y, kernel_num = kernel_num, verbose = verbose)
if(!identical(sigma, "auto")) {
if (!identical(sigma, "auto")) {
params <- c(params, alist(sigma = sigma))
}

# Run ---------------------------------------------------------------------
if(method == "uLSIF") {
if(!identical(lambda, "auto")) params <- c(params, alist(lambda = lambda))
if (method == "uLSIF") {
if (!identical(lambda, "auto")) params <- c(params, alist(lambda = lambda))
result <- do.call(uLSIF, params)
} else if (method == "RuLSIF") {
params <- alist(x1 = x, x2 = y, kernel_num = kernel_num, verbose = verbose)
if (!identical(sigma, "auto")) params <- c(params, alist(sigma = sigma))
if (!identical(lambda, "auto")) params <- c(params, alist(lambda = lambda))
params <- c(params, alist(alpha = alpha))
result <- do.call(RuLSIF, params)
} else {
params <- c(params, alist(fold = fold))
result <- do.call(KLIEP, params)
Expand All @@ -52,5 +59,3 @@ densratio <- function(x, y, method = c("uLSIF", "KLIEP"),
class(result) <- c("densratio", class(result))
result
}


17 changes: 14 additions & 3 deletions R/print.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,26 @@ print.densratio <- function(x, digits = 3L, ...) {
str(info$centers, digits.d = digits, give.attr = FALSE)
cat("\n")

cat("Kernel Weights(alpha):\n ")
str(x$alpha, digits.d = digits, give.attr = FALSE)
cat("\n")

if("uLSIF" %in% class(x)) {
cat("Kernel Weights:\n ")
str(x$alpha, digits.d = digits, give.attr = FALSE)
cat("\n")
cat("Regularization Parameter(lambda): ", x$lambda, "\n\n")
}

if("RuLSIF" %in% class(x)) {
cat("Kernel Weights:\n ")
str(x$kernel_weights, digits.d = digits, give.attr = FALSE)
cat("\n")
cat("Regularization Parameter (lambda): ", x$lambda, "\n\n")
cat("Relative Parameter (alpha): ", x$alpha, "\n\n")
}

if("KLIEP" %in% class(x)) {
cat("Kernel Weights:\n ")
str(x$alpha, digits.d = digits, give.attr = FALSE)
cat("\n")
cat("Number of the Folds: ", x$fold, "\n\n")
}

Expand Down
33 changes: 33 additions & 0 deletions man/RuLSIF.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 8 additions & 5 deletions man/densratio.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 3a6cd81

Please sign in to comment.