In [1]:
# the discrete case

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import numpy as np

import numpy as np
import matplotlib.pyplot as plt

In [2]:
import numpy as np
import matplotlib.pyplot as plt

def make_gmm(n_components=10, dim=2, seed=42,
             mean_range=(-5.0, 5.0), var_range=(0.05, 1.0),
             isotropic=True):
    """
    Build parameters for a Gaussian Mixture Model (GMM).

    - n_components: number of Gaussian centers (clusters)
    - dim: dimensionality (2 by default)
    - seed: RNG seed for reproducibility
    - mean_range: range for sampling means uniformly per dimension
    - var_range: range for sampling variances (sigma^2); used if isotropic=True
    - isotropic: if True, use scalar variance per component (sigma^2 * I).
                 if False, generate random full SPD covariances.

    Returns dict with keys: weights, means, covs, rng
    """
    rng = np.random.default_rng(seed)

    # Mixture weights (sum to 1)
    weights = rng.dirichlet(np.ones(n_components))

    # Means
    means = rng.uniform(mean_range[0], mean_range[1], size=(n_components, dim))

    # Covariances
    if isotropic:
        # One random variance per component (same across dimensions)
        variances = rng.uniform(var_range[0], var_range[1], size=n_components)
        covs = np.array([v * np.eye(dim) for v in variances])
    else:
        # Random full SPD covariances with average variance within var_range
        covs = []
        target_var = rng.uniform(var_range[0], var_range[1], size=n_components)
        for tvar in target_var:
            A = rng.normal(size=(dim, dim))
            C = A @ A.T + 1e-3 * np.eye(dim)  # SPD
            # Scale so average variance (trace/dim) matches tvar
            scale = tvar / (np.trace(C) / dim)
            covs.append(C * scale)
        covs = np.array(covs)

    return dict(weights=weights, means=means, covs=covs, rng=rng)

def sample_gmm(n_samples, weights, means, covs, rng=None):
    """
    Sample points from a GMM described by weights, means, covs.
    Returns (X, component_indices)
    """
    if rng is None:
        rng = np.random.default_rng()

    K = len(weights)
    dim = means.shape[1]

    # Choose component for each sample
    comps = rng.choice(K, size=n_samples, p=weights)

    # Draw samples
    X = np.empty((n_samples, dim))
    for k in range(K):
        idx = np.where(comps == k)[0]
        if idx.size > 0:
            X[idx] = rng.multivariate_normal(mean=means[k], cov=covs[k], size=idx.size)
    return X, comps

In [3]:
save_dataset = []

params = make_gmm(n_components=20, dim=2, seed=123, mean_range=(-6, 6), var_range=(0.2, 0.8), isotropic=True)
X, z = sample_gmm(n_samples=20000, **params)
train_data = X-X.min()
train_data = train_data/train_data.max()

save_dataset.append(train_data)

params = make_gmm(n_components=20, dim=2, seed=127, mean_range=(-6, 6), var_range=(0.2, 0.8), isotropic=True)
X, z = sample_gmm(n_samples=20000, **params)
train_data = X-X.min()
train_data = train_data/train_data.max()

save_dataset.append(train_data)

params = make_gmm(n_components=20, dim=2, seed=128, mean_range=(-6, 6), var_range=(0.2, 0.8), isotropic=True)
X, z = sample_gmm(n_samples=20000, **params)
train_data = X-X.min()
train_data = train_data/train_data.max()

save_dataset.append(train_data)

In [6]:
# BELOW IS KICA

from scipy.linalg import eigh

def negative_half_power(RX):
    E1,V1 = torch.linalg.eigh(RX)
    RF_NORM = V1@torch.diag(E1**(-1/2))@V1.T
    return RF_NORM

def compute_gram_matrix(X, Y, sigma):
    """Compute the Gram matrix between datasets X and Y using the Gaussian kernel."""
    # Compute pairwise squared distances between X and Y in a vectorized manner
    pairwise_sq_dists = np.sum(X**2, axis=1).reshape(-1, 1) + np.sum(Y**2, axis=1) - 2 * np.dot(X, Y.T)
    
    # Compute the Gram matrix
    gram_matrix = np.exp(-pairwise_sq_dists / (2 * sigma**2))
    return gram_matrix

def gaussian_kernel(x, y, sigma):
    return np.exp(-np.linalg.norm(x-y)**2 / (2*sigma**2))

def kernel_matrix(data, sigma):
    
    return compute_gram_matrix(data[:].reshape(-1, 1), data[:].reshape(-1, 1), sigma)

def hsic_score(X, Y, sigma_x, sigma_y):
    """Compute HSIC using Gaussian kernels."""
    n = len(X)

    # Kernel matrices
    K = kernel_matrix(X, sigma_x)
    L = kernel_matrix(Y, sigma_y)

    # Centering matrix
    H = np.eye(n) - np.ones((n, n)) / n

    # Compute HSIC
    
    return (1.0 / (n-1)**2) * np.trace(np.dot(np.dot(np.dot(K, H), L), H))


# def HSIC_spectrum(pick_data, sigma=.1, alpha=1e-5):
    
# #     gram_matrix_x = compute_gram_matrix(pick_data[:, 0].reshape(-1, 1), pick_data[:, 0].reshape(-1, 1), sigma)
# #     gram_matrix_u = compute_gram_matrix(pick_data[:, 1].reshape(-1, 1), pick_data[:, 1].reshape(-1, 1), sigma)
# #     gram_matrix_cross = compute_gram_matrix(pick_data[:, 0].reshape(-1, 1), pick_data[:, 1].reshape(-1, 1), sigma)

# #     RX = torch.from_numpy(gram_matrix_x) + torch.eye(gram_matrix_x.shape[0])*alpha
# #     RY = torch.from_numpy(gram_matrix_u) + torch.eye(gram_matrix_u.shape[0])*alpha
# #     RXY = torch.from_numpy(gram_matrix_cross) 

# #     normalized_density = negative_half_power(RX)@RXY@negative_half_power(RY)
# #     U, S, V = torch.linalg.svd(normalized_density)    
    
# #     # IMPLEMENT THE MEASURE

    
    
#     return U, S, V

def KICA_spectrum(data_x, data_y, alpha=1e-1):
    gram_matrix_x = gauss(data_x, data_x, 0.001)
    gram_matrix_u = gauss(data_y, data_y, 0.001)
#     gram_matrix_cross = compute_gram_matrix(data_x, data_y, sigma)
    
    N0_matrix = (torch.eye(gram_matrix_x.shape[0]) - 1/gram_matrix_x.shape[0]).float()
    
    normalized_x = N0_matrix@gram_matrix_x@N0_matrix
    normalized_u = N0_matrix@gram_matrix_u@N0_matrix

    normalized_x_add = normalized_x + torch.eye((normalized_x.shape[0]))*alpha
    normalized_u_add = normalized_u + torch.eye((normalized_u.shape[0]))*alpha

    num_sample = data_x.shape[0]

    matrix_a = np.zeros((num_sample*2, num_sample*2))

    matrix_a[:num_sample, num_sample:] = normalized_x@normalized_u.T
    matrix_a[num_sample:, :num_sample] = normalized_u@normalized_x.T

    matrix_b = np.zeros((num_sample*2, num_sample*2))

    matrix_b[:num_sample, :num_sample] = normalized_x_add@normalized_x_add.T
    matrix_b[num_sample:, num_sample:] = normalized_u_add@normalized_u_add.T

    eigenvalues, eigenvectors = eigh(matrix_a, matrix_b)
    
    spectrum_kica = eigenvalues[data_x.shape[0]:][::-1]
    
    
    RF = normalized_x_add@normalized_x_add.T
    RG = normalized_u_add@normalized_u_add.T
    P = normalized_x@normalized_u.T

    E1,V1 = torch.linalg.eigh(RF)
    E2,V2 = torch.linalg.eigh(RG)

    RF_NORM = V1@torch.diag(E1**(-1/2))@V1.T
    RG_NORM = V2@torch.diag(E2**(-1/2))@V2.T
    P_STAR = RF_NORM@P@RG_NORM

    U, S, V = torch.svd(P_STAR)
        
    measure_kica = (-1/2)*np.log((1 - spectrum_kica[:]**2)).sum()
#     measure_kica = np.linalg.det(P_STAR)
    measure_hsic = np.trace(P_STAR)
#     measure_hsic = np.sum(spectrum_kica[:]**2)

    
    return spectrum_kica, eigenvectors, U, S, V, measure_kica, measure_hsic


In [27]:
class decoder_MINER(nn.Module):
    def __init__(self, input_dim = 784, HIDDEN = 2000, out_dim = 200):
        super(decoder_MINER, self).__init__()
        self.dim = out_dim
    
#         self.fc1 = nn.Linear(input_dim+20, HIDDEN, bias=True)
        self.fc1 = nn.Linear(input_dim, HIDDEN, bias=True)
        self.bn1 = torch.nn.BatchNorm1d(HIDDEN)
        self.fc2 = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn2 = torch.nn.BatchNorm1d(HIDDEN)
        self.fc3 = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn3 = torch.nn.BatchNorm1d(HIDDEN)
        self.fc4 = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn4 = torch.nn.BatchNorm1d(HIDDEN)
        self.fc6 = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn6 = torch.nn.BatchNorm1d(HIDDEN)
        self.fc7 = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn7 = torch.nn.BatchNorm1d(HIDDEN)

        self.fc8 = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn8 = torch.nn.BatchNorm1d(HIDDEN)
        self.fc9 = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn9 = torch.nn.BatchNorm1d(HIDDEN)

        self.fc5 = nn.Linear(HIDDEN, out_dim, bias=True)

    def forward(self, x):
        
#         x = torch.cat((x, torch.zeros((x.shape[0], 20)).uniform_().cuda()), 1)
        
        x = torch.relu(((self.fc1(x))))
        x = torch.relu(((self.fc2(x))))
        x = torch.relu(((self.fc3(x))))
        x = torch.relu(((self.fc4(x))))
        x = torch.relu(((self.fc6(x))))
#         x = torch.relu(self.bn7((self.fc7(x))))
#         x = torch.relu(self.bn8((self.fc8(x))))
#         x = torch.relu(self.bn9((self.fc9(x))))

        x = torch.sigmoid(self.fc5(x))+1e-1

        return x


In [28]:
class decoder_MINES(nn.Module):
    def __init__(self, input_dim = 784, HIDDEN = 2000, out_dim = 200):
        super(decoder_MINES, self).__init__()
        self.dim = out_dim
    
#         self.fc1 = nn.Linear(input_dim+20, HIDDEN, bias=True)
        self.fc1 = nn.Linear(input_dim, HIDDEN, bias=True)
        self.bn1 = torch.nn.BatchNorm1d(HIDDEN)
        self.fc2 = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn2 = torch.nn.BatchNorm1d(HIDDEN)
        self.fc3 = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn3 = torch.nn.BatchNorm1d(HIDDEN)
        self.fc4 = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn4 = torch.nn.BatchNorm1d(HIDDEN)
        self.fc6 = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn6 = torch.nn.BatchNorm1d(HIDDEN)
        self.fc7 = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn7 = torch.nn.BatchNorm1d(HIDDEN)

        self.fc8 = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn8 = torch.nn.BatchNorm1d(HIDDEN)
        self.fc9 = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn9 = torch.nn.BatchNorm1d(HIDDEN)

        self.fc5 = nn.Linear(HIDDEN, out_dim, bias=True)

    def forward(self, x):
        
#         x = torch.cat((x, torch.zeros((x.shape[0], 20)).uniform_().cuda()), 1)
        
        x = torch.relu(((self.fc1(x))))
        x = torch.relu(((self.fc2(x))))
        x = torch.relu(((self.fc3(x))))
        x = torch.relu(((self.fc4(x))))
        x = torch.relu(((self.fc6(x))))
#         x = torch.relu(self.bn7((self.fc7(x))))
#         x = torch.relu(self.bn8((self.fc8(x))))
#         x = torch.relu(self.bn9((self.fc9(x))))

        x = (self.fc5(x))

        return x

In [31]:
class encoder(nn.Module):
    def __init__(self, input_dim = 784, HIDDEN = 2000, out_dim = 200):
        super(encoder, self).__init__()
        self.dim = out_dim
    
        self.fc1 = nn.Linear(input_dim+50, HIDDEN, bias=True)
#         self.fc1 = nn.Linear(input_dim, HIDDEN, bias=True)
        self.bn1 = torch.nn.BatchNorm1d(HIDDEN)
        self.fc2 = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn2 = torch.nn.BatchNorm1d(HIDDEN)
        self.fc3 = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn3 = torch.nn.BatchNorm1d(HIDDEN)
        self.fc4 = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn4 = torch.nn.BatchNorm1d(HIDDEN)
        self.fc6 = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn6 = torch.nn.BatchNorm1d(HIDDEN)
        self.fc7 = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn7 = torch.nn.BatchNorm1d(HIDDEN)
        
        self.fc8 = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn8 = torch.nn.BatchNorm1d(HIDDEN)
        self.fc9 = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn9 = torch.nn.BatchNorm1d(HIDDEN)

        self.fc5 = nn.Linear(HIDDEN, out_dim, bias=True)

    def forward(self, x):
        
        x = torch.cat((x, torch.zeros((x.shape[0], 50)).uniform_().cuda()), 1)
        
        x = torch.relu(self.bn1((self.fc1(x))))
        x = torch.relu(self.bn2((self.fc2(x))))
        x = torch.relu(self.bn3((self.fc3(x))))
        x = torch.relu(self.bn4((self.fc4(x))))
        x = torch.relu(self.bn6((self.fc6(x))))
        x = torch.relu(self.bn7((self.fc7(x))))
        x = torch.relu(self.bn8((self.fc8(x))))
        x = torch.relu(self.bn9((self.fc9(x))))

        x = torch.sigmoid(self.fc5(x))

        return x 

In [None]:
for mine_index in [0, 1]:
    for set_index in [0, 1, 2]:
        
        
        torch.cuda.set_device(4)

        def gauss(A,B,var):
            return torch.exp(-((A.unsqueeze(1) - B.unsqueeze(0))**2).mean(2)/(2*var))

        # train_data = run_dataset[1]

        train_data = save_dataset[set_index]

        torch.manual_seed(0)
        np.random.seed(0)

        E = encoder(input_dim = 2, out_dim = 1).cuda()
#         D = decoder(input_dim = 3, out_dim = 1).cuda()

            
            
        if mine_index == 0:
            D = decoder_MINES(input_dim = 3, out_dim = 1).cuda()
        
        if mine_index == 1:
            D = decoder_MINER(input_dim = 3, out_dim = 1).cuda()
            
            
            

        optimizer_E = optim.Adam([
              {'params': E.parameters(), 'lr': 0.00001, 'betas': (0.5, 0.9)},
           ])

        optimizer_D = optim.Adam([
              {'params': D.parameters(), 'lr': 0.00001, 'betas': (0.5, 0.9)},
           ])

        elbo_curve = []
        pdf_curve = []

        data_dim = 2
        center_dim = 30

        for i in range(0, 20001):

            batch_data = torch.from_numpy(train_data[:5000]).float().cuda()

            # apply encoders

            encoded = E(batch_data)

            # add multiple "noise" or just one it is fine

            var_noise = 0.001
            var_samples = 0.001

            batch_noise = batch_data + torch.zeros((batch_data.shape)).cuda().normal_()*np.sqrt(var_noise)
            encoded_noise = encoded + torch.zeros((encoded.shape)).cuda().normal_()*np.sqrt(var_noise)

            index = np.arange(5000)
            np.random.shuffle(index)
            index2 = np.arange(5000)
            np.random.shuffle(index2)

            joint = torch.cat((batch_noise.cuda(), encoded_noise.cuda()), 1)
            disjoint = torch.cat((batch_noise[index].cuda(), encoded_noise[index2].cuda()), 1)

            output_joint = D(joint)
            output_uniform = D(disjoint)
            
            
            if mine_index == 0:
                error = torch.mean(output_joint) - torch.log(torch.mean(torch.exp(output_uniform)))
            if mine_index == 1:
                error = output_joint.mean()/(torch.sqrt((output_uniform**2).mean()))

            
            (-error).backward()

            pdf_curve.append(error.item())

            optimizer_E.step()
            optimizer_D.step()

            optimizer_E.zero_grad()
            optimizer_D.zero_grad()

            if i%100 == 0:
                print(i, error.item())

#                 if i % 500 == 0:
#                     grid_size = 50
#                     x = np.linspace(0.0, 1, grid_size)
#                     y = np.linspace(0.0, 1, grid_size)
#                     X, Y = np.meshgrid(x, y)
#                     grid_points = np.column_stack([X.flatten(), Y.flatten()])
#                     grid_points = torch.from_numpy(grid_points)

#                     output_mean = 0

#                     for n in range(0, 100):

#                         E.eval()
#                         with torch.no_grad():
#                             output = E(grid_points.cuda().float()).detach().cpu().numpy()
#                         E.train() 

#                         output_mean = (output_mean*n+output)/(n+1)

#                     plt.figure(figsize=(3, 3))
#                     heatmap_extent = [0, 1, 0, 1]

#                     plt.imshow(-output_mean.reshape(-1).reshape(50, 50), cmap='coolwarm', extent=[0, 1, 0, 1], origin='lower', aspect='auto', zorder=1)
#                     plt.axis('off') 
#                     plt.show()
                    
                    
                    
                    
        output_mean = 0

        grid_size = 50
        x = np.linspace(0.0, 1, grid_size)
        y = np.linspace(0.0, 1, grid_size)
        X, Y = np.meshgrid(x, y)
        grid_points = np.column_stack([X.flatten(), Y.flatten()])
        grid_points = torch.from_numpy(grid_points)

        for n in range(0, 100):

            E.eval()
            with torch.no_grad():
                output = E(grid_points.cuda().float()).detach().cpu().numpy()
            E.train() 

            output_mean = (output_mean*n+output)/(n+1)

        plt.figure(figsize=(3, 3))
        heatmap_extent = [0, 1, 0, 1]


        plt.imshow(output_mean.reshape(-1).reshape(50, 50), cmap='coolwarm', extent=[0, 1, 0, 1], origin='lower', aspect='auto', zorder=1)
        plt.axis('off')
        plt.savefig('./max_dataset/MI_set_{0}_center_{1}_negative.png'.format(set_index, set_inf), dpi=500, bbox_inches='tight')

        plt.show()


        plt.figure(figsize=(3, 3))
        heatmap_extent = [0, 1, 0, 1]


        plt.imshow(-output_mean.reshape(-1).reshape(50, 50), cmap='coolwarm', extent=[0, 1, 0, 1], origin='lower', aspect='auto', zorder=1)
        plt.axis('off')
        plt.savefig('./max_dataset/MI_set_{0}_center_{1}.png'.format(set_index, set_inf), dpi=500, bbox_inches='tight')

        plt.show()


                    
                    
        spectrum_kica, eigenvectors, U, S, V, measure_kica, measure_hsic = KICA_spectrum(batch_data.detach().cpu(), encoded.detach().cpu(), alpha=1)

        np.save('./max_dataset/measure_kica_{0}_figure_MINEMINE_{1}.npy'.format(set_index, mine_index), measure_kica)
        np.save('./max_dataset/measure_hsic_{0}_figure_MINEMINE_{1}.npy'.format(set_index, mine_index), measure_hsic)
        np.save('./max_dataset/pdf_array_{0}_MINEMINE_{1}.npy'.format(set_index, mine_index), np.array(pdf_curve))