In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pyrepseq as prs
import pyrepseq.metric.tcr_metric as tm

import warnings
# Filter out clustcr warnings
warnings.filterwarnings('ignore')

from clustcr.clustering.clustering import ClusteringResult

plt.style.use('seaborn-v0_8-paper')

# download VDJdb reference at https://github.com/antigenomics/vdjdb-db/releases
## used here: vdjdb-2024-06-13.zip

vdjdb_path = 'data/vdjdb_full.txt' # To be filled in with path to vdjdb data folder

ModuleNotFoundError: No module named 'clustcr'

In [None]:
vdjdb = pd.read_csv(vdjdb_path+'vdjdb_full.txt', sep='\t', low_memory=False)
vdjdb_beta = vdjdb[vdjdb['cdr3.beta'].apply(prs.isvalidcdr3)
                   & (vdjdb['species']=='HomoSapiens')][['cdr3.beta', 'v.beta', 'antigen.epitope']]
vdjdb_beta.rename(columns={'cdr3.beta': 'CDR3B',
                           'v.beta' : 'TRBV',
                           'antigen.epitope' : 'Epitope'}, inplace=True)
vdjdb_beta.drop_duplicates(['TRBV', 'CDR3B'], inplace=True)
vdjdb_beta.dropna(inplace=True)
vdjdb_beta = prs.standardize_dataframe(vdjdb_beta)
vdjdb_beta['TRBV'] = vdjdb_beta['TRBV']+"*01"


epitopes = vdjdb_beta['Epitope'].value_counts()
n = 220
epitopes = set(epitopes[epitopes>=n].index)
vdjdb_beta = vdjdb_beta[vdjdb_beta['Epitope'].isin(epitopes)]
vdjdb_beta = vdjdb_beta.groupby('Epitope').sample(n=n)
vdjdb_beta.reset_index(drop=True, inplace=True)
df = vdjdb_beta
seqs = df['CDR3B']

df.shape[0], len(df['Epitope'].unique())

In [None]:
def calculate_metrics(clustering, classes=df):
    clustering_expanded_clustcr = clustering.rename(columns={'node': 'junction_aa'}, inplace=False)
    epitopes_clustcr = classes.rename(columns={'CDR3B': 'junction_aa', 'Epitope' : 'epitope'}, inplace=False)
    metrics = ClusteringResult(clustering_expanded_clustcr).metrics(epitopes_clustcr)
    retention = metrics.retention()
    purity = metrics.purity()[0]
    return retention, purity

In [None]:
def metrics_vs_threshold(threshold, neighbors, clustering='cc', clustering_kwargs=dict()):
    adjacency_matrix = neighbors[neighbors[:, 2]<threshold]
    result = prs.graph_clustering(adjacency_matrix, seqs,
                                  clustering=clustering, **clustering_kwargs)
    return calculate_metrics(result)

In [None]:
pdists = prs.squareform(tm.BetaTcrdist().calc_pdist_vector(df))
neighbors_complete = []
for i in range(len(df)):
    for j in range(i+1, len(df)):
        neighbors_complete.append((i, j, pdists[i, j]))
neighbors = np.array(neighbors_complete)

In [None]:
pdists = prs.squareform(tm.BetaCdr3Levenshtein().calc_pdist_vector(df))
neighbors_complete = []
for i in range(len(df)):
    for j in range(i+1, len(df)):
        neighbors_complete.append((i, j, pdists[i, j]))
neighbors_cdr3_lev = np.array(neighbors_complete)

In [None]:
columns = ['threshold', 'retention', 'purity']
sweep = [(tcrdist_threshold, *metrics_vs_threshold(tcrdist_threshold, neighbors,
                                              clustering='cc'))
         for tcrdist_threshold in np.arange(2, 100, 2)]

sweep = pd.DataFrame(data=sweep, columns=columns)

In [None]:
sweep_cdr3_lev = [(threshold, *metrics_vs_threshold(threshold, neighbors_cdr3_lev,
                                              clustering='cc'))
         for threshold in np.arange(0, 12, 1)]

sweep_cdr3_lev = pd.DataFrame(data=sweep_cdr3_lev, columns=columns)

In [None]:
sweep_ml = [(tcrdist_threshold, *metrics_vs_threshold(tcrdist_threshold, neighbors,
                        clustering='leiden',
                        clustering_kwargs=dict(resolution=0.1, objective_function='CPM',
                        n_iterations=4)))
         for tcrdist_threshold in np.arange(2, 100, 2)]
sweep_ml = pd.DataFrame(data=sweep_ml, columns=columns)

In [None]:
mct_variants = [('CDR3, CC', sweep_cdr3_lev),
                ('TCRd, CC ', sweep),
                ('TCRd, Leiden', sweep_ml),
                ]

x, y  = 'retention', 'purity'

fig, ax = plt.subplots(figsize=(4.0, 4.0))


mct = []
for label, m in mct_variants:
    l, = ax.plot(m[x], m[y], '-o', ms=4, label=label)
    mct.append(l)


ax.set_ylabel(y)
ax.set_xlabel(x)
legend = ax.legend(handles=mct,
                    loc='upper right', fontsize='small')
ax.set_ylim(0.0, 1.0)
ax.set_xlim(0.0, 1.0)
fig.tight_layout(pad=0.5)