In [1]:
import pandas as pd
import os, sys
from sklearn.metrics import roc_auc_score

In [2]:
from torch.utils.data import DataLoader
from EPACT.utils import load_config, set_seed
from EPACT.dataset import UnlabeledDataset, UnlabeledBacthConverter
from EPACT.trainer import PairedCDR3pMHCCoembeddingTrainer, PairedCDR123pMHCCoembeddingTrainer

In [5]:
from utils.data_mixing import load_dataset, is_invalid

In [8]:
train_dataset = pd.read_pickle('data/pretrained/netmhcpan_pep_cluster_0.5.pkl')
print(train_dataset)

{'AAASSLLYK': 17180, 'AAEAMEVA': 17181, 'AAEQRRSTI': 17182, 'AAESTFESY': 17183, 'AAFVNQHLCG': 17184, 'AAGFDPEVQ': 17185, 'AAGLQDCT': 17186, 'AAGLQDCTM': 17187, 'AAGLQDCTML': 17188, 'AAGLQDCTMLV': 17189, 'AAILKQHKL': 17190, 'AAIRILQQL': 17191, 'AAITDAAVA': 17192, 'AAITDAAVAV': 17193, 'AAKKKGASL': 17194, 'AALAFHLTSR': 17195, 'AALFMYYAK': 17196, 'AALFMYYAKR': 17197, 'AALQSAWQG': 17198, 'AAMAAQLQA': 17199, 'AAMDDFQLI': 17300, 'AAMQRKLEK': 17301, 'AANDPIFVV': 17302, 'AANEMGLIEK': 17303, 'AAPAPAPSW': 17304, 'AAPLILSRI': 17305, 'AARDRQFEK': 17306, 'AARIAGRHM': 17307, 'AARILSEKRK': 17308, 'AARNIVRRA': 17309, 'AASPMLYQL': 17310, 'AASPMLYQLL': 17311, 'AASTLLYATV': 17312, 'AATIQTPTK': 17313, 'AAVALLNKL': 17314, 'AAVDNAVVV': 17315, 'AAVLLLVTHY': 17316, 'AAVSHLTTL': 17317, 'AAYHPQQFI': 17318, 'AAYHPQQFIYA': 17319, 'ACDPHSGHFV': 17320, 'ACQEAVKLK': 17321, 'ADDETSSLP': 17322, 'ADKNLIKCS': 17323, 'ADLRFASEF': 17324, 'ADLVCEQGN': 17325, 'ADMSKLLNL': 17326, 'ADSEITETY': 17327, 'ADSGCVINW': 17328, 'AEAAL

In [34]:
def filter_invalid_seq(dataset):
    remove_index = []
    # drop rows contains illegal animo acid residue
    dataset['illegal_a'] = dataset['CDR3a'].apply(lambda x: is_invalid(x))
    index_list = dataset.index[dataset['illegal_a']]
    remove_index.extend(index_list)
    dataset['illegal_b'] = dataset['CDR3b'].apply(lambda x: is_invalid(x))
    index_list = dataset.index[dataset['illegal_b']]
    remove_index.extend(index_list)
    dataset.drop(index=list(set(remove_index)), inplace=True)
    dataset.drop(columns=['illegal_a', 'illegal_b'], inplace=True)
    return dataset.reset_index(drop=True)

kevin's data are all positive bindings. mixing with some negative data from 10x

In [32]:
dpos = pd.read_csv('processed-data/kevin_positive_remain_gene.csv')
dneg = pd.read_csv('processed-data/10x_negative_all.csv')

In [35]:
dpos = filter_invalid_seq(dpos)
dneg = filter_invalid_seq(dneg)
dpos['sign'] = 1
dneg['sign'] = 0
dpos.rename(columns={"HLA": "MHC"}, inplace=True)
negative_ratio = 5
peps = dpos['Peptide'].unique()
dfs = []
dfs.append(dpos)
for pep in peps:
    count = len(dpos[dpos['Peptide'] == pep])
    total = count * negative_ratio
    cur_neg = dneg[dneg['Peptide'] == pep]
    if len(cur_neg) > total:
        dfs.append(cur_neg.sample(n=count * negative_ratio, random_state=42))
    else:
        dfs.append(cur_neg)
        total -= len(cur_neg)
        df_mat = dpos[dpos['Peptide'] != pep]
        if len(df_mat) > total:
            df_mat = df_mat.sample(n=total, random_state=42)  
        df_mat['Peptide'] = pep
        df_mat['sign'] = 0
        dfs.append(df_mat)
df_final = pd.concat(dfs, axis=0)
df_final.sample(frac=1, random_state=42)
df_final.reset_index(drop=True, inplace=True)
df_final.head()

Unnamed: 0,CDR3a,TRAV,TRAJ,CDR3b,TRBV,TRBJ,Peptide,MHC,sign
0,CILPLAGGTSYGKLTF,TRAV26-2,TRAJ52,CASSLGQAYEQYF,TRBV7-8,TRBJ2-7,FLRGRAYGL,B0801,1
1,CILPLAGGTSYGKLTF,TRAV26-2,TRAJ52,CASSSGQAYEQYF,TRBV7-8,TRBJ2-7,FLRGRAYGL,B0801,1
2,CAVRDTTWDDKIIF,TRAV3,TRAJ30,CASSLGGGEGASEQFF,TRBV7-6,TRBJ2-1,FLRGRAYGL,B0801,1
3,CAVRDTTWAAKIIF,TRAV3,TRAJ30,CASSLGGGEGASEQFF,TRBV7-6,TRBJ2-1,FLRGRAYGL,B0801,1
4,CAPSLNSGGYQKVTF,TRAV2,TRAJ13,CASSVVGGDYGYTF,TRBV2,TRBJ1-2,FLRGRAYGL,B0801,1


In [36]:
df_final['MHC'] = df_final['MHC'].apply(lambda x: f'HLA-{x[0]}*{x[1:3]}:{x[3:]}')

In [37]:
df_final.head()

Unnamed: 0,CDR3a,TRAV,TRAJ,CDR3b,TRBV,TRBJ,Peptide,MHC,sign
0,CILPLAGGTSYGKLTF,TRAV26-2,TRAJ52,CASSLGQAYEQYF,TRBV7-8,TRBJ2-7,FLRGRAYGL,HLA-B*08:01,1
1,CILPLAGGTSYGKLTF,TRAV26-2,TRAJ52,CASSSGQAYEQYF,TRBV7-8,TRBJ2-7,FLRGRAYGL,HLA-B*08:01,1
2,CAVRDTTWDDKIIF,TRAV3,TRAJ30,CASSLGGGEGASEQFF,TRBV7-6,TRBJ2-1,FLRGRAYGL,HLA-B*08:01,1
3,CAVRDTTWAAKIIF,TRAV3,TRAJ30,CASSLGGGEGASEQFF,TRBV7-6,TRBJ2-1,FLRGRAYGL,HLA-B*08:01,1
4,CAPSLNSGGYQKVTF,TRAV2,TRAJ13,CASSVVGGDYGYTF,TRBV2,TRBJ1-2,FLRGRAYGL,HLA-B*08:01,1


rename the column to match with EPACT input

In [38]:
df_input = df_final.rename(columns={'CDR3a': 'CDR3.alpha.aa', 'TRAV': 'V.alpha', 'TRAJ': 'J.alpha', 'CDR3b': 'CDR3.beta.aa', 'TRBV': 'V.beta', 'TRBJ': 'J.beta', 'Peptide': 'Epitope.peptide', 'sign': 'Target'})

In [39]:
df_input.head()

Unnamed: 0,CDR3.alpha.aa,V.alpha,J.alpha,CDR3.beta.aa,V.beta,J.beta,Epitope.peptide,MHC,Target
0,CILPLAGGTSYGKLTF,TRAV26-2,TRAJ52,CASSLGQAYEQYF,TRBV7-8,TRBJ2-7,FLRGRAYGL,HLA-B*08:01,1
1,CILPLAGGTSYGKLTF,TRAV26-2,TRAJ52,CASSSGQAYEQYF,TRBV7-8,TRBJ2-7,FLRGRAYGL,HLA-B*08:01,1
2,CAVRDTTWDDKIIF,TRAV3,TRAJ30,CASSLGGGEGASEQFF,TRBV7-6,TRBJ2-1,FLRGRAYGL,HLA-B*08:01,1
3,CAVRDTTWAAKIIF,TRAV3,TRAJ30,CASSLGGGEGASEQFF,TRBV7-6,TRBJ2-1,FLRGRAYGL,HLA-B*08:01,1
4,CAPSLNSGGYQKVTF,TRAV2,TRAJ13,CASSVVGGDYGYTF,TRBV2,TRBJ1-2,FLRGRAYGL,HLA-B*08:01,1


In [40]:
df_input.to_csv('./input/kevin.csv', index=False)

run in CDR3 mode. Kevin's data only has CDR3 columns

In [41]:
#@markdown Select the EPACT model:
model_name = "CDR3 binding model" #@param ['CDR3 binding model', 'CDR123 binding model']

#@markdown In default, we will use `sample/VDJdb-GLCTLVAML.csv` for prediction.
input_data_path = "input/kevin.csv" #@param {type:"string"}

#@markdown Specify the name of the result folder:
result_dir = "demo/binding" #@param {type:"string"}

#@markdown Specify the number of batch size:
batch_size = 128 #@param {type: "integer"}

In [42]:
if model_name == "CDR3 binding model":
  config_path = 'configs/config-paired-cdr3-pmhc-binding.yml'
  model_location_list = [f'checkpoints/paired-cdr3-pmhc-binding/paired-cdr3-pmhc-binding-model-fold-{i+1}.pt' for i in range(5)]
elif model_name == "CDR123 binding model":
  config_path = 'configs/config-paired-cdr123-pmhc-binding.yml'
  model_location_list = [f'checkpoints/paired-cdr123-pmhc-binding/paired-cdr123-pmhc-binding-model-fold-{i+1}.pt' for i in range(5)]

config = load_config(config_path)
set_seed(config.training.seed)
config.training.gpu_device = 0

In [43]:
dataset = UnlabeledDataset(data_path = input_data_path, hla_lib_path = config.data.hla_lib_path)
data_loader = DataLoader(
        dataset = dataset, batch_size = batch_size, num_workers = 1,
        collate_fn = UnlabeledBacthConverter(max_mhc_len = config.model.mhc_seq_len, use_cdr123=config.data.use_cdr123),
        shuffle = False
    )

if not os.path.exists(result_dir):
  os.makedirs(result_dir)

for i in range(5):
  result_fold_dir = os.path.join(result_dir, f'Fold_{i+1}')

  if not os.path.exists(result_fold_dir):
    os.makedirs(result_fold_dir)

  if config.data.use_cdr123:
      Trainer = PairedCDR123pMHCCoembeddingTrainer(config, result_fold_dir)
  else:
      Trainer = PairedCDR3pMHCCoembeddingTrainer(config, result_fold_dir)

  Trainer.predict(data_loader, model_location=model_location_list[i])

  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 12/12 [00:42<00:00,  3.57s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 12/12 [00:43<00:00,  3.60s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 12/12 [00:43<00:00,  3.59s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrai

prediction result in AUC

In [44]:
data = pd.read_csv(input_data_path)
for i in range(5):
  prediction = pd.read_csv(f'{result_dir}/Fold_{i+1}/predictions.csv')
  if i == 0:
    avg_pred = prediction['Pred'] / 5
  else:
    avg_pred += prediction['Pred'] / 5

data['Pred'] = avg_pred
auc = roc_auc_score(data['Target'], data['Pred'])
print(auc)

0.7888260955780289
