In [1]:
import os
import pandas as pd
import numpy as np
import uuid

from utils.funcs import get_model, get_embeddings, upsert_to_index
from tqdm.autonotebook import tqdm
from pinecone import Pinecone

In [2]:
tcrs = pd.read_parquet("data/antigen_specific_tcrs.parquet")
print(tcrs.shape)
tcrs.head()

(48540, 9)


Unnamed: 0,Species,Antigen Epitope,Antigen Protein,Antigen Source,CDR3.beta.aa,TRBV,TRBJ,Reference,Database
0,Mouse,IKAVYNFATCG,Pre-glycoprotein polyprotein GP complex,LCMV,CASSDAGANTEVF,TRBV8-1,TRBJ1-1,PMID:1716213,McPAS-TCR
1,Mouse,IKAVYNFATCG,Pre-glycoprotein polyprotein GP complex,LCMV,CASSDAGAYAEQF,TRBV8-1,TRBJ2-1,PMID:1716213,McPAS-TCR
2,Mouse,IKAVYNFATCG,Pre-glycoprotein polyprotein GP complex,LCMV,CASSDAGGAAEVF,TRBV8-3,TRBJ1-1,PMID:1716213,McPAS-TCR
3,Mouse,IKAVYNFATCG,Pre-glycoprotein polyprotein GP complex,LCMV,CASSDAGHSPLYF,TRBV8-1,TRBJ1-6,PMID:1716213,McPAS-TCR
4,Mouse,IKAVYNFATCG,Pre-glycoprotein polyprotein GP complex,LCMV,CASSDAWGGAEQYF,TRBV8-3,TRBJ2-6,PMID:1716213,McPAS-TCR


In [3]:
tcrs["CDR3.beta.aa"].str.len().max()

38

In [4]:
tcrs["CDR3.beta.aa"].duplicated().any()

False

In [5]:
tcrs["Species"].value_counts()

Species
Human    45750
Mouse     2790
Name: count, dtype: int64

In [6]:
tcrs["Antigen Source"].value_counts()[:5]

Antigen Source
CMV           19441
SARS-CoV-2     5000
Influenza      4226
EBV            4104
InfluenzaA     3585
Name: count, dtype: int64

In [7]:
tcrs["Antigen Epitope"].value_counts()[:10]

Antigen Epitope
KLGGALQAK         12093
NLVPMVATV          4885
GILGFVFTL          3799
LPRRSGAAGA         1933
ELAGIGILTV         1401
GLCTLVAML          1329
AVFDRKSDAK          997
YLQPRTFLL           886
TFEYVSQPFLMDLE      779
LLWNGPMAV           775
Name: count, dtype: int64

In [8]:
tcrs["Antigen Source"].nunique()

112

In [9]:
tcrs["Antigen Epitope"].nunique()

1274

In [10]:
# downsample KLGGALQAK TCRs
tcrs_klg = tcrs.loc[tcrs["Antigen Epitope"] == "KLGGALQAK"]
tcrs_other = tcrs.loc[tcrs["Antigen Epitope"] != "KLGGALQAK"]

tcrs_klg = tcrs_klg.sample(n=5000, random_state=0)
tcrs = pd.concat([tcrs_klg, tcrs_other], ignore_index=True)
# shuffle
tcrs.sample(frac=1, random_state=0)
print(tcrs.shape)
tcrs.head()

(41447, 9)


Unnamed: 0,Species,Antigen Epitope,Antigen Protein,Antigen Source,CDR3.beta.aa,TRBV,TRBJ,Reference,Database
0,Human,KLGGALQAK,IE1,CMV,CASTPGLALNNEQFF,TRBV19*01,TRBJ2-1*01,https://www.10xgenomics.com/resources/applicat...,VDJdb
1,Human,KLGGALQAK,IE1,CMV,CSARGLSSYEQYF,TRBV20-1*01,TRBJ2-7*01,https://www.10xgenomics.com/resources/applicat...,VDJdb
2,Human,KLGGALQAK,IE1,CMV,CASSSMLTEKLFF,TRBV11-2*01,TRBJ1-4*01,https://www.10xgenomics.com/resources/applicat...,VDJdb
3,Human,KLGGALQAK,IE1,CMV,CASSVEGTQYF,TRBV9*01,TRBJ2-3*01,https://www.10xgenomics.com/resources/applicat...,VDJdb
4,Human,KLGGALQAK,IE1,CMV,CASSLSAGGHFYEQYF,TRBV27*01,TRBJ2-7*01,https://www.10xgenomics.com/resources/applicat...,VDJdb


In [11]:
tcrs["Antigen Epitope"].value_counts()[:10]

Antigen Epitope
KLGGALQAK         5000
NLVPMVATV         4885
GILGFVFTL         3799
LPRRSGAAGA        1933
ELAGIGILTV        1401
GLCTLVAML         1329
AVFDRKSDAK         997
YLQPRTFLL          886
TFEYVSQPFLMDLE     779
LLWNGPMAV          775
Name: count, dtype: int64

In [12]:
pubmed_base_url = "https://pubmed.ncbi.nlm.nih.gov/"


# pubmed link
def format_link(s):
    if s.startswith("PMID:"):
        return s.replace("PMID:", pubmed_base_url)
    return s


tcrs["Reference"] = tcrs["Reference"].map(format_link)
tcrs["Reference"].iloc[0]

'https://www.10xgenomics.com/resources/application-notes/a-new-way-of-exploring-immunity-linking-highly-multiplexed-antigen-recognition-to-immune-repertoire-and-phenotype/#'

In [13]:
# build a validation set from epitopes with more than threshold
tcrs_val = tcrs.groupby("Antigen Epitope").filter(lambda x: len(x) > 2000)
tcrs_val["Antigen Epitope"].value_counts()

Antigen Epitope
KLGGALQAK    5000
NLVPMVATV    4885
GILGFVFTL    3799
Name: count, dtype: int64

In [14]:
# downsample to TCRs per epitope
tcrs_val = tcrs_val.groupby("Antigen Epitope").sample(n=100, random_state=0)
tcrs_val.head()

Unnamed: 0,Species,Antigen Epitope,Antigen Protein,Antigen Source,CDR3.beta.aa,TRBV,TRBJ,Reference,Database
12529,Human,GILGFVFTL,Matrix protein (M1),Influenza,CASSILGKDTQYF,TRBV19,TRBJ2-3,https://pubmed.ncbi.nlm.nih.gov/28423320,McPAS-TCR
22837,Human,GILGFVFTL,M,InfluenzaA,CASSLLGFSDGGTGELFF,TRBV5-4*01,TRBJ2-2*01,https://pubmed.ncbi.nlm.nih.gov/28423320,VDJdb
12860,Human,GILGFVFTL,Matrix protein (M1),Influenza,CAISDLSITGGDNYGYTF,TRBV1-1,TRBJ1-2:01,https://pubmed.ncbi.nlm.nih.gov/28300170,McPAS-TCR
23099,Human,GILGFVFTL,M,InfluenzaA,CASSERRQGLGNQPQHF,TRBV10-1*01,TRBJ1-5*01,https://pubmed.ncbi.nlm.nih.gov/28423320,VDJdb
23717,Human,GILGFVFTL,M,InfluenzaA,CASNRREHDEQFF,TRBV19*01,TRBJ2-1*01,https://pubmed.ncbi.nlm.nih.gov/28423320,VDJdb


In [15]:
tcrs_val["Antigen Epitope"].value_counts()

Antigen Epitope
GILGFVFTL    100
KLGGALQAK    100
NLVPMVATV    100
Name: count, dtype: int64

In [16]:
# need to remove validation set from training set
tcrs = tcrs[~tcrs.index.isin(tcrs_val.index)]
print(tcrs.shape, tcrs_val.shape)

(41147, 9) (300, 9)


In [17]:
# reset index
tcrs = tcrs.reset_index(drop=True)
tcrs_val = tcrs_val.reset_index(drop=True)

In [18]:
tokenizer, model = get_model("facebook/esm2_t6_8M_UR50D")

2024-06-17 17:32:28.142 
  command:

    streamlit run /home/ytian/anaconda3/envs/receptorgpt/lib/python3.11/site-packages/ipykernel_launcher.py [ARGUMENTS]
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [19]:
tokenizer.all_special_tokens

['<eos>', '<unk>', '<pad>', '<cls>', '<mask>']

In [34]:
embeddings = []
batch_size = 100
seqs = tcrs["CDR3.beta.aa"].tolist()

for i in tqdm(range(0, len(seqs), batch_size)):
    embeds = get_embeddings(tokenizer, model, seqs[i : i + batch_size])
    embeddings.append(embeds)

embeddings = np.concatenate(embeddings, axis=0)
embeddings.shape

  0%|          | 0/412 [00:00<?, ?it/s]

(41147, 320)

In [21]:
# convert to list
embeddings = embeddings.tolist()
len(embeddings)

41147

In [22]:
# prepare metadata
metadatas = tcrs.to_dict("records")
metadatas[:1]

[{'Species': 'Human',
  'Antigen Epitope': 'KLGGALQAK',
  'Antigen Protein': 'IE1',
  'Antigen Source': 'CMV',
  'CDR3.beta.aa': 'CASTPGLALNNEQFF',
  'TRBV': 'TRBV19*01',
  'TRBJ': 'TRBJ2-1*01',
  'Reference': 'https://www.10xgenomics.com/resources/application-notes/a-new-way-of-exploring-immunity-linking-highly-multiplexed-antigen-recognition-to-immune-repertoire-and-phenotype/#',
  'Database': 'VDJdb'}]

In [23]:
# merge the three lists int a single list of dictionaries
# use uuid to generate unique ids
records = [
    {"id": str(uuid.uuid4()), "values": e, "metadata": m}
    for e, m in zip(embeddings, metadatas)
]

In [24]:
len(records)

41147

In [25]:
records[0]

{'id': '1584404c-461e-4d8f-a1c4-9e83ebefd075',
 'values': [0.14121215045452118,
  -0.3008734881877899,
  0.2475678026676178,
  0.08334764838218689,
  -0.019711311906576157,
  -0.12751878798007965,
  -0.1914614886045456,
  -0.02174030989408493,
  -0.1290866732597351,
  -0.1492486447095871,
  -0.0704425498843193,
  0.3020123243331909,
  0.13828590512275696,
  0.05327414721250534,
  -0.22204852104187012,
  0.08621577173471451,
  0.0815887525677681,
  -0.053235460072755814,
  0.21515080332756042,
  -0.1412162333726883,
  -0.18838292360305786,
  0.06571920216083527,
  0.07026966661214828,
  0.04556738957762718,
  0.09365881234407425,
  0.26238828897476196,
  0.018247194588184357,
  -0.07799090445041656,
  0.13413658738136292,
  -0.005087621044367552,
  -0.005955008324235678,
  -0.10034336894750595,
  0.09907011687755585,
  -0.25078582763671875,
  0.07476787269115448,
  -0.5109313726425171,
  0.23566466569900513,
  -0.08076376467943192,
  0.3819611668586731,
  0.14221638441085815,
  -0.29946

In [26]:
# pinecone client
pc = Pinecone(api_key=os.environ["PINECONE_API_KEY"])
pc_index = pc.Index("tcrs")

In [27]:
upsert_to_index(pc_index, records)

100%|██████████| 412/412 [02:06<00:00,  3.27it/s]


In [28]:
print(pc_index.describe_index_stats())

{'dimension': 320,
 'index_fullness': 0.0,
 'namespaces': {'': {'vector_count': 42600}},
 'total_vector_count': 42600}


In [29]:
import pickle

In [36]:
# save embeddings for evaluation
with open("data/tcr_embeddings.pkl", "wb") as f:
    pickle.dump(np.array(embeddings), f)

In [31]:
# save dataframe for evaluation
tcrs.to_parquet("data/tcrs_final.parquet")
tcrs_val.to_parquet("data/tcrs_val.parquet")

In [32]:
# embed validation set
seqs = tcrs_val["CDR3.beta.aa"].tolist()
embeddings = []
for i in tqdm(range(0, len(seqs), batch_size)):
    embeds = get_embeddings(tokenizer, model, seqs[i : i + batch_size])
    embeddings.append(embeds)

embeddings = np.concatenate(embeddings, axis=0)
embeddings.shape

  0%|          | 0/3 [00:00<?, ?it/s]

(300, 320)

In [33]:
with open("data/tcrs_val_embeddings.pkl", "wb") as f:
    pickle.dump(embeddings, f)