In [1]:
import pandas as pd
import os, sys
from sklearn.metrics import roc_auc_score

In [2]:
from utils.data_mixing import is_invalid

In [3]:
df_cdr = pd.read_csv("data/pretrained/combined_paired_pretrain_data.csv")
df_cdr.head()

Unnamed: 0,cdr3a,cdr3b,va,ja,vb,jb,cdr1a,cdr2a,cdr1b,cdr2b
0,CAYRSVYRSFMYSGGGADGLTF,CASSLEVGGGEETQYF,TRAV38-2/DV8,TRAJ45,TRBV7-6,TRBJ2-5,TSESDYY,QEAYKQQN,SGHVS,FNYEAQ
1,CAYRSNSYGGSQGNLIF,CASSLVSFGDHGFF,TRAV38-2/DV8,TRAJ42,TRBV7-6,TRBJ1-1,TSESDYY,QEAYKQQN,SGHVS,FNYEAQ
2,CAYSHSGNTPLVF,CASSPRTVSTNEKLFF,TRAV38-2/DV8,TRAJ29,TRBV7-6,TRBJ1-4,TSESDYY,QEAYKQQN,SGHVS,FNYEAQ
3,CAYRSAQGAQKLVF,CASSFSAGGYEQYF,TRAV38-2/DV8,TRAJ54,TRBV7-6,TRBJ2-7,TSESDYY,QEAYKQQN,SGHVS,FNYEAQ
4,CAYRSGNQFYF,CASSIIHGTGIHNEQFF,TRAV38-2/DV8,TRAJ49,TRBV7-6,TRBJ2-1,TSESDYY,QEAYKQQN,SGHVS,FNYEAQ


In [4]:
df_epi = pd.read_csv("data/pretrained/netMHCpan-MHC-I-BA-data.csv")
df_epi.head()

Unnamed: 0,Epitope.peptide,Target,MHC
0,AAASSLLYK,0.010608,HLA-A*02:06
1,AAASSLLYK,0.031401,HLA-A*02:03
2,AAASSLLYK,0.153147,HLA-A*33:01
3,AAASSLLYK,0.488861,HLA-A*31:01
4,AAASSLLYK,0.573507,HLA-A*68:01


In [5]:
def flag_dataset(dataset, train_data, tcrb_column, pep_column):
    in_train_cdr = pd.DataFrame(df_cdr['cdr3b'].unique(), columns=['tcrb'])
    in_train_pep = pd.DataFrame(df_epi['Epitope.peptide'].unique(), columns=['pep'])
    # join tcrb and pep flag to positive data
    df_merge = dataset.merge(in_train_cdr, how='left', left_on='CDR3b', right_on='tcrb')
    df_merge['use_tcrb'] = df_merge['tcrb'].apply(lambda x: 0 if pd.isna(x) else 1)    
    df_merge = df_merge.merge(in_train_pep, how='left', left_on='Peptide', right_on='pep')
    df_merge['use_pep'] = df_merge['pep'].apply(lambda x: 0 if pd.isna(x) else 1)    
    # when both tcr and pep are seen in training, remove
    remove_index = []
    index_list = df_merge.index[(df_merge['use_tcrb'] == 1) & (df_merge['use_pep'] == 1)]
    remove_index.extend(index_list)
    # drop rows contains illegal animo acid residue
    df_merge['illegal_a'] = df_merge['CDR3a'].apply(lambda x: is_invalid(x))
    index_list = df_merge.index[df_merge['illegal_a']]
    remove_index.extend(index_list)
    df_merge['illegal_b'] = df_merge['CDR3b'].apply(lambda x: is_invalid(x))
    index_list = df_merge.index[df_merge['illegal_b']]
    remove_index.extend(index_list)
    df_merge.drop(index=list(set(remove_index)), inplace=True)
    df_merge.drop(columns=['illegal_a', 'illegal_b'], inplace=True)
    return df_merge.reset_index(drop=True)

In [6]:
def build_negative_swap(positive_data, train_data, cdr3, peptide, negative_ratio = 5):
    df_pos = flag_dataset(positive_data, train_data, cdr3, peptide)
    df_pos['sign'] = 1
    dfs = []
    # unseen peptide do not need to worry about duplicate in training set
    unseen_pep_record_total = len(df_pos[df_pos['use_pep'] == 0])
    if unseen_pep_record_total > 0:
        unseen_peps = df_pos[df_pos['use_pep'] == 0][peptide].unique()
        for pep in unseen_peps:
            unmatch_tcr = df_pos[df_pos[peptide] != pep].copy()
            unmatch_tcr[peptide] = pep
            unmatch_tcr['use_pep'] = 0
            unmatch_tcr['sign'] = 0
            unmatch_tcr.drop_duplicates(inplace=True)
            count = len(df_pos[df_pos[peptide] == pep])
            total = count * negative_ratio                
            df_mat = unmatch_tcr.sample(n=total, random_state=42) if len(unmatch_tcr) > total else unmatch_tcr
            dfs.append(df_mat)
    print(f'***dataset contains {unseen_pep_record_total} records using peptide outside training dataset***')

    # seen peptide
    seen_pep_record_total = 0
    for pep in df_pos[df_pos['use_pep'] == 1][peptide].unique():
        count = len(df_pos[df_pos[peptide] == pep])
        seen_pep_record_total += count
        total = count * negative_ratio
        print(f'### {pep} - {count} ###')
        # all the tcrs not bind to the peptide
        unmatch_tcr = df_pos[df_pos[peptide] != pep].copy()
        unmatch_tcr[peptide] = pep
        unmatch_tcr['use_pep'] = 1
        unmatch_tcr['sign'] = 0
        unmatch_tcr.drop_duplicates(inplace=True)
        df_mat = unmatch_tcr.sample(n=total, random_state=42) if len(unmatch_tcr) > total else unmatch_tcr
        dfs.append(df_mat)
    print(f'***dataset contains {seen_pep_record_total} records using peptide outside training dataset***')
    df_final = pd.concat(dfs, axis=0)
    df_final.sample(frac=1, random_state=42)
    df_final.reset_index(drop=True, inplace=True)
    return df_final

def build_negative_sample(positive_data, negative_data, train_data, cdr3, peptide, negative_ratio = 5):
    df_pos = flag_dataset(positive_data, train_data, cdr3, peptide)
    df_pos['sign'] = 1
    # prepare negative data
    df_neg = flag_dataset(negative_data, train_data, cdr3, peptide)
    df_neg['sign'] = 0

    dfs = []
    # when pep are unseen, add unseen pep negative
    unseen_pep_record_total = len(df_pos[df_pos['use_pep'] == 0])
    if unseen_pep_record_total > 0:
        df_neg_unseen = df_neg[df_neg['use_pep'] == 0]
        # when negative dataset have enough unseen pep data
        if len(df_neg_unseen) > unseen_pep_record_total * negative_ratio:
            dfs.append(df_neg_unseen.sample(n=unseen_pep_record_total * negative_ratio, random_state=42))
        else:
            unseen_peps = df_pos[df_pos['use_pep'] == 0][peptide].unique()
            for pep in unseen_peps:
                count = len(df_pos[df_pos[peptide] == pep])
                total = count * negative_ratio
                df_from_neg = df_neg[df_neg[peptide] == pep]
                if len(df_from_neg) >= total:
                    df_mat = df_from_neg.sample(n=total, random_state=42)
                    dfs.append(df_mat)
                else:
                    dfs.append(df_from_neg)
                    total -= len(df_from_neg)
                    df_mat = df_pos[df_pos[peptide] != pep]
                    df_mat = df_mat.sample(n=total, random_state=42)
                    df_mat[peptide] = pep
                    df_mat['use_pep'] = 0
                    df_mat['sign'] = 0
                    dfs.append(df_mat)
    print(f'***dataset contains {unseen_pep_record_total} records using peptide outside training dataset***')

    print('*********')
    print('dataset contains peptides')
    for pep in df_pos[df_pos['use_pep'] == 1][peptide].unique():
        count = len(df_pos[df_pos[peptide] == pep])
        total = count * negative_ratio
        print(f'### {pep} - {count} ###')
        cur_neg = df_neg[df_neg[peptide] == pep]
        if len(cur_neg) > total:
            dfs.append(cur_neg.sample(n=count * negative_ratio, random_state=42))
        else:
            dfs.append(cur_neg)
            total -= len(cur_neg)
            df_mat = df_pos[df_pos[peptide] != pep]
            if len(df_mat) > count * negative_ratio:
                df_mat = df_mat.sample(n=count * negative_ratio, random_state=42)            
            df_mat[peptide] = pep
            df_mat['use_pep'] = 1
            df_mat['sign'] = 0
            dfs.append(df_mat)
    print('*********')
    # join all data together
    df_final = pd.concat(dfs, axis=0)
    df_final.sample(frac=1, random_state=42)
    df_final.reset_index(drop=True, inplace=True)
    return df_final

#negative rate is 5. if mode is mixed, 3 from random swap, 2 from negative control
def load_dataset(positive_data, negative_data, train_data, cdr3, peptide, mode, negative_ratio = 5):
    df_pos = flag_dataset(positive_data, train_data, cdr3, peptide)
    df_pos['sign'] = 1
    # prepare negative data
    df_neg = flag_dataset(negative_data, train_data, cdr3, peptide)
    df_neg['sign'] = 0

    dfs = [df_pos]
    # negative data from shuffle positive data
    if mode == 1:
        negative = build_negative_swap(positive_data, train_data, cdr3, peptide, negative_ratio)
        dfs.append(negative)
    elif mode == 2:
        negative = build_negative_sample(positive_data, negative_data, train_data, cdr3, peptide, negative_ratio)
        dfs.append(negative)
    else:
        negative = build_negative_swap(positive_data, train_data, cdr3, peptide, 3)
        dfs.append(negative)
        negative = build_negative_sample(positive_data, negative_data, train_data, cdr3, peptide, 2)
        dfs.append(negative)
    df_final = pd.concat(dfs, axis=0)
    df_final.sample(frac=1, random_state=42)
    df_final.reset_index(drop=True, inplace=True)
    return df_final

In [7]:
from torch.utils.data import DataLoader
from EPACT.utils import load_config, set_seed
from EPACT.dataset import UnlabeledDataset, UnlabeledBacthConverter
from EPACT.trainer import PairedCDR3pMHCCoembeddingTrainer, PairedCDR123pMHCCoembeddingTrainer

In [8]:
def predict(model_name, input_data_path, result_dir):
    batch_size = 128
    if model_name == "CDR3 binding model":
        config_path = 'configs/config-paired-cdr3-pmhc-binding.yml'
        model_location_list = [f'checkpoints/paired-cdr3-pmhc-binding/paired-cdr3-pmhc-binding-model-fold-{i+1}.pt' for i in range(5)]
    elif model_name == "CDR123 binding model":
        config_path = 'configs/config-paired-cdr123-pmhc-binding.yml'
        model_location_list = [f'checkpoints/paired-cdr123-pmhc-binding/paired-cdr123-pmhc-binding-model-fold-{i+1}.pt' for i in range(5)]

    config = load_config(config_path)
    set_seed(config.training.seed)
    config.training.gpu_device = 0
    
    dataset = UnlabeledDataset(data_path = input_data_path, hla_lib_path = config.data.hla_lib_path)
    data_loader = DataLoader(
        dataset = dataset, batch_size = batch_size, num_workers = 1,
        collate_fn = UnlabeledBacthConverter(max_mhc_len = config.model.mhc_seq_len, use_cdr123=config.data.use_cdr123),
        shuffle = False
    )

    if not os.path.exists(result_dir):
        os.makedirs(result_dir)

    for i in range(5):
        result_fold_dir = os.path.join(result_dir, f'Fold_{i+1}')

        if not os.path.exists(result_fold_dir):
            os.makedirs(result_fold_dir)

        if config.data.use_cdr123:
            Trainer = PairedCDR123pMHCCoembeddingTrainer(config, result_fold_dir)
        else:
            Trainer = PairedCDR3pMHCCoembeddingTrainer(config, result_fold_dir)

        Trainer.predict(data_loader, model_location=model_location_list[i])
    
    data = pd.read_csv(input_data_path)
    for i in range(5):
        prediction = pd.read_csv(f'{result_dir}/Fold_{i+1}/predictions.csv')
        if i == 0:
            avg_pred = prediction['Pred'] / 5
        else:
            avg_pred += prediction['Pred'] / 5

    auc = roc_auc_score(data['Target'], avg_pred)
    partial_auc = roc_auc_score(data['Target'], avg_pred, max_fpr=0.1)
    return auc, partial_auc

In [9]:
df_10x_neg = pd.read_csv('processed-data/10x_negative_all.csv')
df_10x_neg.head()

Unnamed: 0,TRAV,TRAJ,CDR3a,TRBV,TRBJ,CDR3b,Peptide,MHC
0,TRAV29DV5,TRAJ44,CAASVSIWTGTASKLTF,TRBV10-3,TRBJ2-1,CAISDPGLAGGGGEQFF,VTEHDTLLY,A0101
1,TRAV8-6,TRAJ47,CAAWDMEYGNKLVF,TRBV10-3,TRBJ2-1,CAISDPGLAGGGGEQFF,VTEHDTLLY,A0101
2,TRAV38-2DV8,TRAJ34,CASYTDKLIF,TRBV5-1,TRBJ2-3,CASSGGSISTDTQYF,VTEHDTLLY,A0101
3,TRAV29DV5,TRAJ5,CAASGYGNTGRRALTF,TRBV4-3,TRBJ2-1,CASSQDPAGGYNEQFF,VTEHDTLLY,A0101
4,TRAV29DV5,TRAJ48,CAAHLSNFGNEKLTF,TRBV15,TRBJ1-3,CATSRDRGHGDTIYF,VTEHDTLLY,A0101


In [10]:
df_10x_neg.rename(columns={'MHC': 'HLA'}, inplace=True)

In [11]:
df_10x_neg['HLA'] = df_10x_neg['HLA'].apply(lambda x: f'HLA-{x[0]}*{x[1:3]}:{x[3:]}')

In [12]:
df_icon_pos = pd.read_csv('processed-data/10x_positive_ICON.csv')
df_icon_pos['HLA'] = df_icon_pos['HLA'].apply(lambda x: f'HLA-{x[0]}*{x[1:3]}:{x[3:]}')

In [13]:
df_itrap_pos = pd.read_csv('processed-data/10x_positive_ITRAP.csv')
df_itrap_pos['HLA'] = df_itrap_pos['HLA'].apply(lambda x: f'HLA-{x[0]}*{x[1:3]}:{x[3:]}')

In [14]:
df_kevin = pd.read_csv("processed-data/kevin_positive_remain_gene.csv")
df_kevin['HLA'] = df_kevin['HLA'].apply(lambda x: f'HLA-{x[0]}*{x[1:3]}:{x[3:]}')

In [15]:
df_finger = pd.read_csv('processed-data/fingerprinting_all_remain_gene.csv')
df_finger.head()

Unnamed: 0,TCR,TRAV,TRAJ,CDR3a,TRBV,TRBJ,CDR3b,ID,COORD,Plate,log2foldchangevalue,Peptide,HLA,sp
0,SVAR-1,TRAV12-2*01,TRAJ30*01,AVNRDDKII,TRBV7-9*01,TRBJ1-1*01,ASSPDIEAF,A01_SVAR-1_P1.fcs,A01,P1,5.062388,YLQPRTFLL,HLA-A*02:01,1
1,SVAR-1,TRAV12-2*01,TRAJ30*01,AVNRDDKII,TRBV7-9*01,TRBJ1-1*01,ASSPDIEAF,A01_SVAR-1_P2.fcs,A01,P2,-0.31166,YLQPLTFLL,HLA-A*02:01,0
2,SVAR-1,TRAV12-2*01,TRAJ30*01,AVNRDDKII,TRBV7-9*01,TRBJ1-1*01,ASSPDIEAF,A02_SVAR-1_P1.fcs,A02,P1,5.066687,ALQPRTFLL,HLA-A*02:01,1
3,SVAR-1,TRAV12-2*01,TRAJ30*01,AVNRDDKII,TRBV7-9*01,TRBJ1-1*01,ASSPDIEAF,A02_SVAR-1_P2.fcs,A02,P2,-0.263034,YLQPKTFLL,HLA-A*02:01,0
4,SVAR-1,TRAV12-2*01,TRAJ30*01,AVNRDDKII,TRBV7-9*01,TRBJ1-1*01,ASSPDIEAF,A03_SVAR-1_P1.fcs,A03,P1,5.046788,RLQPRTFLL,HLA-A*02:01,1


In [16]:
df_finger['TRAV'] = df_finger['TRAV'].apply(lambda x: x.split('*')[0])
df_finger['TRAJ'] = df_finger['TRAJ'].apply(lambda x: x.split('*')[0])
df_finger['TRBV'] = df_finger['TRBV'].apply(lambda x: x.split('*')[0])
df_finger['TRBJ'] = df_finger['TRBJ'].apply(lambda x: x.split('*')[0])
df_finger.head()

Unnamed: 0,TCR,TRAV,TRAJ,CDR3a,TRBV,TRBJ,CDR3b,ID,COORD,Plate,log2foldchangevalue,Peptide,HLA,sp
0,SVAR-1,TRAV12-2,TRAJ30,AVNRDDKII,TRBV7-9,TRBJ1-1,ASSPDIEAF,A01_SVAR-1_P1.fcs,A01,P1,5.062388,YLQPRTFLL,HLA-A*02:01,1
1,SVAR-1,TRAV12-2,TRAJ30,AVNRDDKII,TRBV7-9,TRBJ1-1,ASSPDIEAF,A01_SVAR-1_P2.fcs,A01,P2,-0.31166,YLQPLTFLL,HLA-A*02:01,0
2,SVAR-1,TRAV12-2,TRAJ30,AVNRDDKII,TRBV7-9,TRBJ1-1,ASSPDIEAF,A02_SVAR-1_P1.fcs,A02,P1,5.066687,ALQPRTFLL,HLA-A*02:01,1
3,SVAR-1,TRAV12-2,TRAJ30,AVNRDDKII,TRBV7-9,TRBJ1-1,ASSPDIEAF,A02_SVAR-1_P2.fcs,A02,P2,-0.263034,YLQPKTFLL,HLA-A*02:01,0
4,SVAR-1,TRAV12-2,TRAJ30,AVNRDDKII,TRBV7-9,TRBJ1-1,ASSPDIEAF,A03_SVAR-1_P1.fcs,A03,P1,5.046788,RLQPRTFLL,HLA-A*02:01,1


In [17]:
df_finger.drop(columns = ['TCR', 'ID', 'COORD', 'Plate', 'log2foldchangevalue'], inplace=True)
df_finger.head()

Unnamed: 0,TRAV,TRAJ,CDR3a,TRBV,TRBJ,CDR3b,Peptide,HLA,sp
0,TRAV12-2,TRAJ30,AVNRDDKII,TRBV7-9,TRBJ1-1,ASSPDIEAF,YLQPRTFLL,HLA-A*02:01,1
1,TRAV12-2,TRAJ30,AVNRDDKII,TRBV7-9,TRBJ1-1,ASSPDIEAF,YLQPLTFLL,HLA-A*02:01,0
2,TRAV12-2,TRAJ30,AVNRDDKII,TRBV7-9,TRBJ1-1,ASSPDIEAF,ALQPRTFLL,HLA-A*02:01,1
3,TRAV12-2,TRAJ30,AVNRDDKII,TRBV7-9,TRBJ1-1,ASSPDIEAF,YLQPKTFLL,HLA-A*02:01,0
4,TRAV12-2,TRAJ30,AVNRDDKII,TRBV7-9,TRBJ1-1,ASSPDIEAF,RLQPRTFLL,HLA-A*02:01,1


In [18]:
negative_mode = {
    1: 'random swap',
    2: 'random sample',
    3: 'mixed'
}
output_df = pd.DataFrame(columns=['name', 'negative', 'seen', 'roc-auc', 'macro 0.1'])
for ds in ['icon', 'itrap', 'kevin', 'fingerprint']:
    if ds == 'icon':
        dpos = df_icon_pos
        dneg = df_10x_neg
    elif ds == 'itrap':
        dpos = df_itrap_pos
        dneg = df_10x_neg
    elif ds == 'kevin':
        dpos = df_kevin
        dneg = df_10x_neg
    else: # fingerprinting contains both positive and negative
        dp_all = df_finger
        dpos = dp_all[dp_all['sp'] == 1].copy()
        dneg = dp_all[dp_all['sp'] == 0].copy()
    for i in negative_mode:
        df_process = load_dataset(dpos, dneg, '', 'CDR3b', 'Peptide', i)
        df_process.rename(columns={'CDR3a': 'CDR3.alpha.aa', 'TRAV': 'V.alpha', 'TRAJ': 'J.alpha', 'CDR3b': 'CDR3.beta.aa', 'TRBV': 'V.beta', 'TRBJ': 'J.beta', 'Peptide': 'Epitope.peptide', 'sign': 'Target', 'HLA': 'MHC'}, inplace=True)        
        ds_seen = df_process[df_process['use_pep'] == 1].copy()
        ds_seen.reset_index(drop=True, inplace=True)
        if len(ds_seen) > 1:
            filepath = f'input/{ds}-{i}-seen.csv'
            ds_seen.to_csv(filepath, index=False)
            result = f'demo/binding/{ds}-{i}-seen'
            roc, p_auc = predict('CDR3 binding model', filepath, result)
            output_df.loc[len(output_df)] = [ds, negative_mode[i], 'seen', roc, p_auc]
        ds_unseen = df_process[df_process['use_pep'] == 0].copy()
        ds_unseen.reset_index(drop=True, inplace=True)
        if len(ds_unseen) > 1:
            filepath = f'input/{ds}-{i}-unseen.csv'
            ds_unseen.to_csv(filepath)
            result = f'demo/binding/{ds}-{i}-unseen'
            roc, p_auc = predict('CDR3 binding model', filepath, result)
            output_df.loc[len(output_df)] = [ds, negative_mode[i], 'unseen', roc, p_auc]

***dataset contains 43 records using peptide outside training dataset***
### IVTDFSVIK - 8 ###
### KLGGALQAK - 66 ###
### GLCTLVAML - 1 ###
### ELAGIGILTV - 17 ###
### RAKFKQLL - 16 ###
### GILGFVFTL - 20 ###
### AVFDRKSDAK - 2 ###
### AYAQKIFKI - 1 ###
### IMDQVPFSV - 1 ###
### QYDPVAALF - 1 ###
### SLLMWITQV - 1 ###
### RLRAEAQVK - 1 ###
### SLFNTVATL - 1 ###
***dataset contains 136 records using peptide outside training dataset***


  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 5/5 [00:21<00:00,  4.25s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 5/5 [00:21<00:00,  4.30s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 5/5 [00:21<00:00,  4.31s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrained_mo

***dataset contains 43 records using peptide outside training dataset***
*********
dataset contains peptides
### IVTDFSVIK - 8 ###
### KLGGALQAK - 66 ###
### GLCTLVAML - 1 ###
### ELAGIGILTV - 17 ###
### RAKFKQLL - 16 ###
### GILGFVFTL - 20 ###
### AVFDRKSDAK - 2 ###
### AYAQKIFKI - 1 ###
### IMDQVPFSV - 1 ###
### QYDPVAALF - 1 ###
### SLLMWITQV - 1 ###
### RLRAEAQVK - 1 ###
### SLFNTVATL - 1 ###
*********


  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 7/7 [00:27<00:00,  3.99s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 7/7 [00:27<00:00,  3.98s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 7/7 [00:27<00:00,  3.96s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrained_mo

***dataset contains 43 records using peptide outside training dataset***
### IVTDFSVIK - 8 ###
### KLGGALQAK - 66 ###
### GLCTLVAML - 1 ###
### ELAGIGILTV - 17 ###
### RAKFKQLL - 16 ###
### GILGFVFTL - 20 ###
### AVFDRKSDAK - 2 ###
### AYAQKIFKI - 1 ###
### IMDQVPFSV - 1 ###
### QYDPVAALF - 1 ###
### SLLMWITQV - 1 ###
### RLRAEAQVK - 1 ###
### SLFNTVATL - 1 ###
***dataset contains 136 records using peptide outside training dataset***
***dataset contains 43 records using peptide outside training dataset***
*********
dataset contains peptides
### IVTDFSVIK - 8 ###
### KLGGALQAK - 66 ###
### GLCTLVAML - 1 ###
### ELAGIGILTV - 17 ###
### RAKFKQLL - 16 ###
### GILGFVFTL - 20 ###
### AVFDRKSDAK - 2 ###
### AYAQKIFKI - 1 ###
### IMDQVPFSV - 1 ###
### QYDPVAALF - 1 ###
### SLLMWITQV - 1 ###
### RLRAEAQVK - 1 ###
### SLFNTVATL - 1 ###
*********


  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 6/6 [00:25<00:00,  4.19s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 6/6 [00:25<00:00,  4.19s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 6/6 [00:25<00:00,  4.19s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrained_mo

***dataset contains 119 records using peptide outside training dataset***
### GILGFVFTL - 10 ###
### GLCTLVAML - 2 ###
### ELAGIGILTV - 23 ###
### AVFDRKSDAK - 11 ###
### RAKFKQLL - 4 ###
### KLGGALQAK - 21 ###
### IVTDFSVIK - 6 ###
### YLLEMLWRL - 1 ###
### SLLMWITQV - 1 ###
### RLRAEAQVK - 2 ###
***dataset contains 81 records using peptide outside training dataset***


  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 4/4 [00:19<00:00,  4.81s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 4/4 [00:19<00:00,  4.77s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 4/4 [00:19<00:00,  4.76s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrained_mo

***dataset contains 119 records using peptide outside training dataset***
*********
dataset contains peptides
### GILGFVFTL - 10 ###
### GLCTLVAML - 2 ###
### ELAGIGILTV - 23 ###
### AVFDRKSDAK - 11 ###
### RAKFKQLL - 4 ###
### KLGGALQAK - 21 ###
### IVTDFSVIK - 6 ###
### YLLEMLWRL - 1 ###
### SLLMWITQV - 1 ###
### RLRAEAQVK - 2 ###
*********


  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 4/4 [00:19<00:00,  4.76s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 4/4 [00:18<00:00,  4.74s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 4/4 [00:19<00:00,  4.78s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrained_mo

***dataset contains 119 records using peptide outside training dataset***
### GILGFVFTL - 10 ###
### GLCTLVAML - 2 ###
### ELAGIGILTV - 23 ###
### AVFDRKSDAK - 11 ###
### RAKFKQLL - 4 ###
### KLGGALQAK - 21 ###
### IVTDFSVIK - 6 ###
### YLLEMLWRL - 1 ###
### SLLMWITQV - 1 ###
### RLRAEAQVK - 2 ###
***dataset contains 81 records using peptide outside training dataset***
***dataset contains 119 records using peptide outside training dataset***
*********
dataset contains peptides
### GILGFVFTL - 10 ###
### GLCTLVAML - 2 ###
### ELAGIGILTV - 23 ###
### AVFDRKSDAK - 11 ###
### RAKFKQLL - 4 ###
### KLGGALQAK - 21 ###
### IVTDFSVIK - 6 ###
### YLLEMLWRL - 1 ###
### SLLMWITQV - 1 ###
### RLRAEAQVK - 2 ###
*********


  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 4/4 [00:18<00:00,  4.75s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 4/4 [00:19<00:00,  4.78s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 4/4 [00:19<00:00,  4.79s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrained_mo

***dataset contains 8 records using peptide outside training dataset***
### FLRGRAYGL - 12 ###
### RAKFKQLL - 82 ###
### GLCTLVAML - 26 ###
### CLGGLLTMV - 1 ###
### QAKWRLQTL - 7 ###
### YVLDHLIVV - 26 ###
### CTELKLSDY - 16 ###
### GILGFVFTL - 6 ###
### YLLEMLWRL - 7 ###
### ELRSRYWAI - 1 ###
### LLSLFSLWL - 1 ###
### YLQQNWWTL - 4 ###
### KLQVFLIVL - 1 ###
### FLYALALLL - 1 ###
### RLLPLLALL - 1 ###
### NLAQDLATV - 1 ###
### RGPGRAFVTI - 1 ###
***dataset contains 194 records using peptide outside training dataset***


  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 7/7 [00:28<00:00,  4.05s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 7/7 [00:28<00:00,  4.03s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 7/7 [00:28<00:00,  4.03s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrained_mo

***dataset contains 8 records using peptide outside training dataset***
*********
dataset contains peptides
### FLRGRAYGL - 12 ###
### RAKFKQLL - 82 ###
### GLCTLVAML - 26 ###
### CLGGLLTMV - 1 ###
### QAKWRLQTL - 7 ###
### YVLDHLIVV - 26 ###
### CTELKLSDY - 16 ###
### GILGFVFTL - 6 ###
### YLLEMLWRL - 7 ###
### ELRSRYWAI - 1 ###
### LLSLFSLWL - 1 ###
### YLQQNWWTL - 4 ###
### KLQVFLIVL - 1 ###
### FLYALALLL - 1 ###
### RLLPLLALL - 1 ###
### NLAQDLATV - 1 ###
### RGPGRAFVTI - 1 ###
*********


  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 10/10 [00:35<00:00,  3.54s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 10/10 [00:35<00:00,  3.51s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 10/10 [00:36<00:00,  3.61s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrai

***dataset contains 8 records using peptide outside training dataset***
### FLRGRAYGL - 12 ###
### RAKFKQLL - 82 ###
### GLCTLVAML - 26 ###
### CLGGLLTMV - 1 ###
### QAKWRLQTL - 7 ###
### YVLDHLIVV - 26 ###
### CTELKLSDY - 16 ###
### GILGFVFTL - 6 ###
### YLLEMLWRL - 7 ###
### ELRSRYWAI - 1 ###
### LLSLFSLWL - 1 ###
### YLQQNWWTL - 4 ###
### KLQVFLIVL - 1 ###
### FLYALALLL - 1 ###
### RLLPLLALL - 1 ###
### NLAQDLATV - 1 ###
### RGPGRAFVTI - 1 ###
***dataset contains 194 records using peptide outside training dataset***
***dataset contains 8 records using peptide outside training dataset***
*********
dataset contains peptides
### FLRGRAYGL - 12 ###
### RAKFKQLL - 82 ###
### GLCTLVAML - 26 ###
### CLGGLLTMV - 1 ###
### QAKWRLQTL - 7 ###
### YVLDHLIVV - 26 ###
### CTELKLSDY - 16 ###
### GILGFVFTL - 6 ###
### YLLEMLWRL - 7 ###
### ELRSRYWAI - 1 ###
### LLSLFSLWL - 1 ###
### YLQQNWWTL - 4 ###
### KLQVFLIVL - 1 ###
### FLYALALLL - 1 ###
### RLLPLLALL - 1 ###
### NLAQDLATV - 1 ###
### RGPGRAF

  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 9/9 [00:32<00:00,  3.57s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 9/9 [00:32<00:00,  3.56s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 9/9 [00:32<00:00,  3.57s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrained_mo

***dataset contains 1589 records using peptide outside training dataset***
***dataset contains 0 records using peptide outside training dataset***


  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 34/34 [01:50<00:00,  3.26s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 34/34 [01:50<00:00,  3.26s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 34/34 [01:50<00:00,  3.26s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrai

***dataset contains 1589 records using peptide outside training dataset***
*********
dataset contains peptides
*********


  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 75/75 [03:57<00:00,  3.17s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 75/75 [03:59<00:00,  3.19s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 75/75 [03:58<00:00,  3.19s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrai

***dataset contains 1589 records using peptide outside training dataset***
***dataset contains 0 records using peptide outside training dataset***
***dataset contains 1589 records using peptide outside training dataset***
*********
dataset contains peptides
*********


  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 57/57 [03:01<00:00,  3.19s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 57/57 [03:02<00:00,  3.20s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 57/57 [03:01<00:00,  3.18s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrai

In [19]:
output_df.to_csv('EPACT-performance.csv', index=False)