In [None]:
import numpy as np

def flatten_upper_triangular_excluding_diagonal(matrix):
    """
    Flattens the upper triangular part of a matrix, excluding the diagonal.

    Args:
        matrix: A list of lists or a NumPy array representing the matrix.

    Returns:
        A NumPy array containing the flattened upper triangular elements.
    """
    # Convert the input to a NumPy array for efficient operations.
    np_matrix = np.array(matrix)

    # Get the upper triangular part of the matrix, excluding the diagonal.
    # The 'k=1' argument specifies that the diagonal should not be included.
    upper_triangle = np.triu(np_matrix, k=1)

    # Flatten the resulting matrix.
    flattened_matrix = upper_triangle.flatten()

    # Filter out the zero values that were not part of the original matrix.
    # We use a boolean mask to keep only non-zero elements.
    result = flattened_matrix[flattened_matrix != 0]

    return result

from collections import OrderedDict
import random

def get_max_identity_class_size(high_average_similarity_idx):
    # order not important
    identity_class_sizes_l = []
    identity_class_sizes = {}
    for k in high_average_similarity_idx.keys():
        identity_class_sizes[k] = len(high_average_similarity_idx[k])
        identity_class_sizes_l.append(len(high_average_similarity_idx[k]))
    return np.max(identity_class_sizes_l)




def get_id_class_representative(data_dict: OrderedDict, keys):

    elements = []
    for key in keys:
        value = data_dict[key]
        
        # Check if the 'unused' list is empty
        if not value['unused']:
            value['unused'].extend(value['used'])
            value['used'].clear()
            random.shuffle(value['unused'])

        random_index = random.randint(0, len(value['unused']) - 1)
        element = value['unused'].pop(random_index)
        value['used'].append(element)
        elements.append(element)

    return elements, data_dict
def create_dataset_assembly_dict(high_average_similarity_idx):
    id_class_dictionary = {}
    all_genes_in_identity_class = []
    for k in  high_average_similarity_idx:
        id_class_dictionary[k] = {'unused': high_average_similarity_idx[k].copy(), 'used' : []}
        all_genes_in_identity_class = all_genes_in_identity_class+high_average_similarity_idx[k]
    return id_class_dictionary, all_genes_in_identity_class

from collections import OrderedDict
import random

def get_id_class_representative(data_dict: OrderedDict, keys) -> OrderedDict:

    elements = []
    for key in keys:
        value = data_dict[key]
        
        # Check if the 'unused' list is empty
        if not value['unused']:
            value['unused'].extend(value['used'])
            value['used'].clear()
            random.shuffle(value['unused'])

        random_index = random.randint(0, len(value['unused']) - 1)
        element = value['unused'].pop(random_index)
        value['used'].append(element)
        elements.append(element)

    return elements, data_dict


In [None]:
import sys
sys.path.append('/data_nfs/og86asub/netmap/netmap-evaluation/')

import scanpy as sc
import time 

from netmap.src.utils.misc import write_config

from netmap.src.model.negbinautoencoder import *
import scanpy as sc

from sklearn.model_selection import train_test_split
import time
from captum.attr import GradientShap, LRP
from netmap.src.model.inferrence_simple import *
from netmap.src.utils.data_utils import attribution_to_anndata
from netmap.src.model.pipeline import *
import numpy as np


from scipy.spatial.distance import pdist, squareform
from sklearn.cluster import HDBSCAN

import os.path as op
import os

import anndata as ad
from statsmodels.stats.nonparametric import rank_compare_2indep

import numpy as np
import pandas as pd
import scipy.sparse as scs
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import pandas as pd

from scipy.spatial.distance import pdist, squareform
from sklearn.cluster import HDBSCAN
from captum.attr import *
import pingouin as pingu
import torch
import torch.nn as nn



class ZINBLoss(nn.Module):
    def __init__(self, scale_factor=1.0, eps=1e-10, ridge_lambda=0.0):
        """
        Zero-Inflated Negative Binomial (ZINB) Loss
        Args:
            scale_factor (float): Scale factor applied to predictions.
            eps (float): Small value for numerical stability.
            ridge_lambda (float): Regularization weight for the zero-inflation probability (pi).
        """
        super(ZINBLoss, self).__init__()
        self.scale_factor = scale_factor
        self.eps = eps
        self.ridge_lambda = ridge_lambda

    def forward(self, y_true, y_pred, theta, pi):
        """
        Compute the ZINB loss.
        Args:
            y_true (torch.Tensor): Ground truth counts (non-negative integers).
            y_pred (torch.Tensor): Predicted mean values (mu).
            theta (torch.Tensor): Dispersion parameter (shape parameter).
            pi (torch.Tensor): Zero-inflation probability (between 0 and 1).
        Returns:
            torch.Tensor: ZINB negative log-likelihood.
        """
        eps = self.eps
        y_true = y_true.float()
        y_pred = y_pred.float() * self.scale_factor
        theta = theta.float()
        pi = torch.clamp(pi.float(), min=eps, max=1 - eps)  # Ensure pi is in (0, 1)

        # Clip theta to avoid numerical issues
        theta = torch.clamp(theta, max=1e6)

        # Negative binomial log-likelihood
        nb_case = (
            torch.lgamma(theta + eps)
            + torch.lgamma(y_true + 1.0)
            - torch.lgamma(y_true + theta + eps)
            + (theta + y_true) * torch.log(1.0 + (y_pred / (theta + eps)))
            + y_true * (torch.log(theta + eps) - torch.log(y_pred + eps))
        )

        # Zero-inflation log-likelihood for y_true = 0
        zero_nb = torch.pow(theta / (theta + y_pred + eps), theta)
        zero_case = -torch.log(pi + ((1.0 - pi) * zero_nb) + eps)

        # Combine cases: zero or NB
        result = torch.where(y_true < eps, zero_case, nb_case)

        # Add ridge penalty for pi
        ridge = self.ridge_lambda * torch.square(pi)
        result += ridge

        return torch.mean(result)  # Return mean loss over the batch




    

class ZINBAutoencoder(nn.Module):
    def __init__(self, input_dim, latent_dim, dropout_rate=0.0, hidden_dim = 128):
        super(ZINBAutoencoder, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate),  # Dropout after activation
            nn.Linear(hidden_dim, latent_dim)
        )
        
        # Decoder for mean (mu)
        self.decoder_mu = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate),  # Dropout after activation
            nn.Linear(hidden_dim, input_dim),
            nn.Softplus()  # Ensure non-negative predictions
        )
        
        # Decoder for dispersion (theta)
        self.decoder_theta = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate),  # Dropout after activation
            nn.Linear(hidden_dim, input_dim),
            nn.Softplus()  # Ensure non-negative dispersion
        )
        
        # Decoder for zero-inflation probability (pi)
        self.decoder_pi = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate),  # Dropout after activation
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()  # Ensure probability values between 0 and 1
        )
        
        self.zinb_loss = ZINBLoss()  # Use ZINBLoss for the computation
        self.forward_mu_only = False
        self.forward_theta_only = False
        self.latent_only = False
        self.forward_pi_only = False



    def forward(self, x):
        
        latent = self.encoder(x)
        mu = self.decoder_mu(latent)
        theta = self.decoder_theta(latent)
        pi = self.decoder_pi(latent)

        #data = self.decoder_data(latent)
        if self.forward_theta_only:
            return theta
        elif self.forward_mu_only:
            return mu 
        elif self.latent_only:
            return latent
        elif self.forward_pi_only:
            return pi
        else:
            return mu, theta, pi

    # def forward(self, x):
    #     # Latent representation
    #     latent = self.encoder(x)
        
    #     # Decode into mu, theta, and pi
    #     mu = self.decoder_mu(latent)
    #     theta = self.decoder_theta(latent)
    #     pi = self.decoder_pi(latent)
        
    #     return mu, theta, pi

    def compute_loss(self, x):
        # Forward pass
        mu, theta, pi = self.forward(x)
        
        # Compute ZINB loss
        loss = self.zinb_loss(x, mu, theta, pi)
        return loss
    

def create_model_zoo(data_tensor, n_models = 4, n_epochs = 500):
    model_zoo = []
    for _ in range(n_models):

        data_train2, data_test2 = train_test_split(data_tensor,test_size=0.01, shuffle=True)

        trained_model2 = NegativeBinomialAutoencoder(input_dim=data_tensor.shape[1], latent_dim=10, dropout_rate = 0.02)
        trained_model2 = trained_model2.cuda()

        optimizer2 = torch.optim.Adam(trained_model2.parameters(), lr=1e-4)

        trained_model2 = train_autoencoder(
                trained_model2,
                data_train2.cuda(),
                optimizer2,
                num_epochs=n_epochs

            )
        model_zoo.append(trained_model2)
    return model_zoo


def create_model_zoo(data_tensor, n_models = 4, n_epochs = 500):
    model_zoo = []
    for _ in range(n_models):

        data_train2, data_test2 = train_test_split(data_tensor,test_size=0.01, shuffle=True)

        trained_model2 = ZINBAutoencoder(input_dim=data_tensor.shape[1], latent_dim=10, dropout_rate = 0.02)
        trained_model2 = trained_model2.cuda()

        optimizer2 = torch.optim.Adam(trained_model2.parameters(), lr=1e-4)

        trained_model2 = train_autoencoder(
                trained_model2,
                data_train2.cuda(),
                optimizer2,
                num_epochs=n_epochs

            )
        model_zoo.append(trained_model2)
    return model_zoo




def set_latent_true(model_zoo):
    for mo in model_zoo:
        mo.forward_mu_only = False
        mo.forward_theta_only = False
        mo.latent_only = True
    return model_zoo


def set_all_false(model_zoo):
    for mo in model_zoo:
        mo.forward_mu_only = False
        mo.forward_theta_only = False
        mo.latent_only = False
    return model_zoo

def shuffle_each_column_independently(tensor):
    """
    Shuffles each column of a 2D PyTorch tensor independently.

    Args:
        tensor (torch.Tensor): The input tensor.

    Returns:
        torch.Tensor: A new tensor with each of its columns independently shuffled.
    """
    if tensor.dim() != 2:
        raise ValueError("Input tensor must be 2-dimensional to shuffle columns.")

    # Create an empty tensor of the same size to store the shuffled columns
    shuffled_tensor = torch.empty_like(tensor)

    # Iterate through each column, shuffle it, and place it in the new tensor
    for i in range(tensor.size(1)):
        column = tensor[:, i]
        idx = torch.randperm(column.nelement())
        shuffled_tensor[:, i] = column[idx]

    return shuffled_tensor


def attribution_one_target( 
        target_gene,
        lrp_model,
        input_data,
        background,
        xai_type='lrp-like',
        randomize_background = False):
    
    attributions_list = []
    for m in range(len(lrp_model)):
        # Randomize backgorund for each round
        if randomize_background:
            background = shuffle_each_column_independently(background)

        model = lrp_model[m]
        #for _ in range(num_iterations):
        if xai_type == 'lrp-like':
            #print(input_data)
            #print(target_gene)
            attribution = model.attribute(input_data, target=target_gene)
                
        elif xai_type == 'shap-like':
            attribution = model.attribute(input_data, baselines = background, target = target_gene)

        attributions_list.append(attribution.detach().cpu().numpy())
    return attributions_list

def get_differential_edges(attribution_anndata, percentile = 10):
    genelist = []
    if len(np.unique(attribution_anndata.obs['leiden']))>1 :
        for cat in np.unique(attribution_anndata.obs['leiden']):
            statisi =rank_compare_2indep(x1=attribution_anndata.X[attribution_anndata.obs['leiden']==cat], x2= attribution_anndata.X[attribution_anndata.obs['leiden']!=cat])
            sig_and_high = np.where((statisi.pvalue<(0.01/(attribution_anndata.X.shape[1]*attribution_anndata.X.shape[1])))  & (statisi.prob1>= 0.9))
            genelist = genelist+ list(sig_and_high[0])

    else:
        # FALLBACk
        m = np.abs(attribution_anndata.X).mean(axis=0)
        # Get the indices of genes in the top 10%
        top_10_percent_indices = np.where(m > np.percentile(m, 100-percentile))[0]

        # Get the indices of genes in the bottom 10%
        bottom_10_percent_indices = np.where(m < np.percentile(m, percentile))[0]

        # Combine the two arrays of indices and sort them
        genelist = np.unique(np.sort(
            np.concatenate((top_10_percent_indices, bottom_10_percent_indices))
        ))
    return genelist

def get_percentile_edges(attribution_anndata, percentile = 10):
    # FALLBACk
    m = attribution_anndata.X.mean(axis=0)
    # Get the indices of genes in the top 10%
    top_10_percent_indices = np.where(m > np.percentile(m, 100-percentile))[0]

    # Get the indices of genes in the bottom 10%
    bottom_10_percent_indices = np.where(m < np.percentile(m, percentile))[0]

    # Combine the two arrays of indices and sort them
    genelist = np.unique(np.sort(
        np.concatenate((top_10_percent_indices, bottom_10_percent_indices))
    ))
    return genelist

def get_edges(attribution_anndata, use_differential=False, percentile = 10):
    if use_differential:
        return get_differential_edges(attribution_anndata, percentile=percentile)
    else:
        return get_percentile_edges(attribution_anndata, percentile=percentile)
    
def get_explainer(model, explainer_type, raw=False):
    if explainer_type in ['GuidedBackprop', 'Deconvolution']:
        explainer_mode = 'lrp-like'
    else:
        explainer_mode = 'shap-like'
    
        
    if explainer_type == 'GuidedBackprop': #fast
        explainer = GuidedBackprop(model)
    elif explainer_type == 'GradientShap': #fast
        if raw:
            explainer = GradientShap(model, multiply_by_inputs=False)
        else:
            explainer = GradientShap(model, multiply_by_inputs=True)

    elif explainer_type == 'Deconvolution': #fast
        explainer = Deconvolution(model)
    else:
        raise ValueError('no such method')
        
    return explainer, explainer_mode

def compute_correlation_metric(data, cor_type):
    # Compute gene correlation measure
    #  'pingouin.pcorr', 'np.cov', 'np.corcoeff'
    if cor_type ==  'pingouin.pcorr':
        cov = pingu.pcorr(pd.DataFrame(data))
    elif cor_type == 'np.cov':
        cov = np.cov(data.T)
    elif cor_type == 'np.corrcoeff':
        cov = np.corrcoef(data.T)
    elif cor_type == 'None':
        cov = 1
    else: 
        cov = 1
    return cov

def aggregate_attributions(attributions, strategy = 'mean'):
    if strategy == 'mean':
        return np.mean(attributions, axis = 0)
    elif strategy == 'sum':
        return np.sum(attributions, axis = 0)
    elif strategy == 'median':
        return np.median(attributions, axis = 0)
    else:
        # Default to mean aggregation
        return np.mean(attributions, axis = 0)
    

    
def wrapper(models, data_train_full_tensor, gene_names, config):

    data = data_train_full_tensor.detach().cpu().numpy()
    tms = []
    name_list = []
    target_names = []
    
    
    
    
    ings = {}
    for trained_model in models:        
        trained_model.forward_mu_only = True
        explainer, xai_type = get_explainer(trained_model, config.xai_method, config.raw_attribution)
        tms.append(explainer)

    attributions = []
    ## ATTRIBUTIONS
    for g in tqdm(range(data_train_full_tensor.shape[1])):
    #for g in range(2):

        attributions_list = attribution_one_target(
            g,
            tms,
            data_train_full_tensor,
            data_train_full_tensor,
            xai_type=xai_type,
            randomize_background = True)
        attributions.append(attributions_list)

    

    ## AGGREGATION: REPLACE LIST BY AGGREGATED DATA
    for i in range(len(attributions)):
        # CURRENTLY MEAN
        attributions[i] = aggregate_attributions(attributions[i], strategy=config.aggregation_strategy )
    
    print(attributions)
    ## PENALIZE:
    if config.penalty != 'None':
        penalty_matrix = compute_correlation_metric(data, cor_type=config.penalty)
        for i in range(len(attributions)):
            # CURRENTLY MEAN
            attributions[i] = np.dot(attributions[i], (1-penalty_matrix))

    print(attributions)
    
    ## CLUSTERING: CLUSTER EACH TARGET INDVIDUALLY
    for i in range(len(attributions)):
 
        attributions[i] = ad.AnnData(attributions[i])
        print(attributions[i])
        sc.pp.scale(attributions[i])
        try:
            sc.pp.pca(attributions[i],n_comps=50)
        except:
            try:
                sc.pp.pca(attributions[i],n_comps=50 )
            except:
                continue
            
        sc.pp.neighbors(attributions[i], n_neighbors=15)
        sc.tl.leiden(attributions[i], resolution=0.1)

        clusterings[f'T_{gene_names[i]}'] = np.array(attributions[i].obs['leiden'])

    
    #EDGE SELECTION:
    for i in range(len(attributions)):
        edge_indices = get_edges(attributions[i], use_differential=config.use_differential, percentile=config.percentile)
        name_list = name_list + list(gene_names[edge_indices])
        target_names = target_names+[gene_names[i]]* len(edge_indices)
        attributions[i] = attributions[i][:,edge_indices].X

    attributions = np.hstack(attributions)
    
    index_list = [f"{s}_{t}" for (s, t) in zip(name_list, target_names)]
    cou = pd.DataFrame({'index': index_list, 'source':name_list, 'target':target_names})
    cou = cou.set_index('index')

    clusterings = pd.DataFrame(clusterings)

    grn_adata = attribution_to_anndata(attributions, var=cou, obs = clusterings)

    return grn_adata

def run_netmap(config, dataset_config):

    print('Version 2')
    start_total = time.monotonic()
    
    ## Load config and setup outputs
    os.makedirs(config.output_directory, exist_ok=True)
    sc.settings.figdir = config.output_directory
    config.write_yaml(yaml_file=op.join(config.output_directory, 'config.yaml'))

    ## load data
    adata = sc.read_h5ad(config.input_data)
    

    ## Get the data matrix from the CustumAnndata obeject

    gene_names = np.array(adata.var.index)
    model_start = time.monotonic()

    if config.layer == 'counts':
        data_tensor = adata.layers['counts']
    else:
        data_tensor = adata.X

    if scs.issparse(data_tensor):
        data_tensor = torch.tensor(data_tensor.todense(), dtype=torch.float32)
    else:
        data_tensor = torch.tensor(data_tensor, dtype=torch.float32)


    print(data_tensor.shape)

    model_zoo = create_model_zoo(data_tensor, n_models=config.n_models, n_epochs=500)
    grn_adata = wrapper(model_zoo, data_tensor.cuda(), gene_names, config)

    adob = adata.obs.reset_index()
    grn_adata.obs['cell_id'] = np.array(adob['cell_id'])
    grn_adata.obs['grn'] = np.array(adob['grn'])

    
    model_elapsed = time.monotonic()-model_start
    grn_adata.write_h5ad(op.join(config.output_directory,config.adata_filename))

    time_elapsed_total = time.monotonic()-start_total


    res = {'time_elapsed_total': time_elapsed_total, 'time_elapsed_netmap': model_elapsed} 
    write_config(res, file=op.join(config.output_directory, 'results.yaml'))



In [None]:



sys.path.append('/data_nfs/og86asub/netmap/netmap-evaluation/')
from netmap.src.utils.data_utils import *
from netmap.src.utils.tf_utils import *
from netmap.src.utils.netmap_config import NetmapConfig
from netmap.src.model.negbinautoencoder import *
from netmap.src.model.negbinautoencoder import train_autoencoder
from netmap.src.model.inferrence_simple import *
from netmap.src.model.pipeline import *

from src.data_simulation.data_simulation_config import DataSimulationConfig


import yaml
def read_config(file):
    with open(file, "r") as f:
        config = yaml.safe_load(f)
    return config

import os.path as op

#config = NetmapConfig.read_yaml("/data_nfs/og86asub/netmap/netmap-evaluation/results/configurations/netmap/config/perturb_seq/")
dada = "/data_nfs/og86asub/netmap/netmap-evaluation/results/configurations/data_simulation/config_easy/net_172_54892_net_131_54992_net_158_55084.config.yaml"
dataset_config = read_config("/data_nfs/og86asub/netmap/netmap-evaluation/results/configurations/data_simulation/config_easy/net_172_54892_net_131_54992_net_158_55084.config.yaml")

nets = [pd.read_csv(op.join("/data_nfs/og86asub/netmap/netmap-evaluation/data/clustered_network/", filename), sep='\t') for filename in dataset_config['edgelist']]
common = [pd.read_csv(op.join("/data_nfs/og86asub/netmap/netmap-evaluation/data/clustered_network/", filename), sep='\t') for filename in dataset_config['common_edges']]

    
config = NetmapConfig.read_yaml('/data_nfs/og86asub/netmap/netmap-evaluation/results/netmap/config_22/config_easy/net_172_54892_net_131_54992_net_158_55084/config.yaml')
dataset_config = DataSimulationConfig.read_yaml(dada)



start_total = time.monotonic()

## Load config and setup outputs
os.makedirs(config.output_directory, exist_ok=True)
sc.settings.figdir = config.output_directory
config.write_yaml(yaml_file=op.join(config.output_directory, 'config.yaml'))

## load data
adata = sc.read_h5ad(config.input_data)


## Get the data matrix from the CustumAnndata obeject
gene_names = np.array(adata.var.index)
model_start = time.monotonic()

if config.layer == 'counts':
    data_tensor = adata.layers['counts']
else:
    data_tensor = adata.X

if scs.issparse(data_tensor):
    data_tensor = torch.tensor(data_tensor.todense(), dtype=torch.float32)
else:
    data_tensor = torch.tensor(data_tensor, dtype=torch.float32)


print(data_tensor.shape)


In [None]:

def create_model_zoo(data_tensor, id_class_dictionary, keys, other_genes, n_models = 4, n_epochs = 500):
    model_zoo = []
    variables_selected = []
    representatives = []
    for _ in range(n_models):

        elements, id_class_dictionary = get_id_class_representative(id_class_dictionary, keys)

        current_data_selection = list(other_genes)+list(elements)
        current_data = data_tensor[:,current_data_selection]
        data_train2, data_test2 = train_test_split(current_data,test_size=0.01, shuffle=True)

        trained_model2 = ZINBAutoencoder(input_dim=current_data.shape[1], latent_dim=10, dropout_rate = 0.02)
        trained_model2 = trained_model2.cuda()

        optimizer2 = torch.optim.Adam(trained_model2.parameters(), lr=1e-4)

        trained_model2 = train_autoencoder(
                trained_model2,
                data_train2.cuda(),
                optimizer2,
                num_epochs=n_epochs

            )
        model_zoo.append(trained_model2)
        variables_selected.append(current_data_selection)
        representatives.append(elements)
    return model_zoo, variables_selected, representatives




In [None]:
from scipy.cluster import hierarchy
from scipy.stats import spearmanr
from statsmodels.formula.api import quantreg
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

def get_hierarchical_clustering(adata):

    corr_matrix, _ = spearmanr(adata.X, axis=0)

    corr_matrix = np.corrcoef(adata.X.T)

    corr_dist = 1 - corr_matrix
    dist_linkage = hierarchy.average(corr_dist)

    df = pd.DataFrame({'cophenet':hierarchy.cophenet(dist_linkage), 'corr':  flatten_upper_triangular_excluding_diagonal(corr_matrix)})

    low_quantile_model = quantreg('corr ~ cophenet', df).fit(q=0.1)

    # np.sort is used to ensure the line is drawn smoothly from left to right
    x_sorted = np.sort(df['cophenet'])
    y_predicted = low_quantile_model.predict({'cophenet': x_sorted})

    return dist_linkage





def fit_regression_clusterings(corr_matrix, dist_linkage, threshold=2):

    df = pd.DataFrame({'cophenet':hierarchy.cophenet(dist_linkage), 'corr':  flatten_upper_triangular_excluding_diagonal(corr_matrix)})
    df =df[df.cophenet<threshold]
    low_quantile_model = quantreg('corr ~ cophenet', df).fit(q=0.1)

    # np.sort is used to ensure the line is drawn smoothly from left to right
    df = df.sort_values('cophenet')
    x_sorted = df['cophenet']

    # model line
    y_predicted = low_quantile_model.predict({'cophenet': x_sorted})
    df['linear_model'] = y_predicted
    return df

def plot_regression(df):
    fig, ax = plt.subplots(figsize=(10, 6))
    plt.scatter(df['cophenet'], df['corr'], alpha=0.7, label='Data Points')

    ax.plot(df['cophenet'], df['linear_model'], color='red', linewidth=2, label='10th Percentile Quantile Regression Line')
    ax.axhline(y=0.6, color='r', linestyle='--', label='y = 0.6')
    # Add labels and a legend for clarity
    plt.title('Quantile Regression with correlation threshold')
    plt.xlabel('Cophenet')
    plt.ylabel('Correlation')
    plt.legend()
    plt.grid(True)
    plt.show()




def cut_clustering_and_gene_mapping(corr_matrix, dist_linkage, threshold, variable_names):
    clusters = hierarchy.fcluster(dist_linkage, t=threshold, criterion='distance')

    cluster_to_genes = {}
    for gene_name, cluster_id in zip(variable_names, clusters):
        if cluster_id not in cluster_to_genes:
            cluster_to_genes[cluster_id] = []
        cluster_to_genes[cluster_id].append(gene_name)

    high_average_similarity = {}
    high_average_similarity_idx = {}

    for cluster_id, gene_list in cluster_to_genes.items():
        if len(gene_list) > 1:
            # Get the sub-matrix of the correlation matrix for the genes in the cluster
            gene_indices = [variable_names.get_loc(g) for g in gene_list]
            cluster_corr_matrix = corr_matrix[np.ix_(gene_indices, gene_indices)]

            # Calculate the average of the upper triangle (excluding the diagonal)
            upper_triangle_indices = np.triu_indices_from(cluster_corr_matrix, k=1)
            average_similarity = np.mean(cluster_corr_matrix[upper_triangle_indices])

            #if average_similarity >= req_sim:
            print(f"Cluster {cluster_id}: {gene_list}")
            print(f"  Average Similarity: {average_similarity:.4f}")
            print("-" * 45)
            high_average_similarity[cluster_id] = gene_list
            high_average_similarity_idx[cluster_id] = gene_indices
    return high_average_similarity_idx, high_average_similarity


corr_matrix = np.corrcoef(adata.X.T)
corr_dist = 1 - corr_matrix
dist_linkage = hierarchy.average(corr_dist)
df = fit_regression_clusterings(corr_matrix, dist_linkage, threshold=2)
plot_regression(df)
cluster_mapping, cluster_mapping_genes = cut_clustering_and_gene_mapping(corr_matrix, dist_linkage, 0.75, adata.var_names)


In [None]:
gene_to_position_mapper = {}
max_un = (len(variables_selected[0])-len(representatives[0]))
for i in range(len(gene_names)):
    if i < (len(variables_selected[0])-len(representatives[0])):
        gene_to_position_mapper[i] = i

start_idx = max_un
current_idx = max_un
for k in cluster_mapping:
    for elem in cluster_mapping[k]:
        gene_to_position_mapper[current_idx] =start_idx 
        current_idx = current_idx+1
    start_idx = start_idx+1


In [None]:
def attribution_one_target( 
        target_gene,
        lrp_model,
        input_data,
        background,
        selected_variables,
        xai_type='lrp-like',
        randomize_background = False):
    
    attributions_list = []
    for m in range(len(lrp_model)):

        # Randomize backgorund for each round
        if randomize_background:
            current_background = background[:, selected_variables[m]]
            current_background = shuffle_each_column_independently(current_background)

        current_data = input_data[:, selected_variables[m]]
        model = lrp_model[m]
        #for _ in range(num_iterations):
        if xai_type == 'lrp-like':
            attribution = model.attribute(current_data, target=target_gene)
                
        elif xai_type == 'shap-like':
            attribution = model.attribute(current_data, baselines = current_background, target = target_gene)

        attributions_list.append(attribution.detach().cpu().numpy())
    return attributions_list


def attribution_one_target_one_model( 
        target_gene,
        lrp_model,
        input_data,
        background,
        selected_variables,
        xai_type='lrp-like',
        randomize_background = False):
    
    # Randomize backgorund for each round
    if randomize_background:
        background = shuffle_each_column_independently(background)

    model = lrp_model
    #for _ in range(num_iterations):
    if xai_type == 'lrp-like':
        #print(input_data)
        #print(target_gene)
        attribution = model.attribute(input_data, target=target_gene)
            
    elif xai_type == 'shap-like':
        attribution = model.attribute(input_data, baselines = background, target = target_gene)

    return attribution.detach().cpu().numpy()

def get_top_edges_per_cell(grn_adata, top_edges):
    
    counters = np.zeros((grn_adata.shape))

    idex = grn_adata.shape[1]-top_edges
    b = np.argpartition(grn_adata, idex, axis=1)[:, idex:]

    np.put_along_axis(counters, b, 1, axis=1)
    return counters

    
def assemble_attributions(current_attr, variables_selected, a_shape):
    assembled_attributions = np.zeros(a_shape)
    column_counter = np.zeros(a_shape[1])
    for i in range(len(variables_selected)):
        for j in range(len(variables_selected[i])):
            assembled_attributions[:, variables_selected[i][j]] += current_attr[i][:, j]
            column_counter[variables_selected[i][j]] +=1
    for i in range(len(column_counter)):
        if column_counter[i] == 0:
            column_counter[i] =1
    assembled_attributions = assembled_attributions/column_counter
    return assembled_attributions, column_counter

def wrapper(models, data_train_full_tensor, gene_names, variables_selected, config):

    data = data_train_full_tensor.detach().cpu().numpy()
    tms = []
    name_list = []
    target_names = []


    for trained_model in models:        
        trained_model.forward_mu_only = True
        explainer, xai_type = get_explainer(trained_model, config.xai_method, config.raw_attribution)
        tms.append(explainer)

    attributions = []
    all_counter = []
    ## ATTRIBUTIONS
    for g in tqdm(range(data_train_full_tensor.shape[1])):
    #for g in tqdm(range(len(variables_selected[0]))):
        current_gene = gene_to_position_mapper[g]
        current_attributions = []
        counter_list = []
        for m in range(len(tms)):

            attributions_list = attribution_one_target_one_model(
                current_gene,
                tms[m], # Select the correct model
                data_train_full_tensor[:, variables_selected[m]],
                data_train_full_tensor[:, variables_selected[m]],
                variables_selected, 
                xai_type=xai_type,
                randomize_background = True)
            counters = get_top_edges_per_cell(attributions_list, 100)
            counter_list.append(counters)
            current_attributions.append(attributions_list)

        current_attributions, column_counter = assemble_attributions(current_attributions, variables_selected, data_train_full_tensor.shape)
        counter_list, col = assemble_attributions(counter_list, variables_selected, data_train_full_tensor.shape)
        attributions.append(current_attributions)
        all_counter.append(counter_list)



    ## AGGREGATION: REPLACE LIST BY AGGREGATED DATA
    # for i in range(len(attributions)):
    #     # CURRENTLY MEAN
    #     attributions[i] = aggregate_attributions(attributions[i], strategy=config.aggregation_strategy )
    

    ## PENALIZE:
    if config.penalty != 'None':
        penalty_matrix = compute_correlation_metric(data, cor_type=config.penalty)
        for i in range(len(attributions)):
            # CURRENTLY MEAN
            attributions[i] = np.dot(attributions[i], (1-penalty_matrix))

    
    ## CLUSTERING: CLUSTER EACH TARGET INDVIDUALLY
    for i in range(len(attributions)):
 
        attributions[i] = ad.AnnData(attributions[i])
        print(attributions[i])
        sc.pp.scale(attributions[i])
        try:
            sc.pp.pca(attributions[i],n_comps=50)
        except:
            try:
                sc.pp.pca(attributions[i],n_comps=50 )
            except:
                continue
            
        sc.pp.neighbors(attributions[i], n_neighbors=15)
        sc.tl.leiden(attributions[i], resolution=0.1)

        #clusterings[f'T_{gene_names[i]}'] = np.array(attributions[i].obs['leiden'])

    #EDGE SELECTION:
    for i in range(len(attributions)):
        edge_indices = get_edges(attributions[i], use_differential=config.use_differential, percentile=config.percentile)
        name_list = name_list + list(gene_names[edge_indices])
        target_names = target_names+[gene_names[i]]* len(edge_indices)
        attributions[i] = attributions[i][:,edge_indices].X
        all_counter[i] = all_counter[i][:, edge_indices]

    attributions = np.hstack(attributions)
    all_counter = np.hstack(all_counter)

    
    index_list = [f"{s}_{t}" for (s, t) in zip(name_list, target_names)]
    cou = pd.DataFrame({'index': index_list, 'source':name_list, 'target':target_names})
    cou = cou.set_index('index')

    #clusterings = pd.DataFrame(clusterings)

    #grn_adata = attribution_to_anndata(attributions, var=cou, obs = clusterings)
    grn_adata = attribution_to_anndata(attributions, var=cou)
    grn_adata.layers['counter'] = all_counter
    return grn_adata





In [None]:
def extend_edges(mapping, top):
    expanded = []
    for i, edges in top.iterrows():
        # Add combinations all all mapped edges in the lists
        if edges['source'] in mapping and edges['target'] in mapping:
            for se in mapping[edges['source']]:
                for st in mapping[edges['target']]:
                    expanded.append([se, st])
        if edges['source'] in mapping:
            for se in mapping[edges['source']]:
                expanded.append([se, edges['target']])
        if edges['target'] in mapping:
            for se in mapping[edges['target']]:
                expanded.append([edges['source'], se])

    expanded = pd.DataFrame(expanded)
    expanded.columns = ['source', 'target']
    all_edges = pd.concat([top, expanded])
    return all_edges


In [None]:
model_zoo, variables_selected, representatives = create_model_zoo(data_tensor, id_class_dictionary, keys, other_genes, n_models=50, n_epochs=500)

In [None]:
config.raw_attribution = True
config.percentile = 55
grn_adata = wrapper(model_zoo, data_tensor.cuda(), gene_names, variables_selected,  config)
