# Confounding mitigating reimplementations of asw batch score 
Author: Pia Rautenstrauch

Date: 2024-07-16

"Previous"/Standard implementation
- nearest cluster

"Confounding mitigating"
- furthest cluster
- mean to every cell of other clusters

In [1]:
import numpy as np
from sklearn.metrics.pairwise import pairwise_distances

In [9]:
# Modified from scib (v1.0.1)
def silhouette_batch_custom(
        adata,
        batch_key,
        group_key,
        embed,
        metric='euclidean',
        return_all=False,
        scale=True,
        verbose=True,
        between_cluster_distances='nearest'
):
    """
    Modification of silhouette_batch from scib package (custom silhouette_samples function) to prevent confounding by nested batch effects.
    Absolute silhouette score of batch labels subsetted for each group. Groups are usually cell types in this context.
    between_cluster_distances='nearest' is equivalent to scib original implementation

    :param batch_key: batches to be compared against
    :param group_key: group labels to be subsetted by e.g. cell type
    :param embed: name of column in adata.obsm
    :param metric: see sklearn silhouette score
    :param scale: if True, scale between 0 and 1
    :param return_all: if True, return all silhouette scores and label means
        default False: return average width silhouette (ASW)
    :param between_cluster_distances: one out of 'mean_other', 'furthest', 'nearest'
    :param verbose:
    :return:
        average width silhouette ASW
        mean silhouette per group in pd.DataFrame
        Absolute silhouette scores per group label
    """
    if embed not in adata.obsm.keys():
        print(adata.obsm.keys())
        raise KeyError(f'{embed} not in obsm')

    sil_all = pd.DataFrame(columns=['group', 'silhouette_score'])

    for group in adata.obs[group_key].unique():
        adata_group = adata[adata.obs[group_key] == group]
        n_batches = adata_group.obs[batch_key].nunique()

        if (n_batches == 1) or (n_batches == adata_group.shape[0]):
            continue
        
        #Modified
        sil_per_group = silhouette_samples_custom(
            adata_group.obsm[embed],
            adata_group.obs[batch_key],
            metric=metric,
            between_cluster_distances=between_cluster_distances,
        )

        # take only absolute value
        sil_per_group = [abs(i) for i in sil_per_group]

        if scale:
            # scale s.t. highest number is optimal
            sil_per_group = [1 - i for i in sil_per_group]

        #sil_all = sil_all.append(
        #    pd.DataFrame({
        #        'group': [group] * len(sil_per_group),
        #        'silhouette_score': sil_per_group
        #    })
        #)
        
        sil_all = pd.concat([sil_all, pd.DataFrame({
                'group': [group] * len(sil_per_group),
                'silhouette_score': sil_per_group
            })], ignore_index=True)

    sil_all = sil_all.reset_index(drop=True)
    sil_means = sil_all.groupby('group').mean()
    asw = sil_means['silhouette_score'].mean()

    if verbose:
        print(f'mean silhouette per cell: {sil_means}')

    if return_all:
        return asw, sil_means, sil_all

    return asw

In [10]:
def silhouette_samples_custom(X, labels, metric="euclidean", between_cluster_distances="nearest"):
    """
    Compute the average silhouette score for the dataset X with the given labels.

    Parameters:
    X : array-like, shape (n_samples, n_features)
        Feature array.
    labels : array-like, shape (n_samples,)
        Labels of each point.
        
    metric : metric for distance calculation, default:"euclidean", alternatives, e.g., "cosine"
    
    between_cluster_distances: one out of "mean_other", "furthest", "nearest"


    Returns:
    score : float
        The average silhouette score.
    """

    # Number of clusters
    unique_labels = np.unique(labels)
    n_clusters = len(unique_labels)

    # If there's only one cluster or no clusters, return 0 as silhouette score cannot be computed
    if n_clusters == 1 or n_clusters == 0:
        return 0

    # Initialize silhouette scores
    silhouette_scores = np.zeros(len(X))

    # Calculate pairwise distance matrix
    #distance_matrix = np.linalg.norm(X[:, np.newaxis] - X, axis=2)
    distance_matrix = pairwise_distances(X, metric=metric)
    
    for i in range(len(X)):
        # Points in the same cluster
        same_cluster = labels == labels[i]
        other_clusters = labels != labels[i]
        # Exclude the current point for intra-cluster distance
        same_cluster[i] = False

        # a: Mean distance from i to all other points in the same cluster
        if np.sum(same_cluster) > 0:
            a = np.mean(distance_matrix[i, same_cluster])
        else:
            a = 0

        # b: Mean distance from i to all points in the furthest different cluster
        if between_cluster_distances == "furthest":
            b = np.max([
                np.mean(distance_matrix[i, labels == label]) 
                for label in unique_labels if label != labels[i]
            ])
        
        # b: Mean distance from i to all points in any other cluster
        elif between_cluster_distances == "mean_other":
            b = np.mean(distance_matrix[i, other_clusters]) 
            
        # b: Mean distance from i to all points in the nearest different cluster
        else:
            b = np.min([
                np.mean(distance_matrix[i, labels == label]) 
                for label in unique_labels if label != labels[i]
            ])

        # Silhouette score for point i
        silhouette_scores[i] = (b - a) / max(a, b)

    # Average silhouette score for all points
    # return np.mean(silhouette_scores)
    
    return silhouette_scores


In [11]:
?silhouette_batch_custom

[0;31mSignature:[0m
[0msilhouette_batch_custom[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0madata[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mbatch_key[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mgroup_key[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0membed[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmetric[0m[0;34m=[0m[0;34m'euclidean'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mreturn_all[0m[0;34m=[0m[0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mscale[0m[0;34m=[0m[0;32mTrue[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mverbose[0m[0;34m=[0m[0;32mTrue[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mbetween_cluster_distances[0m[0;34m=[0m[0;34m'nearest'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Modification of silhouette_batch from scib package (custom silhouette_samples function)
Absolute silhouette score of batch labels subsetted for each group. Groups are usually cell types in this conte

In [12]:
?silhouette_samples_custom

[0;31mSignature:[0m
[0msilhouette_samples_custom[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mX[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mlabels[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmetric[0m[0;34m=[0m[0;34m'euclidean'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mbetween_cluster_distances[0m[0;34m=[0m[0;34m'nearest'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Compute the average silhouette score for the dataset X with the given labels.

Parameters:
X : array-like, shape (n_samples, n_features)
    Feature array.
labels : array-like, shape (n_samples,)
    Labels of each point.
    
metric : metric for distance calculation, default:"euclidean", alternatives, e.g., "cosine"

between_cluster_distances: one out of "mean_other", "furthest", "nearest"


Returns:
score : float
    The average silhouette score.
[0;31mFile:[0m      /tmp/7367919.1.all.q/ipykernel_2608002/4094074416.py
[0;31mType:[0m      functio