In [1]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
from sklearn.feature_selection import mutual_info_regression
from sklearn.neighbors import KernelDensity
from scipy.stats import multivariate_normal
from scipy.special import logsumexp
from scipy import linalg
import scipy.integrate as integrate
from scipy.special import gamma
import time

from sklearn.mixture import GaussianMixture as GMM
from sklearn.mixture._gaussian_mixture import _estimate_log_gaussian_prob, _compute_precision_cholesky, _estimate_gaussian_covariances_full
from sklearn.utils import check_random_state
from sklearn import cluster
from sklearn.model_selection import KFold

In [2]:
class my_GMM(GMM):
    """
    Custom GMM class based on the sklearn GMM class.
    This allows to work with a GMM with fixed parameters, without fitting it.
    It also allows to estimate MI with a certain number of MC samples.
    The different initialisation types are dealt with separately.
    """
    def __init__(self,
                 n_components=1,
                 covariance_type="full",
                 tol=1e-5,
                 reg_covar=1e-6,
                 max_iter=100,
                 n_init=1,
                 init_params="random",
                 random_state=None,
                 warm_start=False,
                 verbose=0,
                 verbose_interval=10,
                 weights_init=None,
                 means_init=None,
                 precisions_init=None,
                 covariances_init=None
                 ):
        super(my_GMM, self).__init__(n_components=n_components,
                 covariance_type=covariance_type,
                 tol=tol,
                 reg_covar=reg_covar,
                 max_iter=max_iter,
                 n_init=n_init,
                 init_params=init_params,
                 random_state=random_state,
                 warm_start=warm_start,
                 verbose=verbose,
                 verbose_interval=verbose_interval,
                 weights_init=weights_init,
                 means_init=means_init,
                 precisions_init=precisions_init,
                )

        self.means_ = means_init
        self.covariances_ = covariances_init
        self.covariances_init = covariances_init
        self.weights_ = weights_init
        #self.random_state = random_state
        #self.covariance_type = covariance_type
        #self.precisions_cholesky_ = _compute_precision_cholesky(
        #        self.covariances_, self.covariance_type
        #    )


    def score_samples(self, X):
        """Compute the log-likelihood of each sample.
        Parameters
        ----------
        X : array-like of shape (n_samples, n_features)
            List of n_features-dimensional data points. Each row
            corresponds to a single data point.
        Returns
        -------
        log_prob : array, shape (n_samples,)
            Log-likelihood of each sample in `X` under the current model.
        """
        #check_is_fitted(self)
        #X = self._validate_data(X, reset=False)

        return logsumexp(self._estimate_weighted_log_prob(X), axis=1)

    def predict(self, X):
        """Predict the labels for the data samples in X using trained model.
        Parameters
        ----------
        X : array-like of shape (n_samples, n_features)
            List of n_features-dimensional data points. Each row
            corresponds to a single data point.
        Returns
        -------
        labels : array, shape (n_samples,)
            Component labels.
        """
        #check_is_fitted(self)
        #X = self._validate_data(X, reset=False)
        return self._estimate_weighted_log_prob(X).argmax(axis=1)

    def predict_proba(self, X):
        """Evaluate the components' density for each sample.
        Parameters
        ----------
        X : array-like of shape (n_samples, n_features)
            List of n_features-dimensional data points. Each row
            corresponds to a single data point.
        Returns
        -------
        resp : array, shape (n_samples, n_components)
            Density of each Gaussian component for each sample in X.
        """
        # copied here to remove the fitting check
        #check_is_fitted(self)
        #X = self._validate_data(X, reset=False)
        _, log_resp = self._estimate_log_prob_resp(X)
        return np.exp(log_resp)

    def sample(self, n_samples=1):
        """Generate random samples from the fitted Gaussian distribution.
        Parameters
        ----------
        n_samples : int, default=1
            Number of samples to generate.
        Returns
        -------
        X : array, shape (n_samples, n_features)
            Randomly generated sample.
        y : array, shape (nsamples,)
            Component labels.
        """
        # copied here to remove the fitting check
        # check_is_fitted(self)

        if n_samples < 1:
            raise ValueError(
                "Invalid value for 'n_samples': %d . The sampling requires at "
                "least one sample." % (self.n_components)
            )

        _, n_features = self.means_.shape
        rng = check_random_state(self.random_state)
        n_samples_comp = rng.multinomial(n_samples, self.weights_)

        if self.covariance_type == "full":
            X = np.vstack(
                [
                    rng.multivariate_normal(mean, covariance, int(sample))
                    for (mean, covariance, sample) in zip(
                        self.means_, self.covariances_, n_samples_comp
                    )
                ]
            )
        elif self.covariance_type == "tied":
            X = np.vstack(
                [
                    rng.multivariate_normal(mean, self.covariances_, int(sample))
                    for (mean, sample) in zip(self.means_, n_samples_comp)
                ]
            )
        else:
            X = np.vstack(
                [
                    mean + rng.randn(sample, n_features) * np.sqrt(covariance)
                    for (mean, covariance, sample) in zip(
                        self.means_, self.covariances_, n_samples_comp
                    )
                ]
            )

        y = np.concatenate(
            [np.full(sample, j, dtype=int) for j, sample in enumerate(n_samples_comp)]
        )

        return (X, y)

    def score_samples_marginal(self, X, index=0):
        """Compute the log-likelihood of each sample for the marginal model, indexed by either 0 (x) or 1 (y).
        Parameters
        ----------
        X : array-like of shape (n_samples, n_features)
            List of n_features-dimensional data points. Each row
            corresponds to a single data point.
        index: integer
            Either 0 (marginal x) or 1 (marginal y).
        Returns
        -------
        log_prob : array, shape (n_samples,)
            Log-likelihood of each sample in `X` under the current model.
        """

        oned_cholesky = np.sqrt(1/self.covariances_[:, index, index]).reshape(-1, 1, 1)
        marginal_logprob = _estimate_log_gaussian_prob(
            X, self.means_[:, index].reshape(-1, 1), oned_cholesky, self.covariance_type
        )

        return logsumexp(np.log(self.weights_) + marginal_logprob, axis=1)


    def estimate_MI_MC(self, MC_samples=100):
        """
        Compute the mutual information (MI) associated with a particular GMM model, using MC integration
        Parameters
        ----------
        MC_samples : integer
            Number of Monte Carlo samples to perform numerical integration of the MI integral.
        Returns
        ----------
        MI : integer
            The value of mutual information.
        -------
        """
        # sample MC samples
        points, clusters = self.sample(MC_samples)
        
        # we first evaluate the log-likelihood for the joint probability
        joint = self.score_samples(points)

        # we then evaluate the marginals; index=0 corresponds to x, index=y corresponds to y
        marginal_x = self.score_samples_marginal(points[:, :1], index=0)
        marginal_y = self.score_samples_marginal(points[:, 1:], index=1)

        MI = np.mean(joint - marginal_x - marginal_y)
        return MI
    
    def fit_predict(self, X, y=None):
        """Estimate model parameters using X and predict the labels for X.
        The method fits the model n_init times and sets the parameters with
        which the model has the largest likelihood or lower bound. Within each
        trial, the method iterates between E-step and M-step for `max_iter`
        times until the change of likelihood or lower bound is less than
        `tol`, otherwise, a :class:`~sklearn.exceptions.ConvergenceWarning` is
        raised. After fitting, it predicts the most probable label for the
        input data points.
        .. versionadded:: 0.20
        Parameters
        ----------
        X : array-like of shape (n_samples, n_features)
            List of n_features-dimensional data points. Each row
            corresponds to a single data point.
        y : Ignored
            Not used, present for API consistency by convention.
        Returns
        -------
        labels : array, shape (n_samples,)
            Component labels.
        """
        X = self._validate_data(X, dtype=[np.float64, np.float32], ensure_min_samples=2)
        if X.shape[0] < self.n_components:
            raise ValueError(
                "Expected n_samples >= n_components "
                f"but got n_components = {self.n_components}, "
                f"n_samples = {X.shape[0]}"
            )
        self._check_initial_parameters(X)

        # if we enable warm_start, we will have a unique initialisation
        do_init = not (self.warm_start and hasattr(self, "converged_"))
        n_init = self.n_init if do_init else 1

        max_lower_bound = -np.inf
        self.converged_ = False

        random_state = check_random_state(self.random_state)

        n_samples, _ = X.shape
        for init in range(n_init):
            self._print_verbose_msg_init_beg(init)

            if do_init:
                self._initialize_parameters(X, random_state)

            lower_bound = -np.inf if do_init else self.lower_bound_

            for n_iter in range(1, self.max_iter + 1):
                #if n_iter==179:
                #    try:
                #        #print(n_iter)
                #        print(np.linalg.eig(self.covariances_[2]))
                #        #print(self.means_[2])
                #        #ind = np.argsort(log_resp[:, 2])[-5:]
                #        #print(X[ind])
                #        #print(log_resp[np.argmax(log_resp[:, 2])])
                #        #plt.hist(log_resp[:, 4])
                #    except:
                #        pass

                prev_lower_bound = lower_bound

                log_prob_norm, log_resp = self._e_step(X)
                self._m_step(X, log_resp)
                lower_bound = self._compute_lower_bound(log_resp, log_prob_norm)

                change = lower_bound - prev_lower_bound
                self._print_verbose_msg_iter_end(n_iter, change)

                if abs(change) < self.tol:
                    self.converged_ = True
                    break

            self._print_verbose_msg_init_end(lower_bound)

            if lower_bound > max_lower_bound or max_lower_bound == -np.inf:
                max_lower_bound = lower_bound
                best_params = self._get_parameters()
                best_n_iter = n_iter

        if not self.converged_:
            warnings.warn(
                "Initialization %d did not converge. "
                "Try different init parameters, "
                "or increase max_iter, tol "
                "or check for degenerate data." % (init + 1),
                ConvergenceWarning,
            )

        self._set_parameters(best_params)
        self.n_iter_ = best_n_iter
        self.lower_bound_ = max_lower_bound

        # Always do a final e-step to guarantee that the labels returned by
        # fit_predict(X) are always consistent with fit(X).predict(X)
        # for any value of max_iter and tol (and any random_state).
        _, log_resp = self._e_step(X)

        return log_resp.argmax(axis=1)
        


In [3]:
# now we also focus on initialising the GMM parameters
# we provide four different initialisation types, which return weights, means and covs
# these will go as input into the GMM class, so that we can ignore whatever happens there

  
def initialize_parameters(X, random_state, n_components=1, s=None, reg_covar=1e-6, init_type='random'):
    """Initialize the model parameters.
    Parameters
    ----------
    X : array-like of shape  (n_samples, n_features)
    random_state : RandomState
        A random number generator instance that controls the random seed used for the method chosen to initialize the parameters.
    n_components: int
        Number of components of the GMM to fit.
    s : float
        If set, sets component variances in the 'random' and 'minmax' cases. 
        If s is not given, it will be set such that the volume of all components
        completely fills the space covered by data.
    init_type : {'random', 'minmax', 'kmeans', 'random_sklearn', 'kmeans_sklearn'}, default='random'
        The method used to initialize the weights, the means and the
        precisions.
        Must be one of:
            'random': weights are set uniformly, covariances are proprtional to identity (with prefactor s^2). 
            For each mean, a data sample is selected at random, and a multivariant Gaussian with variance s^2 offset is added.
            'minmax': same as above, but means are distributed randomly over the range that is covered by data.
            'kmeans': k-means clustering run as in Algorithm 1 from Bloemer & Bujna (arXiv:1312.5946), as implemented by Melchior & Goulding (arXiv:1611.05806)
             WARNING: The result of this call are not deterministic even if rng is set because scipy.cluster.vq.kmeans2 uses its own initialization. 
             TO DO: require scipy > 1.7, and include "seed=random_state" in the kmeans call
            'kmeans_sklearn' : responsibilities are initialized using kmeans.
            'random_sklearn' : responsibilities are initialized randomly.
    reg_covar : float
        The regularization added to the diagonal of the covariance matrices.
    Returns
    ----------
    weights : array, shape (n_components, 1)
        The initial weights of the GMM model.
    means : array, shape (n_components, n_features)
        The initial means of the GMM model.        
    covariances : array, shape (n_components, n_features, n_features)
        The initial covariance matrices of the GMM model.        
    """
    n_samples, n_dim = X.shape

    random_state = check_random_state(random_state)
    if s is None and (init_type=='random' or init_type=='minmax'):
        min_pos = X.min(axis=0)
        max_pos = X.max(axis=0)
        vol_data = np.prod(max_pos-min_pos)
        s = (vol_data / n_components * gamma(n_dim*0.5 + 1))**(1/n_dim) / np.sqrt(np.pi)
        print(f"Scale s set to s={s:.2f}...")

    if init_type == "random":

        weights = np.repeat(1/n_components, n_components)
        # initialize components around data points with uncertainty s
        refs = random_state.randint(0, n_samples, size=n_components)

        means = X[refs] + random_state.multivariate_normal(np.zeros(n_dim), s**2 * np.eye(n_dim), size=n_components)
        
        covariances = np.repeat(s**2 * np.eye(n_dim)[np.newaxis, :, :], n_components, axis=0)

    elif init_type == "minmax":

        weights = np.repeat(1/n_components, n_components)

        min_pos = X.min(axis=0)
        max_pos = X.max(axis=0)
        means = min_pos + (max_pos-min_pos)*random_state.rand(n_components, n_dim)
        
        covariances = np.repeat(s**2 * np.eye(n_dim)[np.newaxis, :, :], n_components, axis=0)

    elif init_type == 'kmeans':
        from scipy.cluster.vq import kmeans2
        center, label = kmeans2(X, n_components)
        weights = np.zeros(n_components)
        means = np.zeros((n_components, n_dim))
        covariances = np.zeros((n_components, n_dim, n_dim))

        for k in range(n_components):
            mask = (label == k)
            weights[k] = mask.sum() / len(X)
            means[k,:] = X[mask].mean(axis=0)
            d_m = X[mask] - means[k,:] 
            # funny way of saying: for each point i, do the outer product
            # of d_m with its transpose and sum over i
            covariances[k,:,:] = (d_m[:, :, None] * d_m[:, None, :]).sum(axis=0) / len(X)

    elif init_type == "random_sklearn":
        resp = random_state.rand(n_samples, n_components)
        resp /= resp.sum(axis=1)[:, np.newaxis]
        nk = resp.sum(axis=0) + 10 * np.finfo(resp.dtype).eps
        
        weights = nk/n_samples
        means = np.dot(resp.T, X) / nk[:, np.newaxis]
        covariances = _estimate_gaussian_covariances_full(resp, X, nk, means, reg_covar)

    elif init_type == "kmeans_sklearn":
        resp = np.zeros((n_samples, n_components))
        label = (
            cluster.KMeans(
                n_clusters=n_components, n_init=1, random_state=random_state
            )
            .fit(X)
            .labels_
        )
        resp[np.arange(n_samples), label] = 1
        nk = resp.sum(axis=0) + 10 * np.finfo(resp.dtype).eps
        
        weights = nk/n_samples
        means = np.dot(resp.T, X) / nk[:, np.newaxis]
        covariances = _estimate_gaussian_covariances_full(resp, X, nk, means, reg_covar)

    else:
        # TO DO: raise error instead of just priting it
        print("Error: initalisation type not specified or not known; it should be one of 'random', 'minmax', 'kmeans', 'random_sklearn', 'kmeans_sklearn'")
        
    precisions = np.empty_like(covariances)
    for i in range(n_components):
        precisions[i] = np.linalg.inv(covariances[i])
        
    return weights, means, covariances, precisions


In [16]:
def MI_procedure_diffconvergence(X, n_components=1, n_folds=5, n_inits=5, init_type='random', reg_covar=1e-6, tol=1e-6):
    """
    Docstring TO DO
    """
    initial_time = time.time()
    # this will be used to store mean validation log-likelihood 
    val_scores_seeds = np.zeros(n_inits)
    train_scores_seeds = np.zeros(n_inits)

    # prepare the folds; note the splitting will be the same for all initialisations
    # the random seed is fixed here, but results should be independent of the exact split
    kf = KFold(n_splits=n_folds, shuffle=True, random_state=42)

    # fix the random seed first
    for r in range(n_inits):

        w_init, m_init, c_init, p_init = initialize_parameters(X, r, n_components=n_components, init_type=init_type)
        validation_scores = []
        training_scores = []
        
        for train_indices, valid_indices in kf.split(X):
            X_training = X[train_indices]
            X_validation = X[valid_indices]
            
            fitted_gmm = my_GMM(n_components=n_components, reg_covar=reg_covar, 
                            tol=tol, max_iter=10000, 
                            random_state=r, weights_init=w_init, 
                            means_init=m_init, precisions_init=p_init).fit(X_training)

            # we take the mean logL per sample, since folds might have slightly different sizes
            val_score = fitted_gmm.score_samples(X_validation).mean()
            train_score = fitted_gmm.score_samples(X_training).mean()

            #print(val_score)
            validation_scores.append(np.copy(val_score))
            training_scores.append(np.copy(train_score))


        # take mean of current seed's val scores
        val_scores_seeds[r] = np.mean(validation_scores)
        train_scores_seeds[r] = np.mean(training_scores)

        #print()
        
    # select seed with highest val score
    best_seed = np.argmax(val_scores_seeds)
    best_val_score = np.max(val_scores_seeds)
    best_train_score = np.max(train_scores_seeds)
    
    return best_seed, best_val_score, best_train_score

### Let's look at MI between first latent and first factor of variation

In [118]:
all_labels = np.load('./labels.npy')
all_latents = np.load('./latents.npy')

In [112]:
dict_labels = {'floor_hue': [10, 0],
         'object_hue': [10, 1],
         'orientation': [15, 2],
         'scale': [8,3],
         'shape': [4, 4],
         'wall_hue': [10, 5]}

latent_id = 3
label_id = 'scale'
label_values = dict_labels[label_id][0]
label_number = dict_labels[label_id][1]

In [113]:
n_inits = 5
n_folds = 3
init_type = 'random_sklearn'
n_bootstrap = 100
MC_samples = 1e5
tol = 1e-5
reg_covar = 1e-15
components_range = 15
patience = 1

In [114]:
# first identify all initialisations for all needed models (1x10 in this case)

tic = time.time()

first_labels = all_labels[:, label_number]
first_latents = all_latents[:, latent_id]

init_params_ = []
for label_value in range(label_values):
    # select latents corresponding to those labels
    current_ids = np.where(first_labels == label_value)
    current_latents = first_latents[current_ids]
    # we need to fit the current latents; this is p(z1|f1 = label_value)
    best_val = -np.inf
    pat_counter = 0
    X = np.reshape(current_latents, (-1, 1))
    for n_components in range(1, components_range+1):
        current_seed, current_val, _ = MI_procedure_diffconvergence(X, n_components=n_components, n_folds=n_folds, 
                                                           init_type=init_type, n_inits=n_inits, tol=tol, reg_covar=reg_covar)

        # check if convergence has been reached based on val score
        if current_val > best_val:
            best_val = current_val
            best_seed = current_seed
            best_components = n_components
            print(n_components, best_val)
        else:
            pat_counter += 1
            if pat_counter >= patience:
                
                print(f'Convergence reached at {best_components} components') 
                w_init, m_init, c_init, p_init = initialize_parameters(X, best_seed, n_components=best_components, init_type=init_type)
                init_params_.append({'w': w_init, 'm': m_init, 'c': c_init, 'p': p_init, 'bc': best_components, 'seed': best_seed})
                break


1 -0.9225725943635975
2 -0.872994249281602
3 -0.8321391906262585
4 -0.8311610133708109
Convergence reached at 4 components
1 -0.8083382910562578
2 -0.7885373212853821
3 -0.7882585784567695
4 -0.7614027289823787
Convergence reached at 4 components
1 -0.8326579785064778
2 -0.7620893425803016
Convergence reached at 2 components
1 -0.9233398813296141
2 -0.8265421218357178
Convergence reached at 2 components
1 -1.0214874337539877
2 -0.9080152775853855
3 -0.8772289751526117
4 -0.8723631583496928
5 -0.8623213742366582
Convergence reached at 5 components
1 -1.1582618172226533
2 -1.0604359380570296
3 -0.9613187116535062
Convergence reached at 3 components
1 -1.2758897339915867
2 -1.1467730469207205
3 -1.0940507376391382
Convergence reached at 3 components
1 -1.3914381284494821
2 -1.3255693795169412
3 -1.297369495848088
Convergence reached at 3 components


In [115]:
# then bootstrap and calculate MI

n_bootstrap = 100
MC_samples = 1e5
MI_estimates = np.zeros(n_bootstrap)

for i in range(n_bootstrap):
    # we use i to change the seed so that the results will be fully reproducible
    rng = np.random.default_rng(i)

    all_gmms = []
    for label_value in range(label_values):
        # select latents corresponding to those labels
        current_ids = np.where(first_labels == label_value)
        current_latents = first_latents[current_ids]
        X = np.reshape(current_latents, (-1, 1))

        n_components = init_params_[label_value]['bc']
        w_init = init_params_[label_value]['w']
        m_init = init_params_[label_value]['m']
        p_init = init_params_[label_value]['p']
        seed = init_params_[label_value]['seed']

        X_bs = rng.choice(X, X.shape[0])
        gmm = my_GMM(n_components=n_components, reg_covar=reg_covar, 
                    tol=tol, max_iter=10000, 
                    random_state=seed, weights_init=w_init, 
                    means_init=m_init, precisions_init=p_init).fit(X_bs)
        
        all_gmms.append(gmm)
        
    # estimate MI using MC
    MI = 0 
    for label_value in range(label_values):
        samples = all_gmms[label_value].sample(MC_samples)[0]
        log_p = all_gmms[label_value].score_samples(samples)
        p_ = 0
        for inner_label_value in range(label_values):
            p_ += np.exp(all_gmms[inner_label_value].score_samples(samples))
        p = np.log(p_/label_values)
        
        MI += np.mean(log_p - p)
    
    MI_estimates[i] = MI/label_values


In [116]:
time.time() - tic 

497.16844725608826

In [117]:
np.mean(MI_estimates), np.std(MI_estimates)

(0.48360447744050394, 0.008677881488270743)

### Now do all of them

In [127]:
label_list = ['floor_hue', 'object_hue', 'orientation', 'scale', 'shape', 'wall_hue']
n_bootstrap = 10
0
n_inits = 5
n_folds = 3
init_type = 'random_sklearn'
n_bootstrap = 100
MC_samples = 1e5
tol = 1e-5
reg_covar = 1e-15
components_range = 15
patience = 1
MC_samples = 1e5

all_MI_estimates = np.zeros((len(label_list), 6, n_bootstrap))
for label_id in label_list:
    for latent_id in range(6):
        
        label_values = dict_labels[label_id][0]
        label_number = dict_labels[label_id][1]

        # first identify all initialisations for all needed models (1x10 in this case)

        tic = time.time()

        first_labels = all_labels[:, label_number]
        first_latents = all_latents[:, latent_id]

        init_params_ = []
        for label_value in range(label_values):
            # select latents corresponding to those labels
            current_ids = np.where(first_labels == label_value)
            current_latents = first_latents[current_ids]
            # we need to fit the current latents; this is p(z1|f1 = label_value)
            best_val = -np.inf
            pat_counter = 0
            X = np.reshape(current_latents, (-1, 1))
            for n_components in range(1, components_range+1):
                current_seed, current_val, _ = MI_procedure_diffconvergence(X, n_components=n_components, n_folds=n_folds, 
                                                                   init_type=init_type, n_inits=n_inits, tol=tol, reg_covar=reg_covar)

                # check if convergence has been reached based on val score
                if current_val > best_val:
                    best_val = current_val
                    best_seed = current_seed
                    best_components = n_components
                    print(n_components, best_val)
                else:
                    pat_counter += 1
                    if pat_counter >= patience:

                        print(f'Convergence reached at {best_components} components') 
                        w_init, m_init, c_init, p_init = initialize_parameters(X, best_seed, n_components=best_components, init_type=init_type)
                        init_params_.append({'w': w_init, 'm': m_init, 'c': c_init, 'p': p_init, 'bc': best_components, 'seed': best_seed})
                        break


        # then bootstrap and calculate MI

        MI_estimates = np.zeros(n_bootstrap)

        for i in range(n_bootstrap):
            # we use i to change the seed so that the results will be fully reproducible
            rng = np.random.default_rng(i)

            all_gmms = []
            for label_value in range(label_values):
                # select latents corresponding to those labels
                current_ids = np.where(first_labels == label_value)
                current_latents = first_latents[current_ids]
                X = np.reshape(current_latents, (-1, 1))

                n_components = init_params_[label_value]['bc']
                w_init = init_params_[label_value]['w']
                m_init = init_params_[label_value]['m']
                p_init = init_params_[label_value]['p']
                seed = init_params_[label_value]['seed']

                X_bs = rng.choice(X, X.shape[0])
                gmm = my_GMM(n_components=n_components, reg_covar=reg_covar, 
                            tol=tol, max_iter=10000, 
                            random_state=seed, weights_init=w_init, 
                            means_init=m_init, precisions_init=p_init).fit(X_bs)

                all_gmms.append(gmm)

            # estimate MI using MC
            MI = 0 
            for label_value in range(label_values):
                samples = all_gmms[label_value].sample(MC_samples)[0]
                log_p = all_gmms[label_value].score_samples(samples)
                p_ = 0
                for inner_label_value in range(label_values):
                    p_ += np.exp(all_gmms[inner_label_value].score_samples(samples))
                p = np.log(p_/label_values)

                MI += np.mean(log_p - p)

            MI_estimates[i] = MI/label_values

        print(time.time() - tic )
        print(label_number, latent_id, np.mean(MI_estimates), np.std(MI_estimates))
        all_MI_estimates[label_number, latent_id] = MI_estimates

1 -1.4278538146226774
2 -1.4278538135308911
Convergence reached at 2 components
1 -1.4046774618230098
Convergence reached at 1 components
1 -1.4325455522509278
Convergence reached at 1 components
1 -1.424767729705979
Convergence reached at 1 components
1 -1.4283032917651177
Convergence reached at 1 components
1 -1.4125616410525061
Convergence reached at 1 components
1 -1.4122224524532392
2 -1.4122222038837862
Convergence reached at 2 components
1 -1.4246574684293973
2 -1.4246572975602465
Convergence reached at 2 components
1 -1.4149368854946884
Convergence reached at 1 components
1 -1.437017648937302
2 -1.4370175583514921
Convergence reached at 2 components
144.1802794933319
0 0 0.0003306559476049711 0.00010236223990093201
1 -1.4015621613883826
Convergence reached at 1 components
1 -1.4110030157116442
Convergence reached at 1 components
1 -1.4039307835017414
2 -1.4039302951910182
Convergence reached at 2 components
1 -1.4203862852613272
2 -1.4203862837339727
Convergence reached at 2 co

Convergence reached at 4 components
1 -1.4216620745122215
2 -1.421661089920449
Convergence reached at 2 components
1 -1.4142816477857867
2 -1.4142788947002385
3 -1.4142781441638632
Convergence reached at 3 components
310.7913281917572
1 3 0.0004720170854559289 0.00013773354353362158
1 1.2084234382091399
2 1.2084234759043824
Convergence reached at 2 components
1 0.8775035817313525
Convergence reached at 1 components
1 -0.053926999987213155
2 -0.05391001451392313
Convergence reached at 2 components
1 -0.04888620565854201
2 -0.048885129692210955
Convergence reached at 2 components
1 0.8816826899658352
2 0.8816854760484963
Convergence reached at 2 components
1 1.267984973192198
2 1.2679870476356925
Convergence reached at 2 components
1 1.335510741790392
2 1.3355123195954979
Convergence reached at 2 components
1 1.3648887237557394
2 1.36488899800387
3 1.3648892619598294
Convergence reached at 3 components
1 1.3537729231313358
2 1.353785919078345
Convergence reached at 2 components
1 1.31973

Convergence reached at 1 components
1 -1.433899498551363
Convergence reached at 1 components
1 -1.43550651315188
2 -1.4355061923665395
Convergence reached at 2 components
1 -1.445844014652316
Convergence reached at 1 components
1 -1.45596831225422
2 -1.4559682499512345
Convergence reached at 2 components
1 -1.442023564396135
Convergence reached at 1 components
1 -1.4537502470369275
Convergence reached at 1 components
244.18999767303467
2 5 0.0004907757482120115 0.00013123776522044176
1 -1.4261821793131622
Convergence reached at 1 components
1 -1.4163259982863392
Convergence reached at 1 components
1 -1.4283071348251133
Convergence reached at 1 components
1 -1.4275667029348487
Convergence reached at 1 components
1 -1.40889240249167
Convergence reached at 1 components
1 -1.4336732205697222
Convergence reached at 1 components
1 -1.408945293606255
Convergence reached at 1 components
1 -1.423650919913966
Convergence reached at 1 components
60.357749938964844
3 0 0.00029810367514393146 9.184

116.44382071495056
5 2 0.0004858506047150676 0.00010568306644996125
1 -1.4037691984553218
2 -1.403757503964939
Convergence reached at 2 components
1 -1.4197695240850214
2 -1.419764260987395
Convergence reached at 2 components
1 -1.4173963157391352
2 -1.4173892233247922
Convergence reached at 2 components
1 -1.4357218425588218
2 -1.4357149878054996
3 -1.435713296895356
4 -1.4357126499975357
Convergence reached at 4 components
1 -1.4414778357411553
2 -1.4414734010588648
3 -1.4414727669259235
Convergence reached at 3 components
1 -1.4173764079529942
2 -1.4173739474900116
3 -1.4173729204842764
4 -1.417363785611747
Convergence reached at 4 components
1 -1.4332013166342243
2 -1.433190972921446
Convergence reached at 2 components
1 -1.4176276576209974
2 -1.4176248604547574
Convergence reached at 2 components
1 -1.4366282844644591
Convergence reached at 1 components
1 -1.4283205064983155
2 -1.428318164804302
3 -1.4283062186428916
Convergence reached at 3 components
249.73053812980652
5 3 0.000

In [129]:
np.mean(all_MI_estimates, axis=2), np.std(all_MI_estimates, axis=2)

(array([[3.30655948e-04, 3.52582959e-04, 2.29500784e+00, 4.50452214e-04,
         5.44391275e-04, 5.35197628e-04],
        [3.19895069e-04, 1.25555408e-03, 3.53057244e-04, 4.72017085e-04,
         2.22267396e+00, 3.39152809e-04],
        [1.95500989e+00, 5.75835089e-04, 4.08510523e-04, 1.11940241e-03,
         5.01270103e-04, 4.90775748e-04],
        [2.98103675e-04, 4.78504531e-02, 2.73146690e-04, 4.83604477e-01,
         1.87243202e-03, 4.58129904e-04],
        [7.90845796e-05, 6.90157123e-02, 2.46522328e-04, 2.92330021e-01,
         1.74966487e-04, 1.15425389e-04],
        [6.28729930e-04, 6.13152621e-04, 4.85850605e-04, 6.07341147e-04,
         4.57732774e-04, 2.28933610e+00]]),
 array([[1.02362240e-04, 1.05465737e-04, 2.53718333e-04, 1.20422458e-04,
         1.31929553e-04, 1.19856802e-04],
        [9.11831765e-05, 3.97352189e-04, 9.60674294e-05, 1.37733544e-04,
         1.37660047e-03, 9.44933331e-05],
        [3.12532781e-03, 1.22393921e-04, 9.20244810e-05, 5.25133029e-04,
     

In [128]:
np.save('./MI_latents_labels.npy', all_MI_estimates)
all_MI_estimates.shape

(6, 6, 100)

### Also look at disentanglement between latents

In [123]:
# let's look at the disentanglement among latents
# we use a simple way to stop adding components, only for the sake of this argument

MI_latents = np.zeros((6, 6, n_bootstrap))

for latbin1 in range(6):
    samples1 = all_latents[:, latbin1]
    for latbin2 in range(6):
        if latbin2 <= latbin1:
            continue
        samples2 = all_latents[:, latbin2]

        X = np.stack((samples1, samples2), axis=0).T

        # now we do this for many components, from 1 to 15
        n_inits = 5
        n_folds = 3
        init_type = 'random_sklearn'
        n_bootstrap = 100
        MC_samples = 1e5
        tol = 1e-5
        reg_covar = 1e-15
        components_range = 15
        all_MI_estimates = np.zeros((components_range, n_bootstrap))

        best_val = -np.inf

        initial_time = time.time()
        for n_components in range(1, components_range+1):
            current_seed, current_val, _ = MI_procedure_diffconvergence(X, n_components=n_components, n_folds=n_folds, 
                                                               init_type=init_type, n_inits=n_inits, tol=tol, reg_covar=reg_covar)

            # check if convergence has been reached based on val score
            if current_val > best_val:
                best_val = current_val
                best_seed = current_seed
                print(n_components, best_val)
            else:
                # if val score has not increased, then we should stop and calculate MI with the previous parameters
                best_components = n_components-1
                print(f'Convergence reached at {best_components} components') 
                w_init, m_init, c_init, p_init = initialize_parameters(X, best_seed, n_components=best_components, init_type=init_type)
                MI_estimates = np.zeros(n_bootstrap)

                # bootstrap available samples
                for i in range(n_bootstrap):
                    # we use i to change the seed so that the results will be fully reproducible
                    rng = np.random.default_rng(i)
                    X_bs = rng.choice(X, X.shape[0])
                    gmm = my_GMM(n_components=best_components, reg_covar=reg_covar, 
                                tol=tol, max_iter=10000, 
                                random_state=best_seed, weights_init=w_init, 
                                means_init=m_init, precisions_init=p_init).fit(X_bs)

                    # in case of "warm start", uncomment next line
                    #w_init, m_init, c_init, p_init = gmm.weights_, gmm.means_, gmm.covariances_, gmm.precisions_

                    current_MI_estimate = gmm.estimate_MI_MC(MC_samples=MC_samples)
                    MI_estimates[i] = current_MI_estimate
                break

        print(f'Total time to run the procedure: {time.time()-initial_time:.2f} s')
        print()
        MI_latents[latbin1, latbin2] = MI_estimates
        print(np.mean(MI_estimates), np.std(MI_estimates))

1 -2.8351839731527075
Convergence reached at 1 components
Total time to run the procedure: 6.78 s

2.1854010013858267e-05 2.8188332567922855e-05
1 -2.858851631122507
Convergence reached at 1 components
Total time to run the procedure: 6.68 s

9.105459309968977e-06 1.816437594410711e-05
1 -2.8466428486092608
2 -2.8466419399345155
Convergence reached at 2 components
Total time to run the procedure: 16.35 s

9.915161717324454e-05 6.702912576368392e-05
1 -2.8627978572886206
Convergence reached at 1 components
Total time to run the procedure: 6.72 s

7.275620813958173e-06 1.4821284488435358e-05
1 -2.8696041505675383
2 -2.8696037266052596
Convergence reached at 2 components
Total time to run the procedure: 16.23 s

1.4595423324363134e-05 3.8313708857700765e-05
1 -2.850887724324322
Convergence reached at 1 components
Total time to run the procedure: 6.77 s

1.8293658419239556e-05 2.7888542865517755e-05
1 -2.8388687422340992
2 -2.838866166058255
3 -2.8388636051302023
Convergence reached at 3 c

In [124]:
np.mean(MI_latents, axis=2), np.std(MI_latents, axis=2)

(array([[0.00000000e+00, 2.18540100e-05, 9.10545931e-06, 9.91516172e-05,
         7.27562081e-06, 1.45954233e-05],
        [0.00000000e+00, 0.00000000e+00, 1.82936584e-05, 1.70746017e-05,
         3.08572227e-04, 1.80598031e-05],
        [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 2.81221398e-05,
         5.83909497e-05, 1.37628371e-04],
        [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
         1.31509619e-05, 5.60589660e-05],
        [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
         0.00000000e+00, 1.57033080e-04],
        [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
         0.00000000e+00, 0.00000000e+00]]),
 array([[0.00000000e+00, 2.81883326e-05, 1.81643759e-05, 6.70291258e-05,
         1.48212845e-05, 3.83137089e-05],
        [0.00000000e+00, 0.00000000e+00, 2.78885429e-05, 2.04134082e-05,
         1.05144476e-04, 2.90489039e-05],
        [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 3.28454610e-05,
     

In [126]:
np.save('./MI_latents_latents.npy', MI_latents)
MI_latents.shape

(6, 6, 100)