In [1]:
import pickle
import numpy as np

from utils.funcs import *
from tqdm.autonotebook import tqdm

In [2]:
with open("data/tcr_embeddings.pkl", "rb") as f:
    embeddings = pickle.load(f)
len(embeddings)

41147

In [3]:
len(embeddings[0])

320

In [4]:
tcrs = pd.read_parquet("data/tcrs_final.parquet")
tcrs.head()

Unnamed: 0,Species,Antigen Epitope,Antigen Protein,Antigen Source,CDR3.beta.aa,TRBV,TRBJ,Reference,Database
0,Human,KLGGALQAK,IE1,CMV,CASTPGLALNNEQFF,TRBV19*01,TRBJ2-1*01,https://www.10xgenomics.com/resources/applicat...,VDJdb
1,Human,KLGGALQAK,IE1,CMV,CSARGLSSYEQYF,TRBV20-1*01,TRBJ2-7*01,https://www.10xgenomics.com/resources/applicat...,VDJdb
2,Human,KLGGALQAK,IE1,CMV,CASSSMLTEKLFF,TRBV11-2*01,TRBJ1-4*01,https://www.10xgenomics.com/resources/applicat...,VDJdb
3,Human,KLGGALQAK,IE1,CMV,CASSVEGTQYF,TRBV9*01,TRBJ2-3*01,https://www.10xgenomics.com/resources/applicat...,VDJdb
4,Human,KLGGALQAK,IE1,CMV,CASSLSAGGHFYEQYF,TRBV27*01,TRBJ2-7*01,https://www.10xgenomics.com/resources/applicat...,VDJdb


In [5]:
print(len(embeddings), len(tcrs))

41147 41147


In [6]:
search_embedding(embeddings[0].tolist(), 1, tcrs.loc[0, "Species"])

{'ids': ['1584404c-461e-4d8f-a1c4-9e83ebefd075'],
 'metadatas': [{'Antigen Epitope': 'KLGGALQAK',
   'Antigen Protein': 'IE1',
   'Antigen Source': 'CMV',
   'CDR3.beta.aa': 'CASTPGLALNNEQFF',
   'Database': 'VDJdb',
   'Reference': 'https://www.10xgenomics.com/resources/application-notes/a-new-way-of-exploring-immunity-linking-highly-multiplexed-antigen-recognition-to-immune-repertoire-and-phenotype/#',
   'Species': 'Human',
   'TRBJ': 'TRBJ2-1*01',
   'TRBV': 'TRBV19*01'}],
 'scores': [1.00276697]}

In [7]:
def top_match_cdr3(embedding, species):
    try:
        result = search_embedding(embedding.tolist(), 1, species)
        pred_cdr3 = result["metadatas"][0]["CDR3.beta.aa"]
    except:
        pred_cdr3 = np.nan
    return pred_cdr3

In [8]:
from joblib import Parallel, delayed

In [9]:
# run on all embeddings in parallel
pred_cdr3s = Parallel(n_jobs=-1)(
    delayed(top_match_cdr3)(embedding, tcrs.loc[i, "Species"])
    for i, embedding in tqdm(enumerate(embeddings), total=len(embeddings))
)

  0%|          | 0/41147 [00:00<?, ?it/s]

In [10]:
true_cdr3s = tcrs["CDR3.beta.aa"].tolist()

In [11]:
print(len(pred_cdr3s), len(true_cdr3s))

41147 41147


In [12]:
accuracy = (np.array(true_cdr3s) == np.array(pred_cdr3s)).mean()
print(f"Accuracy: {accuracy*100:.1f}%")

Accuracy: 97.8%


In [156]:
# validation set
tcrs_val = pd.read_parquet("data/tcrs_val.parquet")

# validation embeddings
with open("data/tcrs_val_embeddings.pkl", "rb") as f:
    embeddings_val = pickle.load(f)

In [157]:
tcrs_val.head()

Unnamed: 0,Species,Antigen Epitope,Antigen Protein,Antigen Source,CDR3.beta.aa,TRBV,TRBJ,Reference,Database
0,Human,GILGFVFTL,Matrix protein (M1),Influenza,CASSILGKDTQYF,TRBV19,TRBJ2-3,https://pubmed.ncbi.nlm.nih.gov/28423320,McPAS-TCR
1,Human,GILGFVFTL,M,InfluenzaA,CASSLLGFSDGGTGELFF,TRBV5-4*01,TRBJ2-2*01,https://pubmed.ncbi.nlm.nih.gov/28423320,VDJdb
2,Human,GILGFVFTL,Matrix protein (M1),Influenza,CAISDLSITGGDNYGYTF,TRBV1-1,TRBJ1-2:01,https://pubmed.ncbi.nlm.nih.gov/28300170,McPAS-TCR
3,Human,GILGFVFTL,M,InfluenzaA,CASSERRQGLGNQPQHF,TRBV10-1*01,TRBJ1-5*01,https://pubmed.ncbi.nlm.nih.gov/28423320,VDJdb
4,Human,GILGFVFTL,M,InfluenzaA,CASNRREHDEQFF,TRBV19*01,TRBJ2-1*01,https://pubmed.ncbi.nlm.nih.gov/28423320,VDJdb


In [158]:
tcrs_val["Antigen Source"] = tcrs_val["Antigen Source"].map(
    lambda x: "Influenza" if x.startswith("Influenza") else x
)

In [159]:
tcrs_val["Antigen Source"].value_counts()

Antigen Source
CMV          200
Influenza    100
Name: count, dtype: int64

In [160]:
tcrs_val.head()

Unnamed: 0,Species,Antigen Epitope,Antigen Protein,Antigen Source,CDR3.beta.aa,TRBV,TRBJ,Reference,Database
0,Human,GILGFVFTL,Matrix protein (M1),Influenza,CASSILGKDTQYF,TRBV19,TRBJ2-3,https://pubmed.ncbi.nlm.nih.gov/28423320,McPAS-TCR
1,Human,GILGFVFTL,M,Influenza,CASSLLGFSDGGTGELFF,TRBV5-4*01,TRBJ2-2*01,https://pubmed.ncbi.nlm.nih.gov/28423320,VDJdb
2,Human,GILGFVFTL,Matrix protein (M1),Influenza,CAISDLSITGGDNYGYTF,TRBV1-1,TRBJ1-2:01,https://pubmed.ncbi.nlm.nih.gov/28300170,McPAS-TCR
3,Human,GILGFVFTL,M,Influenza,CASSERRQGLGNQPQHF,TRBV10-1*01,TRBJ1-5*01,https://pubmed.ncbi.nlm.nih.gov/28423320,VDJdb
4,Human,GILGFVFTL,M,Influenza,CASNRREHDEQFF,TRBV19*01,TRBJ2-1*01,https://pubmed.ncbi.nlm.nih.gov/28423320,VDJdb


In [161]:
import time

In [162]:
# for each validation tcr embeddings, get the top 20 matches
# need the Antigen Epitope of the top10 matches to calculate metrics
# precision@k and MMR
k = 20
pred_epitopes = []

# time
start = time.time()

for i in tqdm(tcrs_val.index.tolist()):
    embedding = embeddings_val[i]
    result = search_embedding(embedding.tolist(), k, tcrs_val.loc[i, "Species"])
    pred_epitopes.append([result["metadatas"][j]["Antigen Source"] for j in range(k)])

end = time.time()
# average time per tcr
print(f"Time per TCR: {(end-start)/len(tcrs_val):.2f} seconds")

  0%|          | 0/300 [00:00<?, ?it/s]

Time per TCR: 0.36 seconds


In [163]:
len(pred_epitopes)

300

In [164]:
len(pred_epitopes[0])

20

In [165]:
for pred in pred_epitopes:
    for i, epitope in enumerate(pred):
        if epitope.startswith("Influenza"):
            # clearn up
            pred[i] = "Influenza"

In [166]:
pred_epitopes[0]

['CMV',
 'Influenza',
 'Influenza',
 'EBV',
 'SARS-CoV-2',
 'Influenza',
 'DENV1',
 'SARS-CoV-2',
 'DENV3/4',
 'Influenza',
 'CMV',
 'Neoantigen',
 'EBV',
 'SARS-CoV-2',
 'CMV',
 'CMV',
 'EBV',
 'HIV-1',
 'SARS-CoV-2',
 'HCV']

In [167]:
# calculate MMR
mmr = []
for i in tcrs_val.index.tolist():
    true_epitope = tcrs_val.loc[i, "Antigen Source"]
    # if true epitope is not in the top k, mmr = 0
    if true_epitope not in pred_epitopes[i]:
        mmr.append(0)
    else:
        mmr.append(1 / (1 + pred_epitopes[i].index(true_epitope)))

In [168]:
print(f"MMR: {np.mean(mmr):.2f}")

MMR: 0.53


In [169]:
print(f"On average, TCR with the same epitope found at position {1/np.mean(mmr):.0f}")

On average, TCR with the same epitope found at position 2


In [170]:
# precision@k
k = 10
precision = []

for i in tcrs_val.index.tolist():
    true_epitope = tcrs_val.loc[i, "Antigen Source"]
    num_correct = 0
    for j in range(k):
        if pred_epitopes[i][j] == true_epitope:
            num_correct += 1

    precision.append(num_correct / k)

In [171]:
print(f"Precision@{k}: {np.mean(precision):.2f}")

Precision@10: 0.32


In [175]:
print(f"Precision@{k}: {np.mean(precision[:100]):.2f}")

Precision@10: 0.29


In [173]:
print(f"Precision@{k}: {np.mean(precision[100:200]):.2f}")

Precision@10: 0.32


In [174]:
print(f"Precision@{k}: {np.mean(precision[200:]):.2f}")

Precision@10: 0.34
