In [None]:
###MIG###


from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import KBinsDiscretizer
from sklearn.cluster import KMeans
from sklearn.metrics import mutual_info_score
from sklearn.model_selection import train_test_split

import numpy as np
import anndata
import pandas as pd

def encode_categorical(data):
    encoders = []
    encoded_data = np.zeros_like(data, dtype=int)
    for i in range(data.shape[1]):
        le = LabelEncoder()
        encoded_data[:, i] = le.fit_transform(data[:, i])
        encoders.append(le)
    return encoded_data, encoders

def prep_data(adata, embedding, covriate_keys=None):
    encoded_factors_of_variation, _ = encode_categorical(adata.obs[covriate_keys].values)

    if isinstance(embedding, anndata.AnnData):  
        embedding_data = embedding.X
    else:
        embedding_data = embedding

    mus = np.array(embedding_data)
    ys = np.array(encoded_factors_of_variation)

    return mus.T.copy(), ys.T.copy()


def compute_mig(mus, ys, covariate_names=None):
    """Computes the mutual information gap."""
    return _compute_mig(mus, ys, covariate_names)

def _compute_mig(mus, ys, covariate_names=None):
    """Computes MIG score based on latent codes and covariates."""
    score_dict = {}
    discretized_mus = make_discretizer(mus, discretizer_fn=_histogram_discretize)
   # print("Sample Discretized Latent Variables:\n", discretized_mus[:, :5])
    m = discrete_mutual_info(discretized_mus, ys)

    if covariate_names is None:
        covariate_names = [f"Covariate {j}" for j in range(m.shape[1])]
        
    for j in range(m.shape[1]):
        top_indices = np.argsort(m[:, j])[::-1][:3]
        top_scores = m[top_indices, j]
        print(f"Top 3 MI scores for covariate '{covariate_names[j]}':")
        for idx, score in zip(top_indices, top_scores):
            print(f"  Latent dim {idx}: MI = {score:.4f}")

    assert m.shape[0] == mus.shape[0]
    assert m.shape[1] == ys.shape[0]

    entropy = discrete_entropy(ys)
    sorted_m = np.sort(m, axis=0)[::-1]

    score_dict["discrete_mig"] = np.mean(
        np.divide(sorted_m[0, :] - sorted_m[1, :], entropy[:])
    )

    print(score_dict)
    print("Entropy values:", entropy)
    return score_dict

def discrete_mutual_info(mus, ys):
    num_codes = mus.shape[0]
    num_factors = ys.shape[0]
    m = np.zeros([num_codes, num_factors])
    
    for i in range(num_codes):
        for j in range(num_factors):
            m[i, j] = mutual_info_score(ys[j, :], mus[i, :])
    
    return m

def discrete_entropy(ys):
    num_factors = ys.shape[0]
    h = np.zeros(num_factors)
    
    for j in range(num_factors):
        h[j] = mutual_info_score(ys[j, :], ys[j, :])
    
    return h

def _identity_discretizer(target, num_bins):
    del num_bins
    return target


def _histogram_discretize(target, num_bins=10):
    discretized = np.zeros_like(target)
    for i in range(target.shape[0]):
        discretized[i, :] = np.digitize(target[i, :], np.histogram(
            target[i, :], num_bins)[1][:-1])
    return discretized


def make_discretizer(target, num_bins=10, discretizer_fn=_histogram_discretize):
    return discretizer_fn(target, num_bins)


def score_disentanglement(adata, embedding_data, embedding_basal, covriate_keys=None, continuous_covriate_keys=None):
    mus, ys = prep_data(adata, embedding_data, covriate_keys=covriate_keys)
    print('Computing MIG')
    mig = compute_mig(mus, ys, covariate_names=covriate_keys)
    return mig, mus, ys

# Run MIG score
mig_1,mus,ys = score_disentanglement(
    adata_normal,
    z,
    None,
    covriate_keys=["cell_type", "sex", "donor_id",]
)


discretized_mus = _histogram_discretize(mus)
m = discrete_mutual_info(discretized_mus, ys)

print("MI matrix shape:", m.shape)
print("Max MI per factor:", np.max(m, axis=0))
print("Which latents have highest MI per factor:", np.argmax(m, axis=0))

print("MIG Score:", mig_1)
