This notebook provides a tutorial to use the trained TRAP model weights to predict whether CDR3beta binds to a specified epitope. Before running the notebook, the user needs to install the environment. For installation of the environment, refer to the requirements.txt file or trap.tar.gz in https://zenodo.org/records/15062393

Step1. Candidate CDR3beta generation for virtual screening.

In [4]:
import os
import pandas as pd
os.system('olga-generate_sequences --humanTRB -o screen.tsv -n 20')
df = pd.read_csv('screen.tsv', sep='\t', header=None)
candidate_cdr = df.iloc[:,1].str.match('^C.*F$')
candidate_cdr = df[candidate_cdr].iloc[:,1].str[1:-1].tolist()

print(f'Found {len(candidate_cdr)} valid CDR3beta sequences')
print('First 5 sequences:', candidate_cdr[:5])

[0m

Starting sequence generation... 
Completed generating all 20 sequences in 0.00 seconds.
Found 20 valid CDR3beta sequences
First 5 sequences: ['SARRRARTEAF', 'ASSGLLAKNIQY', 'AQRSVSGANEKLF', 'ASSLAVLSHNSPLH', 'ASSLDSYVKNTQY']


Step2. Generation of CDR3 beta embedding features. (only keep the CDR3 betas which length <= 18)

In [None]:
import torch
from TRAP.cdr_gen import cdr_align
import pickle

b_max = 18
device = 0
dev = torch.device(f'cuda:{device}' if torch.cuda.is_available() else "cpu")

os.makedirs('./cdr_feat', exist_ok=True)

# esm
esm_model, esm_alphabet = torch.hub.load("facebookresearch/esm:main", 
                                            "esm2_t33_650M_UR50D")
esm_batch_converter = esm_alphabet.get_batch_converter()
esm_model = esm_model.to(dev)
esm_model.eval()

with torch.no_grad():
    for b_seq in candidate_cdr:
        dic = cdr_align(esm_model, esm_batch_converter, dev,
                                            b_seq,
                                            use_cpu=True, 
                                            b_max=b_max)
        with open('./cdr_feat/' + b_seq, 'wb') as f:
            pickle.dump(dic, f)

Using cache found in /home/gejingxuan/.cache/torch/hub/facebookresearch_esm_main


Step3.Generation of epitope embedding features. (eg.HLA-A*02:01 GLCTLVAML)

In [6]:
from TRAP.pmhc_gen import epi_feature_esm
dis_threshold1 = 5
dis_threshold2 = 8
dis_threshold3 = 15
epi_max = 12
epi = 'A_02_01_GLCTLVAML'
struc_path = '/home/gejingxuan/TCR/data/our_data/pmhc_struc'
file_path = './epi_feat'

os.makedirs(file_path, exist_ok=True)

device = 0
# dev = torch.device(f'cuda:{device}' if torch.cuda.is_available() else "cpu")
dev = torch.device("cpu")
esm_model, esm_alphabet = torch.hub.load("facebookresearch/esm:main", 
                                            "esm2_t33_650M_UR50D")
esm_batch_converter = esm_alphabet.get_batch_converter()
esm_model = esm_model.to(dev)
esm_model.eval()
with torch.no_grad():
    system_path = struc_path+'/'+epi
    epi_feat = epi_feature_esm(system_path, dis_threshold1, dis_threshold2, dis_threshold3, epi_max,
                               esm_model, esm_batch_converter, dev)
    with open(file_path+'/'+epi,'wb') as fo:
        pickle.dump(epi_feat, fo)

Using cache found in /home/gejingxuan/.cache/torch/hub/facebookresearch_esm_main
@> 62 atoms and 1 coordinate set(s) were parsed in 0.00s.
@> 2947 atoms and 1 coordinate set(s) were parsed in 0.03s.


Step3.5. Save candidate CDR3 and epitope information to CSV file

In [None]:
import pandas as pd

data = {
    'CDR3': candidate_cdr,
    'pmhc': [epi] * len(candidate_cdr),
    'labels': [0] * len(candidate_cdr),
    'idx': range(len(candidate_cdr))
}


df = pd.DataFrame(data)
df.to_csv('test.csv', index=False)
print('CSV file created with shape:', df.shape)

CSV file created with shape: (20, 4)


Step4. Inference

In [None]:
from TRAP.train import TrapDataset, DataLoaderX, collate_fn_v2, TransTRAP


batch_size=32
model_path = './2024-01-20_19_19_14_707076.pth'
test_dataset = TrapDataset(cdr_dir= './cdr_feat', epi_dir= './epi_feat', data_dirs='test.csv',
                        graph_ls_file='test.bin',
                        graph_dic_path='test', path_marker='/',
                        del_tmp_files=True, p_max=epi_max, b_max=b_max)

test_dataloader = DataLoaderX(test_dataset, batch_size, shuffle=False, num_workers=0,
                                      collate_fn=collate_fn_v2)

TRAP_Model = TransTRAP(p_max=epi_max, b_max=b_max, in_feat_b=1280, in_feat_p=1470, batch_size=batch_size)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
TRAP_Model.to(device)
TRAP_Model.load_state_dict(torch.load(model_path,
                                    map_location = device)['model_state_dict'])

TRAP_Model.eval()
pred = []
key = []

with torch.no_grad():
    for i_batch, batch in enumerate(test_dataloader):
            bgb, bgp, y, key_ = batch
            bgb, bgp, y = bgb.to(device), bgp.to(device), y.to(device)
            pred_, _= TRAP_Model(bgb, bgp)
            pred.extend(pred_.squeeze().data.cpu().numpy())
            key.extend(key_)

pd_te = pd.DataFrame({'key': key, 'test_pred': pred})
pd_te.to_csv('prediction.csv',index=False)

Loading previously saved dgl graphs and corresponding data...
