Skip to content

Commit

Permalink
Bug fix kld_est (multiple numerical columns were incorrectly split by…
Browse files Browse the repository at this point in the history
… the discrete variables)
  • Loading branch information
niklhart committed Jan 26, 2024
1 parent 9cb9e6f commit 1678d7e
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 24 deletions.
13 changes: 7 additions & 6 deletions R/kld-estimation-interfaces.R
Original file line number Diff line number Diff line change
Expand Up @@ -97,24 +97,25 @@ kld_est <- function(X, Y = NULL, q = NULL, estimator.continuous = kld_est_nn,
}

# now we know it's a mixed discrete/continuous dataset
Xcont <- as.matrix(X[vartype == "c"])
Xcont <- as.data.frame(X[vartype == "c"])
Xdisc <- X[vartype == "d"]
iXdisc <- interaction(Xdisc, drop = TRUE)
sXcont <- split(Xcont, f = iXdisc)
sXcont <- lapply(X = split(Xcont, f = iXdisc), FUN = as.matrix)

if (two.sample) {
Ycont <- as.matrix(Y[vartype == "c"])
Ycont <- as.data.frame(Y[vartype == "c"])
Ydisc <- Y[vartype == "d"]
sYcont <- split(Ycont, f = interaction(Ydisc, drop = TRUE))
sYcont <- lapply(X = split(Ycont, f = interaction(Ydisc, drop = TRUE)),
FUN = as.matrix)

KLcont <- mapply(FUN = estimator.continuous, X = sXcont, Y = sYcont)
KLdisc <- estimator.discrete(Xdisc,Ydisc)
KLdisc <- estimator.discrete(X = Xdisc, Y = Ydisc)

} else {
qXcond <- lapply(Xdisc[match(levels(iXdisc),iXdisc), ],
function(d) {force(d); function(c) q$cond(c,d)})
KLcont <- mapply(FUN = estimator.continuous, X = sXcont, q = qXcond)
KLdisc <- estimator.discrete(Xdisc, q = q$disc)
KLdisc <- estimator.discrete(X = Xdisc, q = q$disc)
}

# return compound KL divergence
Expand Down
36 changes: 18 additions & 18 deletions tests/testthat/test-kld-interfaces.R
Original file line number Diff line number Diff line change
Expand Up @@ -103,28 +103,28 @@ test_that("kld_est works as expected for numeric data", {
test_that("kld_est works as expected for mixed data", {

# check that heuristic for detecting column type works
Xnn <- data.frame(A = c(1,1,1,2,2),B = c(1,1,2,2,2))
Ynn <- data.frame(A = c(1,1,2,2), B = c(1,2,1,2))
Xnc <- Xnn; Xnc$B <- as.character(Xnc$B)
Ync <- Ynn; Ync$B <- as.character(Ync$B)
Xnnn <- data.frame(A = 1:5, B = 6:10, C = c(1,1,2,2,2))
Ynnn <- data.frame(A = 11:14, B = 15:18, C = c(1,2,1,2))
Xnnc <- Xnnn; Xnnc$C <- as.character(Xnnc$C)
Ynnc <- Ynnn; Ynnc$C <- as.character(Ynnc$C)

KLnn_est <- kld_est(Xnn, Ynn, vartype = c("c","d"))
KLnc_est <- kld_est(Xnc, Ync)
KLnnn_est <- kld_est(Xnnn, Ynnn, vartype = c("c","c","d"))
KLnnc_est <- kld_est(Xnnc, Ynnc)

expect_equal(KLnn_est, KLnc_est)
expect_equal(KLnnn_est, KLnnc_est)

# check that computed KL-D agrees with hardcoded mixed KL-D
X1 <- Xnn$A[Xnn$B == 1]
X2 <- Xnn$A[Xnn$B == 2]
Y1 <- Ynn$A[Ynn$B == 1]
Y2 <- Ynn$A[Ynn$B == 2]
X1 <- Xnnn[Xnnn$C == 1, c("A","B")]
X2 <- Xnnn[Xnnn$C == 2, c("A","B")]
Y1 <- Ynnn[Ynnn$C == 1, c("A","B")]
Y2 <- Ynnn[Ynnn$C == 2, c("A","B")]

p1 <- mean(Xnn$B == 1); p2 <- 1 - p1
q1 <- mean(Ynn$B == 1); q2 <- 1 - q1
p1 <- mean(Xnnn$C == 1); p2 <- 1 - p1
q1 <- mean(Ynnn$C == 1); q2 <- 1 - q1

KL_ref <- p1*kld_est_nn(X1, Y1) + p2*kld_est_nn(X2, Y2) + kld_discrete(c(p1,p2),c(q1,q2))

expect_equal(KLnn_est,KL_ref)
expect_equal(KLnnn_est,KL_ref)

# 2D example, one sample
X <- data.frame(A = rnorm(5),
Expand All @@ -134,11 +134,11 @@ test_that("kld_est works as expected for mixed data", {

KL_Xq_est <- kld_est(X, q = q, vartype = c("c","d"))

p1 <- mean(Xnn$B == 1); p2 <- 1 - p1
p0 <- mean(X$B == 0); p1 <- 1 - p0

KL_Xq_ref <- p1*kld_est_nn(X$A[X$B == 0], q = function(x) dnorm(x, mean = 0)) +
p2*kld_est_nn(X$A[X$B == 1], q = function(x) dnorm(x, mean = 1)) +
kld_discrete(c(p1,1-p1), c(0.5,0.5))
KL_Xq_ref <- p0 * kld_est_nn(X$A[X$B == 0], q = function(x) dnorm(x, mean = 0)) +
p1 * kld_est_nn(X$A[X$B == 1], q = function(x) dnorm(x, mean = 1)) +
kld_discrete(c(p0,p1), c(0.5,0.5))

expect_equal(KL_Xq_est,KL_Xq_ref)

Expand Down

0 comments on commit 1678d7e

Please sign in to comment.