This document is used to help users learn how to use TRAP training and inference. Here we show the process from generating CDR3beta and epitope features to training TRAP and inference, and get the corresponding calculation metrics.

### Step1. Read the csv files used for training and inference, remove duplicates of all CDR3 and epitope, and prepare for the next step of generating input features.
- Here we take the random splitting scenario as an example, where `csv_path` needs to be replaced by the path where the user places the file. 
- The file under this path should be in a format similar to https://github.com/gejingxuan/TRAP/tree/main/TRAP/data_split/randomly, where `kfold.csv` and `train.csv` are the data used for model training, `kfold.csv` is the result of negative sample sampling, and `train.csv` is a subset of positive samples in `kfold.csv`, which will be used for binary classification training and contrastive learning training. The TCR-pMHC pair in `test.csv` is invisible to `train.csv` and `kfold.csv` to prevent data leakage, and is used to test model performance.

In [1]:
import pandas as pd
import glob
import os

# you can change the path to your own data
csv_path = '/home/gejingxuan/TCR/data/our_data/data_v1/randomly/*.csv'
all_files = glob.glob(csv_path)


all_cdr3 = []
all_pmhc = []
for file in all_files:
    df = pd.read_csv(file)
    if 'CDR3' in df.columns:
        all_cdr3.extend(df['CDR3'].tolist())
    if 'pmhc' in df.columns:
        all_pmhc.extend(df['pmhc'].tolist())


all_cdr3 = list(set(all_cdr3))
all_pmhc = list(set(all_pmhc))
print(f'Number of unique CDR3 sequences: {len(all_cdr3)}')
print(f'Number of unique pmhc: {len(all_pmhc)}')

Number of unique CDR3 sequences: 36380
Number of unique pmhc: 531


### Step2. Generation of CDR3 beta embedding features.
- We use the representation of ESM2 as the input feature of CDR3beta. Here we generate the features of each CDR3beta (dimension=1280) and temporarily store them in the `./cdr_feat_rep/` path. Users can modify the path according to their needs.

In [2]:
import torch
from TRAP.cdr_gen import get_input, cdr_align
import pickle

b_max = 18
device = 0
dev = torch.device(f'cuda:{device}' if torch.cuda.is_available() else "cpu")

os.makedirs('./cdr_feat_rep', exist_ok=True)

# esm
esm_model, esm_alphabet = torch.hub.load("facebookresearch/esm:main", 
                                            "esm2_t33_650M_UR50D")
esm_batch_converter = esm_alphabet.get_batch_converter()
esm_model = esm_model.to(dev)
esm_model.eval()

with torch.no_grad():
    for b_seq in all_cdr3:
        dic = cdr_align(esm_model, esm_batch_converter, dev,
                                            b_seq,
                                            use_cpu=True, 
                                            b_max=b_max)
        with open('./cdr_feat_rep/' + b_seq, 'wb') as f:
            pickle.dump(dic, f)

  from .autonotebook import tqdm as notebook_tqdm
Using cache found in /home/gejingxuan/.cache/torch/hub/facebookresearch_esm_main


### Step3.Generation of epitope embedding features.
- We uploaded the pMHC structure generated by AF2 in zenodo (https://zenodo.org/records/15062393), and users can download and replace the `struc_path` path. Similar to CDR3beta, the generated input features are temporarily stored in the `./epi_feat_rep` path

In [3]:
from TRAP.pmhc_gen import epi_feature_esm
dis_threshold1 = 5
dis_threshold2 = 8
dis_threshold3 = 15
epi_max = 12

struc_path = '/home/gejingxuan/TCR/data/our_data/pmhc_struc'
file_path = './epi_feat_rep'

os.makedirs(file_path, exist_ok=True)

device = 0
dev = torch.device("cpu")
esm_model, esm_alphabet = torch.hub.load("facebookresearch/esm:main", 
                                            "esm2_t33_650M_UR50D")
esm_batch_converter = esm_alphabet.get_batch_converter()
esm_model = esm_model.to(dev)
esm_model.eval()

with torch.no_grad():
    for epi in all_pmhc:
        system_path = struc_path+'/'+epi
        epi_feat = epi_feature_esm(system_path, dis_threshold1, dis_threshold2, dis_threshold3, epi_max,
                                   esm_model, esm_batch_converter, dev)
        with open(file_path+'/'+epi,'wb') as fo:
            pickle.dump(epi_feat, fo)

Using cache found in /home/gejingxuan/.cache/torch/hub/facebookresearch_esm_main
@> 74 atoms and 1 coordinate set(s) were parsed in 0.00s.
@> 2960 atoms and 1 coordinate set(s) were parsed in 0.03s.
@> 97 atoms and 1 coordinate set(s) were parsed in 0.00s.
@> 2936 atoms and 1 coordinate set(s) were parsed in 0.03s.
@> 83 atoms and 1 coordinate set(s) were parsed in 0.00s.
@> 2913 atoms and 1 coordinate set(s) were parsed in 0.03s.
@> 69 atoms and 1 coordinate set(s) were parsed in 0.00s.
@> 2923 atoms and 1 coordinate set(s) were parsed in 0.03s.
@> 63 atoms and 1 coordinate set(s) were parsed in 0.00s.
@> 2946 atoms and 1 coordinate set(s) were parsed in 0.06s.
@> 102 atoms and 1 coordinate set(s) were parsed in 0.00s.
@> 2983 atoms and 1 coordinate set(s) were parsed in 0.03s.
@> 91 atoms and 1 coordinate set(s) were parsed in 0.00s.
@> 2947 atoms and 1 coordinate set(s) were parsed in 0.03s.
@> 86 atoms and 1 coordinate set(s) were parsed in 0.00s.
@> 2971 atoms and 1 coordinate set

### Step4. Training and Inference
- Here we list the hyperparameters used in model training. In addition, `home_path` is the main path where model parameters and other files are stored when the task is running. Users can replace it with their own path. `cdr_dir` and `epi_dir` are the paths for generating features in the Step2&3, and `data_dir` is the path for storing csv mentioned in Step 1.

In [8]:
from TRAP.train import TrapDataset, DataLoaderX, collate_fn_v2, TransTRAP, run_a_eval_epoch, \
    run_clip_train_epoch, run_bi_train_epoch, cal_aupr
from TRAP.utils import *
import numpy as np
import datetime
import time
from torch.utils.data import DataLoader,Subset
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import KFold
  
# training parameters
gpuid = 0
lr = 10**-4
epochs = 500
batch_size = 256
tolerance = 0.0
patience = 30
l2 = 0.00

# model parameters
in_feat_b = 1280
in_feat_p = 1470
d_ff = 512
d_k = 128
d_v = 128
n_heads = 4
n_layers = 3
d_model = 200
dropout = 0.3
d_graph_layer = 200
d_FC_layer = 100
n_FC_layer = 2
n_tasks = 1

home_path = './rep_randomly'
cdr_dir = './cdr_feat_rep'
epi_dir = './epi_feat_rep'
data_dir = '/home/gejingxuan/TCR/data/our_data/data_v1/randomly'
path_marker = '/'
reuslt_path = '%s/result' %home_path

os.system('mkdir -p %s/result' %home_path)
os.system('mkdir -p %s/model_save' %reuslt_path)
os.system('mkdir -p %s/stats' %reuslt_path)


positive_dataset = TrapDataset(cdr_dir= cdr_dir, epi_dir= epi_dir, data_dirs=data_dir+'/train.csv',
                        graph_ls_file=home_path + path_marker +'train.bin',
                        graph_dic_path=home_path + path_marker + 'train', path_marker='/',
                        del_tmp_files=True, p_max=epi_max, b_max=b_max)

positive_idx = np.arange(0, len(positive_dataset))

kfold_dataset = TrapDataset(cdr_dir= cdr_dir, epi_dir= epi_dir, data_dirs=data_dir+'/kfold.csv',
                        graph_ls_file=home_path + path_marker + 'kfold.bin',
                        graph_dic_path=home_path + path_marker + 'kfold', path_marker='/',
                        del_tmp_files=True, p_max=epi_max, b_max=b_max)
    
test_dataset = TrapDataset(cdr_dir= cdr_dir, epi_dir= epi_dir, data_dirs=data_dir+'/test.csv',
                        graph_ls_file=home_path + path_marker + 'test.bin',
                        graph_dic_path=home_path + path_marker + 'test', path_marker='/',
                        del_tmp_files=True, p_max=epi_max, b_max=b_max)


kfold_dataloader = DataLoaderX(kfold_dataset, batch_size, shuffle=False, num_workers=0,
                                    collate_fn=collate_fn_v2)
test_dataloader = DataLoaderX(test_dataset, batch_size, shuffle=False, num_workers=0,
                                    collate_fn=collate_fn_v2)


torch.cuda.empty_cache()
dt = datetime.datetime.now()
filename = reuslt_path + path_marker + 'model_save/{}_{:02d}_{:02d}_{:02d}_{:d}.pth'.format(
    dt.date(), dt.hour, dt.minute, dt.second, dt.microsecond)
set_random_seed(0)

# model
TRAPModel = TransTRAP(p_max=epi_max, b_max=b_max, in_feat_b=in_feat_b, in_feat_p=in_feat_p, 
                        d_model=d_model, d_ff=d_ff, d_k=d_k, d_v=d_v, n_heads=n_heads,
                    n_layers=n_layers, dropout=dropout, d_graph_layer=d_graph_layer, 
                    d_FC_layer=d_FC_layer, n_FC_layer=n_FC_layer, n_tasks=n_tasks,
                    batch_size=batch_size)
print('number of parameters : ', sum(p.numel() for p in TRAPModel.parameters() if p.requires_grad))
print(TRAPModel)

device = torch.device("cuda:%s" % gpuid if torch.cuda.is_available() else "cpu")
TRAPModel.to(device)
optimizer = torch.optim.Adam(TRAPModel.parameters(), lr=lr, weight_decay=l2)
loss_fn = FocalLoss(gamma=2, alpha=30 / (30 + 1))


kf = KFold(n_splits=5)  # 5-fold

for train_index, valid_index in kf.split(positive_idx):
    stopper = EarlyStopping(mode='lower', patience=patience, tolerance=tolerance, filename=filename)
    train_dataset = Subset(positive_dataset, train_index)
    valid_dataset = Subset(positive_dataset, valid_index)

    train_dataloader = DataLoaderX(train_dataset, batch_size, shuffle=True, num_workers=0,
                                collate_fn=collate_fn_v2)
    valid_dataloader = DataLoaderX(valid_dataset, batch_size, shuffle=False, num_workers=0,
                                collate_fn=collate_fn_v2)

    train_loss_record = []
    train_auc_record = []
    valid_loss_record = []
    valid_auc_record = []
    for epoch in range(epochs):
        st = time.time()
        # train
        run_clip_train_epoch(TRAPModel, train_dataloader, optimizer, device)
        run_bi_train_epoch(TRAPModel, kfold_dataloader, loss_fn, optimizer, device)

        # validation
        train_true,  train_pred, train_key, train_loss, train_loss2 = run_a_eval_epoch(TRAPModel, train_dataloader, loss_fn, device)
        valid_true,  valid_pred, valid_key, valid_loss, valid_loss2 = run_a_eval_epoch(TRAPModel, valid_dataloader, loss_fn, device)
        kfold_true,  kfold_pred, kfold_key, kfold_loss, kfold_loss2 = run_a_eval_epoch(TRAPModel, kfold_dataloader, loss_fn, device)

        kfold_auc = roc_auc_score(kfold_true, kfold_pred)
        early_stop = stopper.step(valid_loss, TRAPModel)

        end = time.time()

        if early_stop:
            break
        print("epoch:%s\ttrain_loss:%.4f\ttrain_loss2:%.4f\tvalid_loss:%.4f\tvalid_loss2:%.4f\tkfold_loss:%.4f\tkfold_loss2:%.4f\tkfold_auc:%.4f\ttime:%.3fs" %(
        epoch, train_loss, train_loss2, valid_loss, valid_loss2, kfold_loss, kfold_loss2, kfold_auc, end - st))


# load the best model
stopper.load_checkpoint(TRAPModel)

# test
kfold_true, kfold_pred, kfold_keys, kfold_attns, _ = run_a_eval_epoch(TRAPModel, kfold_dataloader,
                                                                                loss_fn, device)
test_true, test_pred, te_keys, test_attns, _ = run_a_eval_epoch(TRAPModel, test_dataloader,
                                                                                loss_fn, device)


pd_kfold = pd.DataFrame({'key': kfold_keys, 'kfold_true': kfold_true, 'kfold_pred': kfold_pred})
pd_te = pd.DataFrame({'key': te_keys, 'test_true': test_true, 'test_pred': test_pred})

pd_kfold.to_csv(
    reuslt_path + path_marker + 'stats' + path_marker + '{}_{:02d}_{:02d}_{:02d}_{:d}_kfold.csv'.format(
        dt.date(), dt.hour, dt.minute, dt.second, dt.microsecond), index=False)
pd_te.to_csv(
    reuslt_path + path_marker + 'stats' + path_marker + '{}_{:02d}_{:02d}_{:02d}_{:d}_te.csv'.format(
        dt.date(), dt.hour, dt.minute, dt.second, dt.microsecond), index=False)


kfold_auc = roc_auc_score(kfold_true, kfold_pred)
test_auc = roc_auc_score(test_true, test_pred)

kfold_aupr = cal_aupr(kfold_true, kfold_pred)
test_aupr = cal_aupr(test_true, test_pred)





print('***The Best TRAP model***')
print("epoch:%s te_aupr:%.4f te_auc:%.4f" % (
    epoch, test_aupr, test_auc))


  

Generate CDRb-epitope graph...
main process start >>> pid=1751642
main process end (time:77.5854344367981 S)

main process end (time:77.5854344367981 S)

Generate CDRb-epitope graph...
main process start >>> pid=1751642
Generate CDRb-epitope graph...
main process start >>> pid=1751642
main process end (time:383.8331575393677 S)

main process end (time:383.8331575393677 S)

Generate CDRb-epitope graph...
main process start >>> pid=1751642
Generate CDRb-epitope graph...
main process start >>> pid=1751642
main process end (time:41.690550327301025 S)

main process end (time:41.690550327301025 S)

number of parameters :  5487501
TransTRAP(
  (b_encoder): Encoder_(
    (src_emb): Linear(in_features=1280, out_features=200, bias=True)
    (pos_emb): PositionalEncoding(
      (dropout): Dropout(p=0.3, inplace=False)
    )
    (layers): ModuleList(
      (0-2): 3 x EncoderLayer(
        (enc_self_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=200, out_features=512, bias=False)
  