# Comparison of all Methods

In [1]:
import scanpy as sc

In [2]:
data_dir = '/data/peer/moormana/GitHub/dpeerlab/segger-analysis/data/xenium_nsclc/h5ads/cellpose_cell_id.h5ad'
ad = sc.read_h5ad(data_dir)

In [2]:
%load_ext autoreload
%autoreload 2

In [19]:
from requirements import *
from collections import defaultdict
from matplotlib.colors import to_rgba

## MECR Marker Pairs

In [3]:
from segger.validation.utils import *

In [4]:
# Read in AnnData:
#  - Downsampled to workable size (10% cells)
#  - Subset to Xenium panel genes
base_dir = Path('/data/peer/moormana/GitHub/dpeerlab/segger_dev/dev/')
scrna_dir = base_dir / 'figure_3/data/inputs/scrnaseq'
atlas = 'integrated_human_lung_cell_atlas'
ad = sc.read_h5ad(scrna_dir / f'{atlas}_small.h5ad')

In [5]:
# Get positive and negative markers per cell type
ct_col = 'cell_type_simplified'
markers = find_markers(ad, ct_col)

In [6]:
# Get mutually exclusive gene pairs
me_genes = find_mutually_exclusive_genes(ad, markers, ct_col)

100%|██████████| 10/10 [02:25<00:00, 14.50s/it]


## Calculate Summary Stats

In [57]:
def get_overlap(
    tx,
    a_col,
    b_col,
    mask_col='is_epithelial'
):
    mask = tx[mask_col]
    if 'segger' in a_col:
        score_col = a_col.replace('cell_id', 'score')
        mask = ~tx[score_col].isna()
    counts = tx[mask].value_counts([b_col, a_col])
    idx = counts.index.get_level_values(1)
    
    if 'segger' in a_col:
        totals = tx.loc[mask, a_col].value_counts().loc[idx]
    else:
        totals = tx[a_col].value_counts().loc[idx]
    fracs = counts / totals.values
    significant = fracs.gt(0.25) & totals.gt(5).values
    overlaps = significant.groupby(b_col).sum()
    return overlaps

In [58]:
def get_contamination(
    tx,
    a_col,
    b_col,
    mask_col='is_epithelial'
):
    # Only return for epithelial cells
    contamination = tx.groupby(a_col)[mask_col].mean()
    mask = contamination.gt(0.33)
    mask &= tx.groupby(a_col).size().gt(5)
    return contamination[mask].dropna()

In [59]:
def get_recall(
    tx,
    a_col,
    b_col,
    mask_col='is_epithelial'
):
    # Calculate all counts
    counts_b = tx[b_col].value_counts()
    counts = tx.value_counts([a_col, b_col])
    gb = counts.groupby(a_col)
    idx = counts.index.get_level_values(1)[gb.idxmax()]
    counts_b = cudf.Series(counts_b[idx].values, gb.idxmax().index)
    counts_b[counts_b.lt(5)] = None

    # Calculate recall
    intersection = gb.max()
    recall = intersection / counts_b.loc[intersection.index]

    # Only return for epithelial cells
    mask = tx[mask_col].groupby(tx[a_col]).mean().gt(0.33)
    return recall[mask.loc[recall.index]].dropna()

In [60]:
def get_jaccard(
    tx,
    a_col,
    b_col,
    mask_col='is_epithelial',
):
    # Calculate all counts
    counts_a = tx[a_col].value_counts()
    counts_b = tx[b_col].value_counts()
    counts = tx.value_counts([a_col, b_col])
    gb = counts.groupby(a_col)
    idx = counts.index.get_level_values(1)[gb.idxmax()]
    counts_b = cudf.Series(counts_b[idx].values, gb.idxmax().index)
    counts_a[counts_a.lt(10)] = None
    counts_b[counts_b.lt(10)] = None

    # Calculate jaccard
    intersection = gb.max()
    union = counts_a + counts_b - intersection
    jaccard = intersection / union

    # Only return for epithelial cells
    mask = tx[mask_col].groupby(tx[a_col]).mean().gt(0.33)
    return jaccard[mask].dropna()

In [61]:
def get_purity(
    tx,
    a_col,
    b_col,
    mask_col='is_epithelial',
):
    # Calculate all counts
    counts_a = tx[a_col].value_counts()
    counts_a[counts_a.lt(10)] = None
    counts = tx.value_counts([a_col, b_col])
    gb = counts.groupby(a_col)

    # Calculate purity
    intersection = gb.max()
    purity = intersection / counts_a.loc[intersection.index]

    # Only return for epithelial cells
    mask = tx[mask_col].groupby(tx[a_col]).mean().gt(0.33)
    return purity[mask.loc[purity.index]].dropna()

In [62]:
def get_no_transcripts(
    tx,
    a_col,
    mask_col='is_epithelial',
):
    # Calculate all counts
    counts_a = tx[a_col].value_counts()

    # Only return for epithelial cells
    mask = tx[mask_col].groupby(tx[a_col]).mean().gt(0.33)
    return counts_a[mask].dropna()

In [63]:
segmentations = {
    'cellpose_cell_id': 'Cellpose',
    '10x_cell_id': '10X',
    '10x_nucleus_id': '10X Nucleus',
    'segger_cell_id_3Q6EISGCD9': 'segger',
    'segger_cell_id_HDE46PBXJB': 'segger+',
    'baysor_cell_id_c=0.9': 'Baysor, c=0.9',
    'baysor_cell_id_c=0.7': 'Baysor, c=0.7',
    'baysor_cell_id_c=0.5': 'Baysor',
}

main_segmentations = [
    'Cellpose',
    '10X',
    '10X Nucleus',
    'segger+',
    'Baysor',
]

In [64]:
# Read and filter transcripts
transcripts_filepath = data_dir / 'labeled_transcripts.parquet'
columns = list(segmentations.keys())
columns += [
    'cellpose_cell_type',
    'segger_score_HDE46PBXJB',
    'segger_score_3Q6EISGCD9',
    'feature_name'
]
tx = pd.read_parquet(transcripts_filepath, columns=columns)
tx['is_epithelial'] = tx['cellpose_cell_type'].eq('Epithelial/Cancer')

In [65]:
# Calculate metrics
metrics = defaultdict(list)

targ_col = 'cellpose_cell_id'

for pred_col, name in tqdm(segmentations.items()):
    
    # Read into cuDF for faster ops
    columns = list(set([pred_col, targ_col, 'is_epithelial']))
    if 'segger' in pred_col:
        score_col = pred_col.replace('cell_id', 'score')
        columns.append(score_col)
    tx_cudf = cudf.from_pandas(tx[columns])
    
    # Filter low-score transcripts for Segger
    if 'segger' in pred_col:
        mask = tx_cudf[pred_col.replace('cell_id', 'score')].lt(0.5)
        tx_cudf.loc[mask, pred_col] = None
    
    # Calculate metrics
    if 'cellpose' not in pred_col:
        # Jaccard
        jaccard = get_jaccard(tx_cudf, pred_col, targ_col).to_pandas()
        jaccard.name = 'value'
        jaccard = pd.DataFrame(jaccard)
        jaccard['segmentation'] = name
        metrics['jaccard'].append(jaccard)
        # Purity
        purity = get_purity(tx_cudf, pred_col, targ_col).to_pandas()
        purity.name = 'value'
        purity = pd.DataFrame(purity)
        purity['segmentation'] = name
        metrics['purity'].append(purity)
        # Recall
        recall = get_recall(tx_cudf, pred_col, targ_col).to_pandas()
        recall.name = 'value'
        recall = pd.DataFrame(recall)
        recall['segmentation'] = name
        metrics['recall'].append(recall)
        # Overlap
        overlap = get_overlap(tx_cudf, pred_col, targ_col).to_pandas()
        overlap.name = 'value'
        overlap = pd.DataFrame(overlap)
        overlap['segmentation'] = name
        metrics['overlap'].append(overlap)
        # Contamination
        contam = get_contamination(tx_cudf, pred_col, targ_col).to_pandas()
        contam.name = 'value'
        contam = pd.DataFrame(contam)
        contam['segmentation'] = name
        metrics['contam'].append(contam)
    
    else:
        # Jaccard
        columns = ['value', 'segmentation']
        jaccard = pd.DataFrame([[0, 'Cellpose']], columns=columns)
        metrics['jaccard'].append(jaccard)
        # Purity
        purity = pd.DataFrame([[0, 'Cellpose']], columns=columns)
        metrics['purity'].append(purity)
        # Recall
        recall = pd.DataFrame([[0, 'Cellpose']], columns=columns)
        metrics['recall'].append(recall)
        # Overlap
        overlap = pd.DataFrame([[0, 'Cellpose']], columns=columns)
        metrics['overlap'].append(overlap)
        # Contamination
        contam = pd.DataFrame([[0, 'Cellpose']], columns=columns)
        metrics['contam'].append(contam)
    
    # No. Transcripts
    num_tx = get_no_transcripts(tx_cudf, pred_col).to_pandas()
    num_tx.name = 'value'
    num_tx = pd.DataFrame(num_tx)
    num_tx['segmentation'] = name
    metrics['num_tx'].append(num_tx)
    
    # No. Cells
    num_cells = dict(name=tx_cudf[pred_col].nunique())
    metrics['num_cells'].append(num_cells)

100%|██████████| 8/8 [01:39<00:00, 12.41s/it]


In [66]:
for metric, d in tqdm(metrics.items()):
    if metric != 'num_cells':
        df = pd.concat(d, axis=0)
        df.to_csv(data_dir / 'metrics' / f'{metric}.csv')

100%|██████████| 7/7 [00:22<00:00,  3.24s/it]


## Plot Summary Stats

In [90]:
metrics = dict()
for metric in ['contam', 'jaccard', 'overlap', 'purity', 'recall', 'num_tx']:
    metrics[metric] = pd.read_csv(
        data_dir / 'metrics' / f'{metric}.csv',
        index_col=0
    )
metrics['contam']['value'] = 100 * (1 - metrics['contam']['value'])
metrics['contam'].loc['0', 'value'] = 0

In [97]:
main_order = ['Cellpose', '10X', '10X Nucleus', 'segger+', 'Baysor']
supp_order = [
    'Cellpose', '10X', '10X Nucleus',
    'segger+', 'segger',
    'Baysor', 'Baysor, c=0.7', 'Baysor, c=0.9'
]
ylims = {
    'contam': (0.0, 12.5),
    'purity': (0.5, 1.0),
    'recall': (0.0, 0.9),
    'num_tx': (0, 300),
    'overlap': (0, 3.15),
    'jaccard': (0, 0.75),
}

### Main

In [101]:
order = main_order
palette = [to_rgba(named_colors[o]) for o in order]
styles = dict(saturation=0.9, capsize=0.3, width=0.6, palette=palette)

for metric, df in tqdm(metrics.items()):
    df = df[df['segmentation'].isin(order)]
    fig, ax = plt.subplots(1, 1, figsize=(1.1, 1.1))
    sns.barplot(
        df,
        x='segmentation',
        y='value',
        errorbar=('pi', 50),
        order=order,
        err_kws=dict(linewidth=0.5, alpha=0.75, color='k'),
        ax=ax,
        linewidth=0,
        **styles,
    )
    ax.tick_params(labelsize=7)
    tickstyles = dict(ha='right', va='top', rotation=45, rotation_mode="anchor")
    ax.set_xticklabels(ax.get_xticklabels(), **tickstyles)
    ax.set_xlabel('')
    ax.set_ylabel('')
    ax.set_ylim(ylims[metric])
    fig.savefig(media_dir / f'{metric}_main.svg')
    plt.close()

100%|██████████| 6/6 [00:28<00:00,  4.78s/it]


### Supplementary

In [102]:
order = supp_order
palette = [to_rgba(named_colors[o]) for o in order]
styles = dict(saturation=0.9, capsize=0.3, width=0.6, palette=palette)

for metric, df in tqdm(metrics.items()):
    df = df[df['segmentation'].isin(order)]
    fig, ax = plt.subplots(1, 1, figsize=(1.75, 1.1))
    sns.barplot(
        df,
        x='segmentation',
        y='value',
        errorbar=('pi', 50),
        order=order,
        err_kws=dict(linewidth=0.5, alpha=0.75, color='k'),
        ax=ax,
        linewidth=0,
        **styles,
    )
    ax.tick_params(labelsize=7)
    tickstyles = dict(ha='right', va='top', rotation=45, rotation_mode="anchor")
    ax.set_xticklabels(ax.get_xticklabels(), **tickstyles)
    ax.set_xlabel('')
    ax.set_ylabel('')
    ax.set_ylim(ylims[metric])
    fig.savefig(media_dir / f'{metric}_supp.svg')
    plt.close()

100%|██████████| 6/6 [00:57<00:00,  9.57s/it]
