-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #8 from hoxo-m/RuLSIF
implement RuSLIF
- Loading branch information
Showing
7 changed files
with
183 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
|
||
S3method(print,densratio) | ||
export(KLIEP) | ||
export(RuLSIF) | ||
export(densratio) | ||
export(uLSIF) | ||
importFrom(utils,str) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
#' Estimate alpha-Relative Density Ratio p(x)/(alpha p(x) + (1-alpha) q(x)) | ||
#' by RuLSIF (Relative unconstrained Least-Square Importance Fitting) | ||
#' | ||
#' @param x1 numeric vector or matrix. Data from a numerator distribution p(x). | ||
#' @param x2 numeric vector or matrix. Data from a denominator distribution q(x). | ||
#' @param sigma positive numeric vector. Search range of Gaussian kernel bandwidth. | ||
#' @param lambda positive numeric vector. Search range of regularization parameter. | ||
#' @param alpha numeric value from 0.0 to 1.0. Relative parameter. Default 0.1. | ||
#' @param kernel_num positive integer. Number of kernels. | ||
#' @param verbose logical. Default TRUE. | ||
#' | ||
#' @return RuLSIF object which has `compute_density_ratio()`. | ||
#' | ||
#' @export | ||
RuLSIF <- function(x1, x2, | ||
sigma = 10 ^ seq(-3, 1, length.out = 9), | ||
lambda = 10 ^ seq(-3, 1, length.out = 9), | ||
alpha = 0.1, kernel_num = 100, verbose = TRUE) { | ||
|
||
if(verbose) message("################## Start RuLSIF ##################") | ||
if(is.vector(x1)) x1 <- matrix(x1) | ||
if(is.vector(x2)) x2 <- matrix(x2) | ||
if(ncol(x1) != ncol(x2)) stop("x1 and x2 must be same dimensions.") | ||
|
||
nx1 <- nrow(x1) | ||
nx2 <- nrow(x2) | ||
|
||
kernel_num <- min(kernel_num, nx1) | ||
centers <- x1[sample(nx1, size = kernel_num), , drop = FALSE] | ||
|
||
if(length(sigma) != 1 || length(lambda) != 1) { | ||
if(verbose) message("Searching optimal sigma and lambda...") | ||
opt_params <- RuLSIF_search_sigma_and_lambda(x1, x2, centers, sigma, lambda, alpha, verbose) | ||
sigma <- opt_params$sigma | ||
lambda <- opt_params$lambda | ||
if(verbose) message(sprintf("Found optimal sigma = %.3f, lambda = %.3f.", sigma, lambda)) | ||
} | ||
|
||
if(verbose) message("Optimizing kernel weights...") | ||
phi_x1 <- compute_kernel_Gaussian(x1, centers, sigma) | ||
phi_x2 <- compute_kernel_Gaussian(x2, centers, sigma) | ||
H <- alpha * crossprod(phi_x1) / nx1 + (1 - alpha) * crossprod(phi_x2) / nx2 | ||
h <- colMeans(phi_x1) | ||
kernel_weights <- solve(H + diag(lambda, kernel_num, kernel_num)) %*% h | ||
kernel_weights[kernel_weights < 0] <- 0 | ||
if(verbose) message("End.") | ||
|
||
result <- list(kernel_weights = as.vector(kernel_weights), | ||
lambda = lambda, | ||
alpha = alpha, | ||
kernel_info = list( | ||
kernel = "Gaussian", | ||
kernel_num = kernel_num, | ||
sigma = sigma, | ||
centers = centers | ||
), | ||
compute_density_ratio = function(x) { | ||
if(is.vector(x)) x <- matrix(x) | ||
phi_x <- compute_kernel_Gaussian(x, centers, sigma) | ||
density_ratio <- as.vector(phi_x %*% kernel_weights) | ||
density_ratio | ||
} | ||
) | ||
class(result) <- c("RuLSIF", class(result)) | ||
if(verbose) message("################## Finished RuLSIF ###############") | ||
result | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
RuLSIF_search_sigma_and_lambda <- function(x, y, centers, sigma_list, lambda_list, alpha, verbose) { | ||
nx <- nrow(x) | ||
ny <- nrow(y) | ||
n_min <- min(nx, ny) | ||
kernel_num <- nrow(centers) | ||
|
||
score_new <- Inf | ||
sigma_new <- 0 | ||
lambda_new <- 0 | ||
for (sigma in sigma_list) { | ||
phi_x <- compute_kernel_Gaussian(x, centers, sigma) | ||
phi_y <- compute_kernel_Gaussian(y, centers, sigma) | ||
H <- alpha * crossprod(phi_x) / nx + (1 - alpha) * crossprod(phi_y) / ny | ||
h <- colMeans(phi_x) | ||
phi_x <- t(phi_x[seq_len(n_min), , drop = FALSE]) | ||
phi_y <- t(phi_y[seq_len(n_min), , drop = FALSE]) | ||
for (lambda in lambda_list) { | ||
B <- H + diag(lambda * (ny - 1) / ny, nrow = kernel_num, ncol = kernel_num) | ||
B_inv <- solve(B) | ||
B_inv_X <- B_inv %*% phi_y | ||
X_B_inv_X <- phi_y * B_inv_X | ||
denom <- ones(n_min, value = ny) - ones(kernel_num) %*% X_B_inv_X | ||
B0 <- B_inv %*% (h %*% ones(n_min)) + | ||
B_inv_X %*% diag(as.vector((t(h) %*% B_inv_X) / denom)) | ||
B1 <- B_inv %*% phi_x + | ||
B_inv_X %*% diag(as.vector((ones(kernel_num) %*% (phi_x * B_inv_X)) / denom)) | ||
B2 <- (ny-1) * (nx * B0 - B1) / (ny * (nx - 1)) | ||
B2[B2 < 0] <- 0 | ||
r_y <- t(ones(kernel_num) %*% (phi_y * B2)) | ||
r_x <- t(ones(kernel_num) %*% (phi_x * B2)) | ||
score <- (crossprod(r_y) / 2 - ones(n_min) %*% r_x) / n_min | ||
if(score < score_new) { | ||
if(verbose) message(sprintf(" sigma = %.3f, lambda = %.3f, score = %.3f", sigma, lambda, score)) | ||
score_new <- score | ||
sigma_new <- sigma | ||
lambda_new <- lambda | ||
} | ||
} | ||
} | ||
list(sigma = sigma_new, lambda = lambda_new) | ||
} | ||
|
||
ones <- function(size, value=1) { | ||
t(rep(value, size)) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.