In [8]:
import os

from torch.utils.data import DataLoader
from EPACT.utils import load_config, set_seed
from EPACT.dataset import UnlabeledDataset, UnlabeledBacthConverter
from EPACT.trainer import PairedCDR3pMHCCoembeddingTrainer, PairedCDR123pMHCCoembeddingTrainer

In [9]:
#@markdown Select the EPACT model:
model_name = "CDR3 binding model" #@param ['CDR3 binding model', 'CDR123 binding model']

#@markdown In default, we will use `sample/VDJdb-GLCTLVAML.csv` for prediction.
input_data_path = "sample/VDJdb-GLCTLVAML.csv" #@param {type:"string"}

#@markdown Specify the name of the result folder:
result_dir = "demo/binding" #@param {type:"string"}

#@markdown Specify the number of batch size:
batch_size = 128 #@param {type: "integer"}

In [10]:
if model_name == "CDR3 binding model":
  config_path = 'configs/config-paired-cdr3-pmhc-binding.yml'
  model_location_list = [f'checkpoints/paired-cdr3-pmhc-binding/paired-cdr3-pmhc-binding-model-fold-{i+1}.pt' for i in range(5)]
elif model_name == "CDR123 binding model":
  config_path = 'configs/config-paired-cdr123-pmhc-binding.yml'
  model_location_list = [f'checkpoints/paired-cdr123-pmhc-binding/paired-cdr123-pmhc-binding-model-fold-{i+1}.pt' for i in range(5)]

config = load_config(config_path)
set_seed(config.training.seed)
config.training.gpu_device = 0

In [11]:
dataset = UnlabeledDataset(data_path = input_data_path, hla_lib_path = config.data.hla_lib_path)
data_loader = DataLoader(
        dataset = dataset, batch_size = batch_size, num_workers = 1,
        collate_fn = UnlabeledBacthConverter(max_mhc_len = config.model.mhc_seq_len, use_cdr123=config.data.use_cdr123),
        shuffle = False
    )

if not os.path.exists(result_dir):
  os.makedirs(result_dir)

for i in range(5):
  result_fold_dir = os.path.join(result_dir, f'Fold_{i+1}')

  if not os.path.exists(result_fold_dir):
    os.makedirs(result_fold_dir)

  if config.data.use_cdr123:
      Trainer = PairedCDR123pMHCCoembeddingTrainer(config, result_fold_dir)
  else:
      Trainer = PairedCDR3pMHCCoembeddingTrainer(config, result_fold_dir)

  Trainer.predict(data_loader, model_location=model_location_list[i])

  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 2/2 [00:12<00:00,  6.21s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 2/2 [00:12<00:00,  6.17s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.tcr_model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu'))
  self.model.load_state_dict(torch.load(model_location, map_location='cpu'), strict=False)
100%|██████████| 2/2 [00:12<00:00,  6.18s/it]
  self.pmhc_model.load_state_dict(torch.load(pretrained_mo

In [12]:
#@title Display prediction results
import pandas as pd
from sklearn.metrics import roc_auc_score

data = pd.read_csv(input_data_path)
for i in range(5):
  prediction = pd.read_csv(f'{result_dir}/Fold_{i+1}/predictions.csv')
  if i == 0:
    avg_pred = prediction['Pred'] / 5
  else:
    avg_pred += prediction['Pred'] / 5

data['Pred'] = avg_pred
auc = roc_auc_score(data['Target'], data['Pred'])

In [13]:
print(auc)

0.8673130193905817
