In [None]:
# Finalized DCI computation based on disentanglement_lib

import numpy as np
import pandas as pd
import anndata
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.ensemble import RandomForestClassifier
from scipy.stats import entropy

# === Encoding and Preprocessing ===
def encode_categorical(data):
    encoded_data = np.zeros_like(data, dtype=int)
    for i in range(data.shape[1]):
        le = LabelEncoder()
        encoded_data[:, i] = le.fit_transform(data[:, i])
    return encoded_data

def remove_duplicate_columns(df):
    df_unique = df.T.drop_duplicates().T
    return df_unique

def prep_data(adata, embedding, covariate_keys, test_size=0.25):
    idx_train, idx_test = train_test_split(
        range(len(adata)), test_size=test_size, random_state=42
    )
    cov_df = adata.obs[covariate_keys].copy()
    cov_df = remove_duplicate_columns(cov_df)
    encoded_factors = encode_categorical(cov_df.values)
    embedding_data = embedding.X if isinstance(embedding, anndata.AnnData) else embedding
    mus_train = embedding_data[idx_train]
    mus_test = embedding_data[idx_test]
    ys_train = encoded_factors[idx_train]
    ys_test = encoded_factors[idx_test]
    return mus_train.T, ys_train.T, mus_test.T, ys_test.T

# === Importance Matrix ===
def compute_importance_rf(x_train, y_train, x_test, y_test):
    num_factors = y_train.shape[0]
    num_codes = x_train.shape[0]
    importance_matrix = np.zeros((num_codes, num_factors))
    train_acc = []
    test_acc = []
    for i in range(num_factors):
        model = RandomForestClassifier(random_state=42, max_depth=5)
        model.fit(x_train.T, y_train[i])
        importance_matrix[:, i] = np.abs(model.feature_importances_)
        train_acc.append(np.mean(model.predict(x_train.T) == y_train[i]))
        test_acc.append(np.mean(model.predict(x_test.T) == y_test[i]))
    return importance_matrix, np.mean(train_acc), np.mean(test_acc)

# === Disentanglement ===
def disentanglement_per_code(importance_matrix):
    row_sums = importance_matrix.sum(axis=1, keepdims=True)
    safe_matrix = np.where(row_sums == 0, 1e-11, row_sums)
    normalized = importance_matrix / safe_matrix
    return 1. - entropy(normalized.T + 1e-11, base=importance_matrix.shape[1])

def disentanglement(importance_matrix):
    per_code = disentanglement_per_code(importance_matrix)
    total = importance_matrix.sum()
    if total == 0.:
        return 0.0
    code_importance = importance_matrix.sum(axis=1) / total
    return np.sum(per_code * code_importance)

# === Completeness ===
def completeness_per_factor(importance_matrix):
    return 1. - entropy(importance_matrix + 1e-11, base=importance_matrix.shape[0])

def completeness(importance_matrix):
    per_factor = completeness_per_factor(importance_matrix)
    total = importance_matrix.sum()
    if total == 0.:
        return 0.0
    factor_importance = importance_matrix.sum(axis=0) / total
    return np.sum(per_factor * factor_importance)

# === DCI Master Function ===
def compute_dci(mus_train, ys_train, mus_test, ys_test):
    importance_matrix, train_acc, test_acc = compute_importance_rf(
        mus_train, ys_train, mus_test, ys_test
    )
    threshold = 1e-11
    importance_matrix = np.where(importance_matrix < threshold, 0, importance_matrix)
    return {
    "disentanglement": disentanglement(importance_matrix),
    "completeness": completeness(importance_matrix),
    "importance_matrix": importance_matrix,
    "train_acc": train_acc,
    "test_acc": test_acc,
}


In [None]:
covariate_keys = ["cell_type", "sex", "donor_id"]
mus_train, ys_train, mus_test, ys_test = prep_data(adata, z, covariate_keys)
dci_scores = compute_dci(mus_train, ys_train, mus_test, ys_test)
print(dci_scores)

importance_matrix = dci_scores["importance_matrix"]