In [1]:
import os
import pandas as pd
import numpy as np
import uuid

from utils.funcs import get_model, get_embeddings, upsert_to_index
from tqdm.autonotebook import tqdm
from pinecone import Pinecone

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tcrs = pd.read_parquet("data/antigen_specific_tcrs.parquet")
print(tcrs.shape)
tcrs.head()

(48540, 9)


Unnamed: 0,Species,Antigen Epitope,Antigen Protein,Antigen Source,CDR3.beta.aa,TRBV,TRBJ,Reference,Database
0,Mouse,IKAVYNFATCG,Pre-glycoprotein polyprotein GP complex,LCMV,CASSDAGANTEVF,TRBV8-1,TRBJ1-1,PMID:1716213,McPAS-TCR
1,Mouse,IKAVYNFATCG,Pre-glycoprotein polyprotein GP complex,LCMV,CASSDAGAYAEQF,TRBV8-1,TRBJ2-1,PMID:1716213,McPAS-TCR
2,Mouse,IKAVYNFATCG,Pre-glycoprotein polyprotein GP complex,LCMV,CASSDAGGAAEVF,TRBV8-3,TRBJ1-1,PMID:1716213,McPAS-TCR
3,Mouse,IKAVYNFATCG,Pre-glycoprotein polyprotein GP complex,LCMV,CASSDAGHSPLYF,TRBV8-1,TRBJ1-6,PMID:1716213,McPAS-TCR
4,Mouse,IKAVYNFATCG,Pre-glycoprotein polyprotein GP complex,LCMV,CASSDAWGGAEQYF,TRBV8-3,TRBJ2-6,PMID:1716213,McPAS-TCR


In [3]:
tcrs["CDR3.beta.aa"].str.len().max()

38

In [4]:
tcrs["CDR3.beta.aa"].duplicated().any()

False

In [5]:
tcrs["Species"].value_counts()

Species
Human    45750
Mouse     2790
Name: count, dtype: int64

In [6]:
tcrs["Antigen Source"].value_counts()[:5]

Antigen Source
CMV           19441
SARS-CoV-2     5000
Influenza      4226
EBV            4104
InfluenzaA     3585
Name: count, dtype: int64

In [7]:
# downsample CMV TCRs
tcrs_cmv = tcrs.loc[tcrs["Antigen Source"] == "CMV"]
tcrs_other = tcrs.loc[tcrs["Antigen Source"] != "CMV"]

tcrs_cmv = tcrs_cmv.sample(n=5000, random_state=0)
tcrs = pd.concat([tcrs_cmv, tcrs_other], ignore_index=True)
print(tcrs.shape)
tcrs.head()

(34099, 9)


Unnamed: 0,Species,Antigen Epitope,Antigen Protein,Antigen Source,CDR3.beta.aa,TRBV,TRBJ,Reference,Database
0,Human,KLGGALQAK,IE1,CMV,CASSPKTSVTYNEQFF,TRBV7-9*01,TRBJ2-1*01,https://www.10xgenomics.com/resources/applicat...,VDJdb
1,Human,NLVPMVATV,pp65,CMV,CASSLDSLNTIYF,TRBV5-1*01,TRBJ1-3*01,PMID:28423320,VDJdb
2,Human,KLGGALQAK,IE1,CMV,CASSSRTSSTDTQYF,TRBV12-4*01,TRBJ2-3*01,https://www.10xgenomics.com/resources/applicat...,VDJdb
3,Human,KLGGALQAK,IE1,CMV,CSSESGTSEAFF,TRBV29-1*01,TRBJ1-1*01,https://www.10xgenomics.com/resources/applicat...,VDJdb
4,Human,KLGGALQAK,IE1,CMV,CSVEYGLAGSTDTQYF,TRBV29-1*01,TRBJ2-3*01,https://www.10xgenomics.com/resources/applicat...,VDJdb


In [8]:
pubmed_base_url = "https://pubmed.ncbi.nlm.nih.gov/"


# pubmed link
def format_link(s):
    if s.startswith("PMID:"):
        return s.replace("PMID:", pubmed_base_url)
    return s


tcrs["Reference"] = tcrs["Reference"].map(format_link)
tcrs["Reference"].iloc[0]

'https://www.10xgenomics.com/resources/application-notes/a-new-way-of-exploring-immunity-linking-highly-multiplexed-antigen-recognition-to-immune-repertoire-and-phenotype/#'

In [9]:
tokenizer, model = get_model("facebook/esm2_t6_8M_UR50D")

2024-05-17 22:36:16.137 
  command:

    streamlit run /home/ytian/anaconda3/envs/receptorgpt/lib/python3.11/site-packages/ipykernel_launcher.py [ARGUMENTS]
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
embeddings = []
batch_size = 100
seqs = tcrs["CDR3.beta.aa"].tolist()

for i in tqdm(range(0, len(seqs), batch_size)):
    embeds = get_embeddings(tokenizer, model, seqs[i : i + batch_size])
    embeddings.append(embeds)

embeddings = np.concatenate(embeddings, axis=0)
embeddings.shape

100%|██████████| 341/341 [00:28<00:00, 11.90it/s]


(34099, 320)

In [11]:
# convert to list
embeddings = embeddings.tolist()
len(embeddings)

34099

In [12]:
# prepare metadata
metadatas = tcrs.to_dict("records")
metadatas[:1]

[{'Species': 'Human',
  'Antigen Epitope': 'KLGGALQAK',
  'Antigen Protein': 'IE1',
  'Antigen Source': 'CMV',
  'CDR3.beta.aa': 'CASSPKTSVTYNEQFF',
  'TRBV': 'TRBV7-9*01',
  'TRBJ': 'TRBJ2-1*01',
  'Reference': 'https://www.10xgenomics.com/resources/application-notes/a-new-way-of-exploring-immunity-linking-highly-multiplexed-antigen-recognition-to-immune-repertoire-and-phenotype/#',
  'Database': 'VDJdb'}]

In [13]:
# merge the three lists int a single list of dictionaries
# use uuid to generate unique ids
records = [
    {"id": str(uuid.uuid4()), "values": e, "metadata": m}
    for e, m in zip(embeddings, metadatas)
]

In [14]:
len(records)

34099

In [15]:
records[0]

{'id': '3e4da022-1f56-4f8f-98d7-6f076fb939c2',
 'values': [-0.000996788963675499,
  -0.2494887411594391,
  0.1272898018360138,
  0.11934084445238113,
  0.022819960489869118,
  -0.036739613860845566,
  -0.17488379776477814,
  -0.08020184189081192,
  -0.19916585087776184,
  -0.0760434940457344,
  -0.08207184076309204,
  0.37416425347328186,
  0.1361832320690155,
  0.09151215851306915,
  -0.21137793362140656,
  0.03580343350768089,
  0.15209656953811646,
  -0.04301324486732483,
  0.18764396011829376,
  -0.2048741579055786,
  -0.12977296113967896,
  0.05499275028705597,
  0.10816902667284012,
  0.021946214139461517,
  0.09488830715417862,
  0.24555286765098572,
  -0.0034690089523792267,
  -0.033150218427181244,
  0.08120252192020416,
  -0.029781559482216835,
  0.03547398000955582,
  -0.005609595216810703,
  0.05505049601197243,
  -0.23076193034648895,
  0.1011500209569931,
  -0.45892781019210815,
  0.2678738534450531,
  -0.12745873630046844,
  0.32290250062942505,
  0.1663869470357895,
  -

In [16]:
# pinecone client
pc = Pinecone(api_key=os.environ["PINECONE_API_KEY"])
pc_index = pc.Index("tcrs")

In [17]:
upsert_to_index(pc_index, records)

100%|██████████| 341/341 [01:34<00:00,  3.62it/s]


In [22]:
print(pc_index.describe_index_stats())

{'dimension': 320,
 'index_fullness': 0.0,
 'namespaces': {'': {'vector_count': 34099}},
 'total_vector_count': 34099}


In [19]:
import pickle

In [20]:
# save embeddings for evaluation
with open("data/tcr_embeddings.pkl", "wb") as f:
    pickle.dump(embeddings, f)

In [21]:
# save dataframe for evaluation
tcrs.to_parquet("data/tcrs_final.parquet")