In [None]:
import igraph as ig
import numpy as np
import multiprocessing as mp
from sklearn.metrics import jaccard_score
from scipy.sparse import csr_matrix, coo_matrix, csc_matrix
import ipywidgets as widgets
from ipywidgets import interact
import seaborn as sns

In [None]:
# Core scverse libraries
import scanpy as sc
import anndata as ad
import pandas as pd
import numpy as np
import pertpy as pt
import matplotlib.pyplot as plt
# Data retrieval
import pooch
import scanpy.external as sce
import os

In [None]:
plt.rcParams['figure.dpi'] = 400


In [None]:
# Read in sparse matrix(es) -- there may be multiple if experiment is split across single-cell preps
sparse_path_1 = 'xyz.csv'
if os.path.exists(sparse_path_1):
    sparse_matrix_1 = pd.read_csv(sparse_path_1, sep=',', header = 'infer', index_col='cellbc')
sparse_matrix_1.index = sparse_matrix_1.index + "-1"

In [None]:
sparse_matrix_1.shape

In [None]:
# second sparse matrix if second single-cell kit is used
sparse_path_2 = 'xyz.csv'
if os.path.exists(sparse_path_2):
    sparse_matrix_2 = pd.read_csv(sparse_path_2, sep=',', header = 'infer', index_col='cellbc')
sparse_matrix_2.index = sparse_matrix_2.index + "-2"

In [None]:
sparse_matrix_2.shape

In [None]:
# In parallel, best to perform doublet detection on transcriptome dataset and port over list of
# detected & removed doublets for upstream removal of their signals here, since doublets may drive 
# artificial clone collisions
doublet_df = pd.read_csv('zyx.csv', index_col = 'cellbc')

In [None]:
doublet_df = doublet_df.rename(columns={'Unnamed: 0': "doublet_cells"})

In [None]:
# merge each sparse matrix into a larger matrix using outer concatenation
sparse_matrix_list = [sparse_matrix_1, sparse_matrix_2]

In [None]:
sparse_matrix = pd.concat(sparse_matrix_list, join = 'outer')

In [None]:
# maintain sparse matrix format by filling missing values with 0
sparse_matrix = sparse_matrix.fillna(0)

In [None]:
sparse_matrix.shape

In [None]:
# Filter out doublet cells from the transcriptome and (for example) gRNA analysis - 
sparse_matrix = sparse_matrix[~sparse_matrix.index.isin(doublet_df.index)]



In [None]:
# check removed shape post doublet removal
sparse_matrix.shape

In [None]:
# remove lineage barcode columns now empty after doublet removal
sparse_matrix = sparse_matrix.loc[:, sparse_matrix.sum() != 0]


In [None]:
# again check mtx post doublet-only LBC removal
sparse_matrix.shape

In [None]:
sparse_lbc_sums= np.sum(sparse_matrix, axis = 0)
sparse_cell_sums = np.sum(sparse_matrix, axis = 1)


In [None]:
# Plot statistics of lineage barcodes per cell 
plt.hist(sparse_lbc_sums, bins = 10)
plt.yscale('log')


In [None]:
# plot statistics of cells per lineage barcode 
plt.hist(sparse_cell_sums, bins = 50)
plt.yscale('log')


In [None]:
import igraph as ig
import numpy as np
import multiprocessing as mp
from sklearn.metrics import jaccard_score
from scipy.sparse import csr_matrix, coo_matrix, csc_matrix

mp.set_start_method("fork", force=True)

def convert_to_sparse_matrix(mtx):
    """Convert input matrix to a SciPy CSR sparse matrix while preserving structure."""
    if isinstance(mtx, (coo_matrix, csc_matrix)):
        return mtx.tocsr()
    elif not isinstance(mtx, csr_matrix):
        return csr_matrix(mtx)
    return mtx

def jaccard_similarity(set1, set2):
    return jaccard_score(np.isin(range(max(max(set1, default=0), max(set2, default=0)) + 1), set1).astype(int),
                          np.isin(range(max(max(set1, default=0), max(set2, default=0)) + 1), set2).astype(int))

def process_barcode(lbc_idx, mtx):
    g_idx = mtx[:, lbc_idx].nonzero()[0]
    g_idx_mtx = mtx[g_idx, :]
    ncells_per_bc = g_idx_mtx.shape[0]
    
    edge_list = []
    edge_attrs = []
    weight_list = []
    jac_list = []
    
    for j in range(ncells_per_bc - 1):
        for k in range(j + 1, ncells_per_bc):
            set1 = g_idx_mtx[j].nonzero()[1]
            set2 = g_idx_mtx[k].nonzero()[1]
            
            bc_frac = 2 * len(set(set1).intersection(set2)) / (len(set1) + len(set2))
            jac_frac = jaccard_similarity(set1, set2)
            
            edge_list.append((g_idx[j], g_idx[k]))
            edge_attrs.append(lbc_idx)
            weight_list.append(bc_frac)
            jac_list.append(jac_frac)
    
    return edge_list, edge_attrs, weight_list, jac_list

def build_graph_from_sparse_mtx_w_pattern_weights(mtx, n_cores, row_names=None, col_names=None):
    mtx = convert_to_sparse_matrix(mtx)
    
    num_nodes = mtx.shape[0]
    g = ig.Graph(directed=False)
    g.add_vertices(num_nodes)
    
    if row_names is not None:
        g.vs['cellID'] = row_names
    else:
        g.vs['cellID'] = list(range(num_nodes))
    
    valid_cols = np.where(mtx.sum(axis=0) > 1)[1]
    clean_mtx = mtx[:, valid_cols]
    lbc_vec = valid_cols

    if __name__ == "__main__":
        with mp.Pool(n_cores) as pool:
            results = pool.starmap(process_barcode, [(lbc_idx, mtx) for lbc_idx in lbc_vec])
    
        edges, edge_attrs, weights, jaccards = [], [], [], []
        for res in results:
            edges.extend(res[0])
            edge_attrs.extend(res[1])
            weights.extend(res[2])
            jaccards.extend(res[3])
    
        g.add_edges(edges)
    
        if col_names is not None:
            g.es['lbc_idx'] = [col_names[idx] for idx in edge_attrs]
        else:
            g.es['lbc_idx'] = edge_attrs
    
        g.es['weight'] = weights
        g.es['jaccard_metric'] = jaccards
    
    return g


In [None]:
g = build_graph_from_sparse_mtx_w_pattern_weights(sparse_matrix, row_names = sparse_matrix.index, col_names = sparse_matrix.columns, n_cores = 8 )

In [None]:
# remove redundant edges in the network (since edge weights contain information on 
# overlap of barcode signature, we only need one edge with the computed jaccard weight)
g = g.simplify(multiple=True, combine_edges=dict(weight="mean", jaccard_metric = 'mean'))

In [None]:
# Examine the distribution of similarity scores among cell pairs in the data
plt.hist(g.es['jaccard_metric'], bins = 101)
plt.xlim(0,1.1)

In [None]:
 # Function to apply threshold and extract clones (filtered components)
def network_patterning(gr: GraphBase, threshold: float) -> pd.DataFrame:
    if threshold == 1:
        exclusion_edges = [e.index for e in gr.es if e['weight'] < threshold]
    elif threshold < 1:
        exclusion_edges = [e.index for e in gr.es if e['weight'] <= threshold]
    else:
        exclusion_edges = []

    filtered_graph = gr.copy()
    filtered_graph.delete_edges(exclusion_edges)
    components = filtered_graph.components()

    clone_data = {
        'cloneID': [],
        'cellbc': []
    }
    for idx, comp in enumerate(components):
        for node_idx in comp:
            clone_data['cloneID'].append(idx)
            clone_data['cellbc'].append(filtered_graph.vs[node_idx]['cellID'])

    return pd.DataFrame(clone_data).dropna()

In [None]:
def _run_threshold(args):
    gr, threshold = args
    clones = network_patterning(gr, threshold)
    clones = clones.rename(columns={'cloneID': f'cloneID_{threshold}'})
    return clones


def iterate_clone_pattern_parallel(gr: GraphBase, test_thresholds: list, max_workers=None) -> pd.DataFrame:
    all_clones = network_patterning(gr, 1)
    all_clones = all_clones.rename(columns={'cloneID': 'cloneID_1'})

    args_list = [(gr.copy(), threshold) for threshold in test_thresholds]

    with ProcessPoolExecutor(max_workers=max_workers or multiprocessing.cpu_count()) as executor:
        results = list(tqdm(executor.map(_run_threshold, args_list), total=len(args_list), desc="Processing thresholds"))

    for clone_df in results:
        all_clones = pd.merge(all_clones, clone_df, on='cellbc', how='outer')

    return all_clones

In [None]:
def summarize_clone_lists(iterative_clone_df: pd.DataFrame) -> pd.DataFrame:
    from numpy import percentile

    clone_cols = iterative_clone_df.columns[1:]
    summary_list = []

    for col in tqdm(clone_cols, desc="Summarizing clone lists"):
        data = iterative_clone_df[col].dropna()
        if data.empty:
            summary_list.append({'run': col, **{k: 0 for k in [
                'length', 'biggest_clone_size', 'number_single', 'number_multicells',
                'avg_clone_size', 'top5_avg', 'top10_avg', 'top25_avg',
                'median_clone_size', 'clone_size_variance',
                'top5_variance', 'top10_variance', 'top25_variance']}})
            continue

        counts = data.value_counts().values
        multicells = counts[counts > 1]

        top5, top10, top25 = percentile(multicells, [95, 90, 75]) if len(multicells) > 0 else (0, 0, 0)

        def stats_above_thresh(vals, thresh):
            subset = vals[vals > thresh]
            return subset.mean() if subset.size > 0 else 0, subset.var() if subset.size > 0 else 0

        top5_avg, top5_var = stats_above_thresh(multicells, top5)
        top10_avg, top10_var = stats_above_thresh(multicells, top10)
        top25_avg, top25_var = stats_above_thresh(multicells, top25)

        summary_list.append({
            'run': col,
            'length': len(np.unique(data)),
            'biggest_clone_size': counts.max(),
            'number_single': (counts == 1).sum(),
            'number_multicells': (counts > 1).sum(),
            'avg_clone_size': multicells.mean() if len(multicells) > 0 else 0,
            'top5_avg': top5_avg,
            'top10_avg': top10_avg,
            'top25_avg': top25_avg,
            'median_clone_size': np.median(multicells) if len(multicells) > 0 else 0,
            'clone_size_variance': multicells.var() if len(multicells) > 0 else 0,
            'top5_variance': top5_var,
            'top10_variance': top10_var,
            'top25_variance': top25_var,
        })

    summary_df = pd.DataFrame(summary_list)
    summary_df['threshold'] = summary_df['run'].str.extract(r'(\d+\.?\d*)').astype(float)
    return summary_df

In [None]:
# Example usage:
test_thresholds = np.arange(0, 1.0, 0.001)
iterative_clone_df = iterate_clone_pattern_parallel(g, test_thresholds, max_workers=4)


In [None]:
summary_df = summarize_clone_lists(iterative_clone_df)
print(summary_df.head())

In [None]:
iterative_clone_df

In [None]:
summary_df

In [None]:
summary_df_long = summary_df.melt(id_vars = 'run')

In [None]:
summary_df_long 

In [None]:
# Clone size variance is one of the best heuristic indicators of noise vs. quality data 
# The higher quality of TrackerSeq library (by high diversity & even representation) AND
# the better the single cell preparation (e.g. risk of doublets, minimization of RNA dropouts),
# less noise will be observed and the threshold jaccard metric can be lowered using these heuristics

# Lower quality datasets will have massive variances that drop off precipitously (like a knee), 
# while high quality datasets may not have such a knee but variance will be low even at the low 
# end of the jaccard scores -- use the initial distribution of pairwise scores as a sanity check 
# when interpreting these values! 

sns.relplot(
    data=summary_df,
    x="threshold", y="clone_size_variance",
    kind="line",
    height=3, aspect=1.5, facet_kws=dict(sharex=False)
)

#plt.gca().yaxis.set_major_formatter(PercentFormatter(xmax=1))  # xmax=1 assumes your values are between 0 and 1
#plt.ylim(0,10)
#plt.xscale('log')
#plt.yscale('log')
plt.title('Collapsed barcode representation')

plt.show()

In [None]:
sns.relplot(
    data=summary_df,
    x="threshold", y="biggest_clone_size",
    kind="line",
    height=3, aspect=1.5, facet_kws=dict(sharex=False)
)

#plt.gca().yaxis.set_major_formatter(PercentFormatter(xmax=1))  # xmax=1 assumes your values are between 0 and 1
#plt.ylim(0,100)
#plt.xscale('log')
#plt.yscale('log')
plt.title('Collapsed barcode representation')

plt.show()

In [None]:
sns.relplot(
    data=summary_df,
    x="threshold", y="top10_avg",
    kind="line",
    height=3, aspect=1.5, facet_kws=dict(sharex=False)
)

#plt.gca().yaxis.set_major_formatter(PercentFormatter(xmax=1))  # xmax=1 assumes your values are between 0 and 1
#plt.ylim(0,100)
#plt.xscale('log')
#plt.yscale('log')
plt.title('Collapsed barcode representation')

plt.show()

In [None]:
sns.relplot(
    data=summary_df,
    x="threshold", y="number_multicells",
    kind="line",
    height=3, aspect=1.5, facet_kws=dict(sharex=False)
)

#plt.gca().yaxis.set_major_formatter(PercentFormatter(xmax=1))  # xmax=1 assumes your values are between 0 and 1
#plt.ylim(0,100)
#plt.xscale('log')
#plt.yscale('log')
plt.title('Collapsed barcode representation')

plt.show()

In [None]:
sns.relplot(
    data=summary_df,
    x="threshold", y="avg_clone_size",
    kind="line",
    height=3, aspect=1.5, facet_kws=dict(sharex=False)
)

#plt.gca().yaxis.set_major_formatter(PercentFormatter(xmax=1))  # xmax=1 assumes your values are between 0 and 1
#plt.ylim(0,100)
#plt.xscale('log')
plt.yscale('log')
plt.title('Collapsed barcode representation')

plt.show()

In [None]:
# select a threshold for filtering network edges based on jaccard metric, 
# using heuristics above to minimize noise while preserving clonal information
# Toptier = perfect barcode matches only
toptier_clones = iterative_clone_df[['cellbc','cloneID_1']]

In [None]:
toptier_clones.to_csv('toptier_clones.csv')

In [None]:
# Best practice is to generate and test multiple cutoffs and examine effects on clonal 
# analysis e.g. counts of shared clones across annotated groups, # implausible clones, 
# z-score coupling and correlation analysis
clones_zero_point_5 = iterative_clone_df[['cellbc','cloneID_0.5']]

In [None]:
clones_zero_point_5.to_csv('clones_0.5.csv')

In [None]:
clones_zero_point_2 = iterative_clone_df[['cellbc','cloneID_0.2']]

In [None]:
clones_zero_point_2.to_csv('clones_0.2.csv')