In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import math
from scipy.special import digamma
from sklearn.neighbors import BallTree, KDTree

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)

In [None]:
_metric_tree_types = {"kd_tree": KDTree, "ball_tree": BallTree}

class KSG:
    def __init__(self, k_neighbors=3, tree_type="kd_tree"):
        self.k_neighbours, self.tree_type = k_neighbors, tree_type

    def _tree(self, data):
        return _metric_tree_types[self.tree_type](data, metric="chebyshev")

    def __call__(self, x, y, std=False):
        n_samples, k_neighbours = x.shape[0], min(self.k_neighbours, x.shape[0] - 1)
        x, y = x.reshape(n_samples, -1), y.reshape(n_samples, -1)
        xy = np.concatenate([x, y], axis=-1)

        xt, yt, xyt = self._tree(x), self._tree(y), self._tree(xy)
        distances, _ = xyt.query(xy, k=k_neighbours + 1)
        distances = distances[:, k_neighbours] - 1e-15
        nx = xt.query_radius(x, distances, count_only=True)
        ny = yt.query_radius(y, distances, count_only=True)

        vals = digamma(nx) + digamma(ny)
        mi = max(0.0, digamma(k_neighbours) + digamma(n_samples) - np.mean(vals))
        return (mi, np.std(vals) / math.sqrt(n_samples)) if std else mi

In [None]:
def random_orthonormal_matrices_torch(d, k, batch, device=DEVICE):
    """
    Return tensor of shape (batch, d, k) with orthonormal columns for each element (Haar).
    Uses batched QR.
    """
    A = torch.randn((batch, d, k), device=device, dtype=torch.float32)
    Q, R = torch.linalg.qr(A)  # Q: (batch,d,k), R: (batch,k,k)
    diag = torch.sign(torch.diagonal(R, dim1=-2, dim2=-1))  # (batch, k)
    diag[diag == 0] = 1.0
    Q = Q * diag.unsqueeze(1)  # broadcast -> (batch, d, k)
    return Q

In [None]:
def whiten_torch(X, eps=1e-6, device=DEVICE):
    """
    PCA-whitening on device (torch). Returns tensor (n, d) on device.
    Accepts numpy arrays or torch tensors.
    """
    if isinstance(X, np.ndarray):
        X = torch.from_numpy(X).to(dtype=torch.float32, device=device)
    else:
        X = X.to(dtype=torch.float32, device=device)

    Xc = X - X.mean(dim=0, keepdim=True)
    n = Xc.shape[0]
    cov = (Xc.t() @ Xc) / (n - 1)  # (d, d)

    # SVD (stable on symmetric matrices)
    U, S, Vh = torch.linalg.svd(cov)  # S shape (d,)
    inv_sqrt = torch.diag(1.0 / torch.sqrt(S + eps))
    W = U @ inv_sqrt @ U.t()
    Xw = Xc @ W
    return Xw

In [None]:
def make_projection_matrices(dX, dY, proj_k, n_proj, device=DEVICE, seed=0):
    torch.manual_seed(seed)
    U_list = random_orthonormal_matrices_torch(dX, proj_k, n_proj, device=device)  # (n_proj,dX,k)
    V_list = random_orthonormal_matrices_torch(dY, proj_k, n_proj, device=device)  # (n_proj,dY,k)
    return U_list, V_list

def sliced_mi_with_fixed_proj(X_np, Y_np, U_list, V_list, knn=5, device=DEVICE):
    Xw = whiten_torch(X_np, device=device)
    Yw = whiten_torch(Y_np, device=device)
    n_proj = U_list.shape[0]
    proj_k = U_list.shape[-1]

    X_proj_batch = torch.einsum('nd,bdk->bnk', Xw, U_list)  # (n_proj, n, k)
    Y_proj_batch = torch.einsum('nd,bdk->bnk', Yw, V_list)

    X_cpu = X_proj_batch.detach().cpu().numpy()
    Y_cpu = Y_proj_batch.detach().cpu().numpy()

    ksg = KSG(k_neighbors=knn)
    mis = [float(ksg(X_cpu[i], Y_cpu[i])) for i in range(n_proj)]
    mis = np.array(mis)
    return mis.mean(), mis.std() / np.sqrt(len(mis)), mis  


In [None]:
def compare_two_datasets(X1, Y1, X2, Y2, proj_k=2, n_proj=500, knn=3, seed=0):
    dX, dY = X1.shape[1], Y1.shape[1]
    U_list, V_list = make_projection_matrices(dX, dY, proj_k, n_proj, seed=seed, device=DEVICE)

    mean1, se1, mis1 = sliced_mi_with_fixed_proj(X1, Y1, U_list, V_list, knn=knn)
    mean2, se2, mis2 = sliced_mi_with_fixed_proj(X2, Y2, U_list, V_list, knn=knn)

    print(f"Dataset A: {mean1:.4f} ± {se1:.4f}")
    print(f"Dataset B: {mean2:.4f} ± {se2:.4f}")

    plt.figure(figsize=(8,4))
    plt.hist(mis1, bins=30, alpha=0.5, label='Dataset A')
    plt.hist(mis2, bins=30, alpha=0.5, label='Dataset B')
    plt.axvline(mean1, color='blue', linestyle='--')
    plt.axvline(mean2, color='orange', linestyle='--')
    plt.legend()
    plt.xlabel("Sliced MI")
    plt.ylabel("Count")
    plt.title("Distribution of Sliced MI across projections")
    plt.show()

    return (mean1, se1, mis1), (mean2, se2, mis2)

In [None]:
X_dim, Y_dim = 15, 20 
n_samples = 1000

# Dataset A
dist1 = CorrelatedNormal(1.0, X_dim, Y_dim)
X1, Y1 = dist1.rvs(n_samples)

# Dataset B
dist2 = CorrelatedNormal(2.0, X_dim, Y_dim)
X2, Y2 = dist2.rvs(n_samples)


resA, resB = compare_two_datasets(X1, Y1, X2, Y2, proj_k=2, n_proj=1000, knn=3, seed=42)
