In [None]:
import os 
os.chdir("../")
import warnings

In [None]:
import scanpy as sc
import torch
import scarches as sca
from scarches.dataset.trvae.data_handling import remove_sparsity
import matplotlib.pyplot as plt
import numpy as np
import gdown

In [None]:
sc.settings.set_figure_params(dpi=200, frameon=False)
sc.set_figure_params(dpi=200)
sc.set_figure_params(figsize=(4, 4))
torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)

In [None]:
adata= sc.read('/work/trvae_new/New_fixed_data/focal_cortical_processed_RAW.h5ad')

In [None]:
adata

In [None]:
# Train/test split
from sklearn.model_selection import train_test_split
train_ids, test_ids = train_test_split(adata.obs_names, test_size=0.1, random_state=42)
adata.obs["split"] = "train"
adata.obs.loc[test_ids, "split"] = "test"

train_adata = adata[adata.obs["split"] == "train"]
test_adata = adata[adata.obs["split"] == "test"]

In [None]:
early_stopping_kwargs = {
    "early_stopping_metric": "val_unweighted_loss",
    "threshold": 0,
    "patience": 20,
    "reduce_lr": True,
    "lr_patience": 13,
    "lr_factor": 0.1,
}



In [None]:
train_adata.obs["dummy_condition"] = "same_condition"
test_adata.obs["dummy_condition"] = "same_condition"
adata.obs["dummy_condition"] = "same_condition"

In [None]:
trvae = sca.models.TRVAE(
    adata=train_adata,
    condition_key="dummy_condition",
    conditions=["same_condition"],  # only one condition
    hidden_layer_sizes=[128, 128],
)
trvae.train(n_epochs=300, alpha_epoch_anneal=200, early_stopping_kwargs=early_stopping_kwargs)

In [None]:
trvae

In [None]:
trvae.save("trvae_new/fixed_models/trvae_focal_cortical_raw_model_newconditions")

In [None]:
trvae.load("/work/trvae_new/trvae_newpredict/trvae_new/fixed_models/focal_cortical_new_RAW_trVAE", adata=train_adata, map_location=torch.device("cpu"))


In [None]:
trvae.save("/work/trvae_new/trvae_newpredict/trvae_new/fixed_models/focal_cortical_new_RAW_trVAE")

In [None]:
model=trvae

In [None]:
from scarches.trainers.trvae._utils import make_dataset,custom_collate

In [None]:
def predict_trvae(model, adata, condition_key, batch_size=128):
    # evaluation mode
    model.model.eval()

    # Create a dataset and dataloader for prediction
    predict_data, _ = make_dataset(
        adata,
        train_frac=1.0,
        condition_key=condition_key,
        cell_type_keys=None, 
        condition_encoder=model.model.condition_encoder,
        cell_type_encoder=None, 
    )
    # Create dataloader 
    dataloader = torch.utils.data.DataLoader(
        dataset=predict_data,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=custom_collate,
        num_workers=0,
    )

    # store results
    latent_list = []
    reconstructed_list = []

    device = next(model.model.parameters()).device

    # Perform prediction, moves each part of the data that the device the model is trained on 
    with torch.no_grad():
        for batch_iter, batch_data in enumerate(dataloader):
            for key, batch in batch_data.items():
                batch_data[key] = batch.to(device)
            # Get latent
            sf = np.ravel(batch_data["x"].sum(1))
            sf=torch.tensor(sf,device=batch_data["x"].device)
            size_factor_view = sf.unsqueeze(1).expand(batch_data["x"].size(0), batch_data["x"].size(1))
            
            x_log = torch.log(1 + batch_data["x"])
            z1_mean, z1_log_var = model.model.encoder(x_log, batch_data["batch"])
            latent = model.model.sampling(z1_mean, z1_log_var)
            latent_list.append(latent.cpu().numpy())


            # Get recon, NB, takes latent space from encoder and decodes it
            outputs = model.model.decoder(latent, batch_data["batch"])
            recon_x, _ = outputs

            sf_rate = size_factor_view * recon_x


            reconstructed_list.append(sf_rate.cpu().numpy())

            

    latent = np.concatenate(latent_list, axis=0)
    reconstructed = np.concatenate(reconstructed_list, axis=0)

    return latent, reconstructed

In [None]:
# With Gpu run this instead: 

def predict_trvae(model, adata, condition_key, batch_size=128):
    # evaluation mode
    model.model.eval()

    # Create a dataset and dataloader for prediction
    predict_data, _ = make_dataset(
        adata,
        train_frac=1.0,
        condition_key=condition_key,
        cell_type_keys=None, 
        condition_encoder=model.model.condition_encoder,
        cell_type_encoder=None, 
    )
    # Create dataloader 
    dataloader = torch.utils.data.DataLoader(
        dataset=predict_data,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=custom_collate,
        num_workers=0,
    )

    # store results
    latent_list = []
    reconstructed_list = []


    # Perform prediction, moves each part of the data that the device the model is trained on 
    with torch.no_grad():
        for batch_data in dataloader:
            for k,v in batch_data.items():
                batch_data[k] = v.to(model.trainer.device)

            # sum across features → shape [batch_size]
            sf = batch_data["x"].sum(dim=1)  
            # expand into [batch_size, n_genes]
            size_factor_view = sf.unsqueeze(1).expand(
                batch_data["x"].size(0),
                batch_data["x"].size(1)
            )

            # log‐transform
            x_log = torch.log1p(batch_data["x"])
            z1_mean, z1_log_var = model.model.encoder(x_log, batch_data["batch"])
            latent = model.model.sampling(z1_mean, z1_log_var)
            latent_list.append(latent.cpu().numpy())

            outputs = model.model.decoder(latent, batch_data["batch"])
            recon_x, _ = outputs
            sf_rate = size_factor_view * recon_x
            reconstructed_list.append(sf_rate.cpu().numpy())


            

    latent = np.concatenate(latent_list, axis=0)
    reconstructed = np.concatenate(reconstructed_list, axis=0)

    return latent, reconstructed

In [None]:
latent,rec = predict_trvae(model,test_adata,condition_key="dummy_condition")

In [None]:
latent_2, rec_2 = predict_trvae(model, adata, condition_key="dummy_condition")

In [None]:
rec.sum(axis=1)

In [None]:
adata_2 = adata[test_adata.obs_names].X

# Convert to dense if it's sparse
if not isinstance(adata_2, np.ndarray):
    print("Converting y_true from sparse to dense.")
    adata_2 = adata_2.toarray()



# Now flatten
adata_2_flat = adata_2.flatten()
#rec_2_flat = rec_2.flatten()

In [None]:
# If rec is an AnnData object, extract the X attribute (i.e., the data matrix)
import anndata
if isinstance(rec, anndata.AnnData):
    rec = rec.X

# Now, rec should be a numpy array or sparse matrix, which is what obsm expects
test_adata.obsm["X_reconstructed"] = rec

# Save the entire object with the reconstructed data
test_adata.write("adata_post_with_latent_and_recon_focal_cortical_trVAE_RAW_asde.h5ad")

In [None]:
import os
print("CWD:", os.getcwd(), "Writable?", os.access(os.getcwd(), os.W_OK))

# 1) copy to avoid view‐warning
test_adata = test_adata.copy()
test_adata.obsm["X_reconstructed"] = rec

# 2) write to /tmp (or somewhere you have access)
outfn = "/work/trvae_new/trvae_newpredict/2_adata_post_with_latent_and_reconstructed_focal_cortical_trVAEasdas.h5ad"
test_adata.write(outfn)
print("Wrote to", outfn)


In [None]:
from sklearn.metrics import (
    mean_absolute_error,
    mean_squared_error,
    r2_score,
    mutual_info_score
)

In [None]:
# R2 or R2 adj
# Flatten arrays it is needed, depends on the dimensionality
adata_2_flat = adata_2.flatten()
rec_2_flat = rec.flatten()


r_square = r2_score(adata_2_flat, rec_2_flat)
print("R2:", r_square)

In [None]:
# MSE

mse = mean_squared_error(adata_2, rec)
print(mse)

In [None]:
# MAE 

mae = mean_absolute_error(adata_2, rec)
print(f"Mean absolute error (MAE): {mae}")

In [None]:
#### Disentanglement ####





In [None]:
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import KBinsDiscretizer
from sklearn.cluster import KMeans
from sklearn.metrics import mutual_info_score
from sklearn.model_selection import train_test_split
###MIG###
import numpy as np
import anndata
import pandas as pd

def encode_categorical(data):
    encoders = []
    encoded_data = np.zeros_like(data, dtype=int)
    for i in range(data.shape[1]):
        le = LabelEncoder()
        encoded_data[:, i] = le.fit_transform(data[:, i])
        encoders.append(le)
    return encoded_data, encoders

def prep_data(adata, embedding, covriate_keys=None, continuous_covriate_keys=None, test_size=0.25):
    idx_train, idx_test = train_test_split(
        range(len(adata.obs_names)), test_size=test_size, random_state=42
    )
    print("Splitting complete.")
    
    encoded_factors_of_variation, _ = encode_categorical(adata.obs[covriate_keys].values)
  #  print("Encoded factors of variation:", np.unique(encoded_factors_of_variation, axis=0))
   # print("Encoded factors of variation sample:", encoded_factors_of_variation[:5])
   # print("Categorical encoding complete.")
    
    if isinstance(embedding, anndata.AnnData):  
        embedding_data = embedding.X
    else:
        embedding_data = embedding

   # print("Embedding shape:", embedding_data.shape)
  #  print("Number of train indices:", len(idx_train))
   # print("Number of test indices:", len(idx_test))
    
    mus_train = np.array(embedding_data[idx_train])
    ys_train = np.array(encoded_factors_of_variation[idx_train])
    mus_test = np.array(embedding_data[idx_test])
    ys_test = np.array(encoded_factors_of_variation[idx_test])
    
   # print("mus_train shape:", mus_train.shape)
   # print("ys_train shape:", ys_train.shape)
   # print("mus_test shape:", mus_test.shape)
   # print("ys_test shape:", ys_test.shape)
   # print("Sample of mus_train:", mus_train[:, :5])
   # print("Sample of ys_train:", ys_train[:, :5])
    #print("Min/Max of mus_train:", mus_train.min(), mus_train.max())
   # print("Unique values in ys_train:", np.unique(ys_train))

    return mus_train.T.copy(), ys_train.T.copy(), mus_test.T.copy(), ys_test.T.copy()

def compute_mig(mus_train, ys_train, covariate_names=None):
    """Computes the mutual information gap."""
    return _compute_mig(mus_train, ys_train, covariate_names)

def _compute_mig(mus_train, ys_train, covariate_names=None):
    """Computes MIG score based on latent codes and covariates."""
    score_dict = {}
    discretized_mus = make_discretizer(mus_train, discretizer_fn=_histogram_discretize)
   # print("Sample Discretized Latent Variables:\n", discretized_mus[:, :5])
    
    m = discrete_mutual_info(discretized_mus, ys_train)

    if covariate_names is None:
        covariate_names = [f"Covariate {j}" for j in range(m.shape[1])]
        
    for j in range(m.shape[1]):
        top_indices = np.argsort(m[:, j])[::-1][:3]
        top_scores = m[top_indices, j]
        print(f"Top 3 MI scores for covariate '{covariate_names[j]}':")
        for idx, score in zip(top_indices, top_scores):
            print(f"  Latent dim {idx}: MI = {score:.4f}")

    assert m.shape[0] == mus_train.shape[0]
    assert m.shape[1] == ys_train.shape[0]

    entropy = discrete_entropy(ys_train)
    sorted_m = np.sort(m, axis=0)[::-1]

    score_dict["discrete_mig"] = np.mean(
        np.divide(sorted_m[0, :] - sorted_m[1, :], entropy[:])
    )

    print("Þetta er score:", score_dict)
    print("Entropy values:", entropy)
    return score_dict

def discrete_mutual_info(mus, ys):
    num_codes = mus.shape[0]
    num_factors = ys.shape[0]
    m = np.zeros([num_codes, num_factors])
    
    for i in range(num_codes):
        for j in range(num_factors):
            m[i, j] = mutual_info_score(ys[j, :], mus[i, :])
    
    return m

def discrete_entropy(ys):
    num_factors = ys.shape[0]
    h = np.zeros(num_factors)
    
    for j in range(num_factors):
        h[j] = mutual_info_score(ys[j, :], ys[j, :])
    
    return h

def _identity_discretizer(target, num_bins):
    del num_bins
    return target

def make_discretizer(target, num_bins=10, discretizer_fn=_identity_discretizer):
    return discretizer_fn(target, num_bins)

def _histogram_discretize(target, num_bins=10):
    discretized = np.zeros_like(target)
    for i in range(target.shape[0]):
        discretized[i, :] = np.digitize(target[i, :], np.histogram(
            target[i, :], num_bins)[1][:-1])
    return discretized

def k_means_discretize(target, num_clusters=10):
    discretized = np.zeros_like(target)
    for i in range(target.shape[0]):
        latent_variable = target[i, :].reshape(-1,1)
        kmeans = KMeans(n_clusters = num_clusters, random_state=0)
        kmeans.fit(latent_variable)
        discretized[i,:]=kmeans.labels_
    return discretized

def score_disentanglement(adata, embedding_data, embedding_basal, covriate_keys=None, continuous_covriate_keys=None, test_size=0.25):
    mus_train, ys_train, mus_test, ys_test = prep_data(adata, embedding_data, covriate_keys=covriate_keys)
    print('Computing MIG')
    mig = compute_mig(mus_train, ys_train, covariate_names=covriate_keys)
    return mig

# Run MIG score
mig_1 = score_disentanglement(
    adata,
    latent_2,
    None,
    covriate_keys=["cell_type", "tissue", "development_stage", "donor_id", "development_stage_ontology_term_id", "lateralization"]
)

print("MIG Score:", mig_1)


In [None]:
# Finalized DCI computation based on disentanglement_lib

import numpy as np
import pandas as pd
import anndata
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.ensemble import RandomForestClassifier
from scipy.stats import entropy

# === Encoding and Preprocessing ===
def encode_categorical(data):
    encoded_data = np.zeros_like(data, dtype=int)
    for i in range(data.shape[1]):
        le = LabelEncoder()
        encoded_data[:, i] = le.fit_transform(data[:, i])
    return encoded_data

def remove_duplicate_columns(df):
    df_unique = df.T.drop_duplicates().T
    return df_unique

def prep_data(adata, embedding, covariate_keys, test_size=0.25):
    idx_train, idx_test = train_test_split(
        range(len(adata)), test_size=test_size, random_state=42
    )
    cov_df = adata.obs[covariate_keys].copy()
    cov_df = remove_duplicate_columns(cov_df)
    encoded_factors = encode_categorical(cov_df.values)
    embedding_data = embedding.X if isinstance(embedding, anndata.AnnData) else embedding
    mus_train = embedding_data[idx_train]
    mus_test = embedding_data[idx_test]
    ys_train = encoded_factors[idx_train]
    ys_test = encoded_factors[idx_test]
    return mus_train.T, ys_train.T, mus_test.T, ys_test.T

# === Importance Matrix ===
def compute_importance_rf(x_train, y_train, x_test, y_test):
    num_factors = y_train.shape[0]
    num_codes = x_train.shape[0]
    importance_matrix = np.zeros((num_codes, num_factors))
    train_acc = []
    test_acc = []
    for i in range(num_factors):
        model = RandomForestClassifier(random_state=42, max_depth=5)
        model.fit(x_train.T, y_train[i])
        importance_matrix[:, i] = np.abs(model.feature_importances_)
        train_acc.append(np.mean(model.predict(x_train.T) == y_train[i]))
        test_acc.append(np.mean(model.predict(x_test.T) == y_test[i]))
    return importance_matrix, np.mean(train_acc), np.mean(test_acc)

# === Disentanglement ===
def disentanglement_per_code(importance_matrix):
    row_sums = importance_matrix.sum(axis=1, keepdims=True)
    safe_matrix = np.where(row_sums == 0, 1e-11, row_sums)
    normalized = importance_matrix / safe_matrix
    return 1. - entropy(normalized.T + 1e-11, base=importance_matrix.shape[1])

def disentanglement(importance_matrix):
    per_code = disentanglement_per_code(importance_matrix)
    total = importance_matrix.sum()
    if total == 0.:
        return 0.0
    code_importance = importance_matrix.sum(axis=1) / total
    return np.sum(per_code * code_importance)

# === Completeness ===
def completeness_per_factor(importance_matrix):
    return 1. - entropy(importance_matrix + 1e-11, base=importance_matrix.shape[0])

def completeness(importance_matrix):
    per_factor = completeness_per_factor(importance_matrix)
    total = importance_matrix.sum()
    if total == 0.:
        return 0.0
    factor_importance = importance_matrix.sum(axis=0) / total
    return np.sum(per_factor * factor_importance)

# === DCI Master Function ===
def compute_dci(mus_train, ys_train, mus_test, ys_test):
    importance_matrix, train_acc, test_acc = compute_importance_rf(
        mus_train, ys_train, mus_test, ys_test
    )
    threshold = 1e-11
    importance_matrix = np.where(importance_matrix < threshold, 0, importance_matrix)
    return {
        "disentanglement": disentanglement(importance_matrix),
        "completeness": completeness(importance_matrix),
        "informativeness_train": train_acc,
        "informativeness_test": test_acc,
    }

In [None]:
covariate_keys = ["cell_type", "tissue", "development_stage", "donor_id", "development_stage_ontology_term_id", "lateralization"]
mus_train, ys_train, mus_test, ys_test = prep_data(
    adata, latent_2,covariate_keys=covariate_keys )
dci_scores = compute_dci(mus_train, ys_train, mus_test, ys_test)
dci_scores

In [None]:
#SAP score
from sklearn import svm

def compute_sap(mus, ys, mus_test, ys_test, continuous_factors):
    """Computes the SAP score.

    Args:
        mus, ys, mus_test, ys_test
        continuous_factors: Factors are continuous variable (True) or not (False).

    Returns:
        Dictionary with SAP score.
    """

    return _compute_sap(mus, ys, mus_test, ys_test, continuous_factors)

def _compute_sap(mus, ys, mus_test, ys_test, continuous_factors):
    """Computes score based on both training and testing codes and factors."""
    score_matrix = compute_score_matrix(mus, ys, mus_test, ys_test, continuous_factors)
    # Score matrix should have shape [num_latents, num_factors].
    assert score_matrix.shape[0] == mus.shape[0]
    assert score_matrix.shape[1] == ys.shape[0]
    scores_dict = {}
    scores_dict["SAP_score"] = compute_avg_diff_top_two(score_matrix)

    return scores_dict

def compute_score_matrix(mus, ys, mus_test, ys_test, continuous_factors):
    """Compute score matrix as described in Section 3."""
    num_latents = mus.shape[0]
    num_factors = ys.shape[0]
    score_matrix = np.zeros([num_latents, num_factors])
    for i in range(num_latents):
        for j in range(num_factors):
            mu_i = mus[i, :]
            y_j = ys[j, :]
            if continuous_factors:
                # Attribute is considered continuous.
                cov_mu_i_y_j = np.cov(mu_i, y_j, ddof=1)
                cov_mu_y = cov_mu_i_y_j[0, 1]**2
                var_mu = cov_mu_i_y_j[0, 0]
                var_y = cov_mu_i_y_j[1, 1]
                if var_mu > 1e-12:
                    score_matrix[i, j] = cov_mu_y * 1. / (var_mu * var_y)
                else:
                    score_matrix[i, j] = 0.
            else:
                # Attribute is considered discrete.
                mu_i_test = mus_test[i, :]
                y_j_test = ys_test[j, :]
                classifier = svm.LinearSVC(C=0.01, class_weight="balanced")
                classifier.fit(mu_i[:, np.newaxis], y_j)
                pred = classifier.predict(mu_i_test[:, np.newaxis])
                score_matrix[i, j] = np.mean(pred == y_j_test)
    return score_matrix

def compute_avg_diff_top_two(matrix):
    sorted_matrix = np.sort(matrix, axis=0)
    return np.mean(sorted_matrix[-1, :] - sorted_matrix[-2, :])

sap = compute_sap(mus_train, ys_train, mus_test, ys_test, continuous_factors=False)
sap

In [None]:
# IRS 


def compute_irs(mus, ys, diff_quantile=0.99):
    ys_discrete = make_discretizer(ys)

    active_mask = (mus.var(axis=1) > 0)
    active_mus = mus[active_mask, :]

    if active_mus.size == 0:
        irs_score = 0.0
    else:
        irs_score = scalable_disentanglement_score(ys_discrete.T, active_mus.T, diff_quantile)["avg_score"]

    score_dict = {}
    score_dict["IRS"] = irs_score
    score_dict["num_active_dims"] = int(np.sum(active_mask))
    return score_dict


def _drop_constant_dims(ys):
    """Returns a view of the matrix `ys` with dropped constant rows."""
    ys = np.asarray(ys)
    if ys.ndim != 2:
        raise ValueError("Expecting a matrix.")

    variances = ys.var(axis=1)
    active_mask = variances > 0.
    return ys[active_mask, :]


def scalable_disentanglement_score(gen_factors, latents, diff_quantile=0.99):
    """Computes IRS scores of a dataset.

    Assumes no noise in X and crossed generative factors (i.e. one sample per
    combination of gen_factors). Assumes each g_i is an equally probable
    realization of g_i and all g_i are independent.

    Args:
        gen_factors: Numpy array of shape (num samples, num generative factors),
            matrix of ground truth generative factors.
        latents: Numpy array of shape (num samples, num latent dimensions), matrix
            of latent variables.
        diff_quantile: Float value between 0 and 1 to decide what quantile of diffs
            to select (use 1.0 for the version in the paper).

    Returns:
        Dictionary with IRS scores.
    """
    num_gen = gen_factors.shape[1]
    num_lat = latents.shape[1]

    # Compute normalizer.
    max_deviations = np.max(np.abs(latents - latents.mean(axis=0)), axis=0)
    cum_deviations = np.zeros([num_lat, num_gen])
    for i in range(num_gen):
        unique_factors = np.unique(gen_factors[:, i], axis=0)
        assert unique_factors.ndim == 1
        num_distinct_factors = unique_factors.shape[0]
        for k in range(num_distinct_factors):
            # Compute E[Z | g_i].
            match = gen_factors[:, i] == unique_factors[k]
            e_loc = np.mean(latents[match, :], axis=0)

            # Difference of each value within that group of constant g_i to its mean.
            diffs = np.abs(latents[match, :] - e_loc)
            max_diffs = np.percentile(diffs, q=diff_quantile*100, axis=0)
            cum_deviations[:, i] += max_diffs
        cum_deviations[:, i] /= num_distinct_factors
    # Normalize value of each latent dimension with its maximal deviation.
    normalized_deviations = cum_deviations / max_deviations[:, np.newaxis]
    irs_matrix = 1.0 - normalized_deviations
    disentanglement_scores = irs_matrix.max(axis=1)
    if np.sum(max_deviations) > 0.0:
        avg_score = np.average(disentanglement_scores, weights=max_deviations)
    else:
        avg_score = np.mean(disentanglement_scores)

    parents = irs_matrix.argmax(axis=1)
    score_dict = {}
    score_dict["disentanglement_scores"] = disentanglement_scores
    score_dict["avg_score"] = avg_score
    score_dict["parents"] = parents
    score_dict["IRS_matrix"] = irs_matrix
    score_dict["max_deviations"] = max_deviations
    return score_dict



irs = compute_irs(mus_train, ys_train, diff_quantile=0.99)
irs