<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Generate-validation-prediction-files" data-toc-modified-id="Generate-validation-prediction-files-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Generate validation prediction files</a></span></li><li><span><a href="#create-metrics-from-validation-files" data-toc-modified-id="create-metrics-from-validation-files-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>create metrics from validation files</a></span></li></ul></div>

In [1]:
import numpy as np
import pandas as pd
import pdb
from argparse import ArgumentParser
import shlex
from tqdm import tqdm

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

from pMHC import OUTPUT_FOLDER, SEP, \
    SPLITS, SPLIT_TRAIN, SPLIT_VAL, SPLIT_TEST, \
    SPLIT_VAL_PROTEINS, SPLIT_TEST_PROTEINS, SPLIT_VAL_MHC_ALLELES, SPLIT_TEST_MHC_ALLELES, \
    VIEWS, VIEW_SA, VIEW_SAMA, VIEW_DECONV, \
    INPUT_PEPTIDE, INPUT_CONTEXT, \
    BENCHMARK_SA, BENCHMARK_MA, \
    VALIDATION_PERFORMANCE_FILENAME, TEST_PERFORMANCE_FILENAME
from pMHC.logic import PresentationPredictor
from pMHC.data import from_data, to_input
from pMHC.data.example import Sample, Peptide, Observation

tqdm.pandas()

In [2]:
import torch

from pytorch_lightning.utilities.seed import seed_everything

import pMHC
from pMHC import POSITIVE_THRESHOLD
from pMHC.logic.utils import load_latest, list_versions, validate_checkpoints, predict_checkpoints, evaluate_checkpoints
from pMHC.data import from_input
from pMHC.data.utils import get_input_rep_PSEUDO, get_input_rep_FULL, convert_examples_to_batch
from pMHC.data.mhc_allele import MhcAllele
from pMHC.data.protein import Protein

In [3]:
from pMHC.data.split import load_split_mhc_alleles, load_split_proteins
from pMHC.data.view import create_views

In [4]:
pMHC.set_paths(r"/home/tux/Documents/MScProject")

Update project folder to: /home/tux/Documents/MScProject
Load permutation


# Generate validation prediction files

In [5]:
proportion = 0.1

In [6]:
to_assess = [
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Attn-DECOY_19-LR_0.00001", "epoch=0-step=12911", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Cls-DECOY_19-LR_0.00001", "epoch=0-step=12911", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Avg-DECOY_19-LR_0.00001", "epoch=0-step=12911", proportion),  
    
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Attn-DECOY_19-LR_0.000001", "epoch=0-step=12911", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Cls-DECOY_19-LR_0.000001", "epoch=0-step=12911", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Avg-DECOY_19-LR_0.000001", "epoch=0-step=12911", proportion),  
    
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Attn-DECOY_19-LR_0.0001", "epoch=0-step=12911", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Cls-DECOY_19-LR_0.0001", "epoch=0-step=12911", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Avg-DECOY_19-LR_0.0001", "epoch=0-step=12911", proportion),    
    
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Attn-DECOY_19-LR_0.00001", "epoch=4-step=64559", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Cls-DECOY_19-LR_0.00001", "epoch=4-step=64559", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Avg-DECOY_19-LR_0.00001", "epoch=4-step=64559", proportion),  
    
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Attn-DECOY_19-LR_0.000001", "epoch=4-step=64559", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Cls-DECOY_19-LR_0.000001", "epoch=4-step=64559", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Avg-DECOY_19-LR_0.000001", "epoch=4-step=64559", proportion), 
    
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Attn-DECOY_19-LR_0.0001", "epoch=4-step=64559", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Cls-DECOY_19-LR_0.0001", "epoch=4-step=64559", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Avg-DECOY_19-LR_0.0001", "epoch=4-step=64559", proportion),
    
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Attn-DECOY_99-LR_0.00001", "epoch=0-step=12911", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Cls-DECOY_99-LR_0.00001", "epoch=0-step=12911", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Avg-DECOY_99-LR_0.00001", "epoch=0-step=12911", proportion),  
    
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Attn-DECOY_99-LR_0.000001", "epoch=0-step=12911", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Cls-DECOY_99-LR_0.000001", "epoch=0-step=12911", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Avg-DECOY_99-LR_0.000001", "epoch=0-step=12911", proportion),  
    
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Attn-DECOY_99-LR_0.0001", "epoch=0-step=12911", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Cls-DECOY_99-LR_0.0001", "epoch=0-step=12911", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Avg-DECOY_99-LR_0.0001", "epoch=0-step=12911", proportion),    
    
    ("hparam_search", "CONTEXT-PSEUDO-HEAD_Attn-DECOY_99-LR_0.00001", "epoch=0-step=64559", proportion),
    ("hparam_search", "CONTEXT-PSEUDO-HEAD_Cls-DECOY_99-LR_0.00001", "epoch=0-step=64559", proportion),
    ("hparam_search", "CONTEXT-PSEUDO-HEAD_Avg-DECOY_99-LR_0.00001", "epoch=0-step=64559", proportion),  
    
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Attn-DECOY_99-LR_0.000001", "epoch=0-step=64559", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Cls-DECOY_99-LR_0.000001", "epoch=0-step=64559", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Avg-DECOY_99-LR_0.000001", "epoch=0-step=64559", proportion), 
    
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Attn-DECOY_99-LR_0.0001", "epoch=0-step=64559", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Cls-DECOY_99-LR_0.0001", "epoch=0-step=64559", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Avg-DECOY_99-LR_0.0001", "epoch=0-step=64559", proportion),
]

to_assess = [
    ("hparam_search", "CONTEXT-PSEUDO-HEAD_Attn-DECOY_99-LR_0.000001", "epoch=0-step=64559", proportion),
    ("hparam_search", "CONTEXT-PSEUDO-HEAD_Cls-DECOY_99-LR_0.000001", "epoch=0-step=64559", proportion)
]

In [7]:
predict_checkpoints(to_assess, SPLIT_VAL, VIEW_SA)

Validate: /home/tux/Documents/MScProject/output/hparam_search/CONTEXT-PSEUDO-HEAD_Attn-DECOY_99-LR_0.000001/checkpoints/epoch=0-step=64559


Global seed set to 42
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Global seed set to 42


PresentationPredictor.setup: 2021-07-24 15:35:19
MhcAllele.from_input


MhcAlleles from input: 11074it [00:00, 17433.17it/s]


Protein.from_input


Proteins from input: 185930it [00:09, 19558.78it/s]


Sample.from_input


Samples from input: 472it [00:00, 18990.77it/s]


Peptide.from_input


Peptides from input: 429339it [00:22, 19028.23it/s]


Observation.from_input


Observations from input: 1959736it [01:39, 19773.84it/s]


Reduce observations


100%|██████████| 195973/195973 [00:00<00:00, 220788.83it/s]


Decoy.to_input
Load decoys for 195974 observations
decoys_50000_100000.csv
decoys_0_50000.csv
decoys_100000_150000.csv
decoys_150000_200000.csv


Global seed set to 42


PresentationPredictor.setup finished: 2021-07-24 15:37:44


Datasources



Global seed set to 42
Global seed set to 42


OBSERVATIONS                        /   MHC/Obs comb
   Edi                 :  1,550,250 /      7,209,446
   Atlas               :    409,486 /      2,436,897
   TOTAL               :  1,959,736 /      9,646,343




Splits

OBSERVATIONS                          SA            SAMA          Deconv
   train               :           20,659         140,792         673,075
   val                 :            3,338          27,612         131,327
   test                :            5,385          27,569         130,791
   val-prot            :              986           6,943          33,483
   test-prot           :            1,101           6,945          32,751
   val-mhc             :            2,352          20,669          97,844
   test-mhc            :            4,284          20,624          98,040



Dataloaders

EXAMPLES

   train               
      SA               :  2,065,900 /         64,559
      SAMA             : 14,079,200 /        439,975
      Deconv           :    6

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


OBSERVATIONS                        /   MHC/Obs comb
   Edi                 :  1,550,250 /      7,209,446
   Atlas               :    409,486 /      2,436,897
   TOTAL               :  1,959,736 /      9,646,343




Splits

OBSERVATIONS                          SA            SAMA          Deconv
   train               :           20,659         140,792         673,075
   val                 :            3,338          27,612         131,327
   test                :            5,385          27,569         130,791
   val-prot            :              986           6,943          33,483
   test-prot           :            1,101           6,945          32,751
   val-mhc             :            2,352          20,669          97,844
   test-mhc            :            4,284          20,624          98,040



Dataloaders

EXAMPLES

   train               
      SA               :  2,065,900 /         64,559
      SAMA             : 14,079,200 /        439,975
      Deconv           :    6

  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

   val: tp... 0, tn... 31, fp... 0, fn... 1
   val: tp... 0, tn... 31711, fp... 0, fn... 321
   val: tp... 0, tn... 63391, fp... 0, fn... 641
   val: tp... 1, tn... 95071, fp... 0, fn... 960
   val: tp... 1, tn... 126751, fp... 0, fn... 1280
   val: tp... 2, tn... 158430, fp... 1, fn... 1599
   val: tp... 2, tn... 190110, fp... 1, fn... 1919
   val: tp... 3, tn... 221790, fp... 1, fn... 2238
   val: tp... 3, tn... 253470, fp... 1, fn... 2558
   val: tp... 3, tn... 285150, fp... 1, fn... 2878
   val: tp... 3, tn... 316830, fp... 1, fn... 3198
PresentationPredictor: test_epoch_end - SA 
   val: tp... 3, tn... 330461, fp... 1, fn... 3335
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'val_Accuracy': 0.9900059700012207,
 'val_Precision': 0.75,
 'val_Recall': 0.0008987417677417397,
 'val_fn': 3335.0,
 'val_fp': 1.0,
 'val_loss': 0.051798805594444275,
 'val_tn': 330461.0,
 'val_tp': 3.0}
--------------------------------------------

Global seed set to 42
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Global seed set to 42
Global seed set to 42


PresentationPredictor.setup: 2021-07-24 15:46:56
PresentationPredictor.setup finished: 2021-07-24 15:46:56


Datasources



Global seed set to 42
Global seed set to 42


OBSERVATIONS                        /   MHC/Obs comb
   Edi                 :  1,550,250 /      7,209,446
   Atlas               :    409,486 /      2,436,897
   TOTAL               :  1,959,736 /      9,646,343




Splits

OBSERVATIONS                          SA            SAMA          Deconv
   train               :           20,659         140,792         673,075
   val                 :            3,338          27,612         131,327
   test                :            5,385          27,569         130,791
   val-prot            :              986           6,943          33,483
   test-prot           :            1,101           6,945          32,751
   val-mhc             :            2,352          20,669          97,844
   test-mhc            :            4,284          20,624          98,040



Dataloaders

EXAMPLES

   train               
      SA               :  2,065,900 /         64,559
      SAMA             : 14,079,200 /        439,975
      Deconv           :    6

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


OBSERVATIONS                        /   MHC/Obs comb
   Edi                 :  1,550,250 /      7,209,446
   Atlas               :    409,486 /      2,436,897
   TOTAL               :  1,959,736 /      9,646,343




Splits

OBSERVATIONS                          SA            SAMA          Deconv
   train               :           20,659         140,792         673,075
   val                 :            3,338          27,612         131,327
   test                :            5,385          27,569         130,791
   val-prot            :              986           6,943          33,483
   test-prot           :            1,101           6,945          32,751
   val-mhc             :            2,352          20,669          97,844
   test-mhc            :            4,284          20,624          98,040



Dataloaders

EXAMPLES

   train               
      SA               :  2,065,900 /         64,559
      SAMA             : 14,079,200 /        439,975
      Deconv           :    6

Testing: 0it [00:00, ?it/s]

   val: tp... 0, tn... 31, fp... 0, fn... 1
   val: tp... 1, tn... 31711, fp... 0, fn... 320
   val: tp... 1, tn... 63389, fp... 2, fn... 640
   val: tp... 4, tn... 95068, fp... 3, fn... 957
   val: tp... 5, tn... 126748, fp... 3, fn... 1276
   val: tp... 6, tn... 158427, fp... 4, fn... 1595
   val: tp... 6, tn... 190107, fp... 4, fn... 1915
   val: tp... 7, tn... 221786, fp... 5, fn... 2234
   val: tp... 8, tn... 253466, fp... 5, fn... 2553
   val: tp... 8, tn... 285146, fp... 5, fn... 2873
   val: tp... 8, tn... 316826, fp... 5, fn... 3193
PresentationPredictor: test_epoch_end - SA 
   val: tp... 8, tn... 330456, fp... 6, fn... 3330
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'val_Accuracy': 0.9900059700012207,
 'val_Precision': 0.5714285969734192,
 'val_Recall': 0.0023966447915881872,
 'val_fn': 3330.0,
 'val_fp': 6.0,
 'val_loss': 0.05223966762423515,
 'val_tn': 330456.0,
 'val_tp': 8.0}
-------------------------------

# create metrics from validation files

In [5]:
proportion = 0.1

In [6]:
pMHC.data.setup_system(proportion, 99)

MhcAllele.from_input


MhcAlleles from input: 11074it [00:00, 17831.10it/s]


Protein.from_input


Proteins from input: 185930it [00:09, 19331.21it/s]


Sample.from_input


Samples from input: 472it [00:00, 18667.89it/s]


Peptide.from_input


Peptides from input: 429339it [00:22, 18743.10it/s]


Observation.from_input


Observations from input: 1959736it [01:42, 19074.41it/s]


Reduce observations


100%|██████████| 195973/195973 [00:00<00:00, 212992.81it/s]


Decoy.to_input
Load decoys for 195974 observations
decoys_50000_100000.csv
decoys_0_50000.csv
decoys_100000_150000.csv
decoys_150000_200000.csv


In [7]:
to_assess = [
    ("hparam_search", "CONTEXT-PSEUDO-HEAD_Attn-DECOY_19-LR_0.00001", "epoch=0-step=12911", proportion),
    ("hparam_search", "CONTEXT-PSEUDO-HEAD_Cls-DECOY_19-LR_0.00001", "epoch=0-step=12911", proportion),
    ("hparam_search", "CONTEXT-PSEUDO-HEAD_Avg-DECOY_19-LR_0.00001", "epoch=0-step=12911", proportion),  
    
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Attn-DECOY_19-LR_0.000001", "epoch=0-step=12911", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Cls-DECOY_19-LR_0.000001", "epoch=0-step=12911", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Avg-DECOY_19-LR_0.000001", "epoch=0-step=12911", proportion),  
    
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Attn-DECOY_19-LR_0.0001", "epoch=0-step=12911", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Cls-DECOY_19-LR_0.0001", "epoch=0-step=12911", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Avg-DECOY_19-LR_0.0001", "epoch=0-step=12911", proportion),    
    
    ("hparam_search", "CONTEXT-PSEUDO-HEAD_Attn-DECOY_19-LR_0.00001", "epoch=4-step=64559", proportion),
    ("hparam_search", "CONTEXT-PSEUDO-HEAD_Cls-DECOY_19-LR_0.00001", "epoch=4-step=64559", proportion),
    ("hparam_search", "CONTEXT-PSEUDO-HEAD_Avg-DECOY_19-LR_0.00001", "epoch=4-step=64559", proportion),  
    
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Attn-DECOY_19-LR_0.000001", "epoch=4-step=64559", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Cls-DECOY_19-LR_0.000001", "epoch=4-step=64559", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Avg-DECOY_19-LR_0.000001", "epoch=4-step=64559", proportion), 
    
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Attn-DECOY_19-LR_0.0001", "epoch=4-step=64559", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Cls-DECOY_19-LR_0.0001", "epoch=4-step=64559", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Avg-DECOY_19-LR_0.0001", "epoch=4-step=64559", proportion),
    
    ("hparam_search", "CONTEXT-PSEUDO-HEAD_Attn-DECOY_99-LR_0.00001", "epoch=0-step=12911", proportion),
    ("hparam_search", "CONTEXT-PSEUDO-HEAD_Cls-DECOY_99-LR_0.00001", "epoch=0-step=12911", proportion),
    ("hparam_search", "CONTEXT-PSEUDO-HEAD_Avg-DECOY_99-LR_0.00001", "epoch=0-step=12911", proportion),  
    
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Attn-DECOY_99-LR_0.000001", "epoch=0-step=12911", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Cls-DECOY_99-LR_0.000001", "epoch=0-step=12911", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Avg-DECOY_99-LR_0.000001", "epoch=0-step=12911", proportion),  
    
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Attn-DECOY_99-LR_0.0001", "epoch=0-step=12911", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Cls-DECOY_99-LR_0.0001", "epoch=0-step=12911", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Avg-DECOY_99-LR_0.0001", "epoch=0-step=12911", proportion),    
    
    ("hparam_search", "CONTEXT-PSEUDO-HEAD_Attn-DECOY_99-LR_0.00001", "epoch=0-step=64559", proportion),
    ("hparam_search", "CONTEXT-PSEUDO-HEAD_Cls-DECOY_99-LR_0.00001", "epoch=0-step=64559", proportion),
    ("hparam_search", "CONTEXT-PSEUDO-HEAD_Avg-DECOY_99-LR_0.00001", "epoch=0-step=64559", proportion),  
    
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Attn-DECOY_99-LR_0.000001", "epoch=0-step=64559", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Cls-DECOY_99-LR_0.000001", "epoch=0-step=64559", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Avg-DECOY_99-LR_0.000001", "epoch=0-step=64559", proportion), 
    
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Attn-DECOY_99-LR_0.0001", "epoch=0-step=64559", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Cls-DECOY_99-LR_0.0001", "epoch=0-step=64559", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Avg-DECOY_99-LR_0.0001", "epoch=0-step=64559", proportion),
]

In [8]:
info = evaluate_checkpoints(to_assess, [SPLIT_VAL_PROTEINS, SPLIT_VAL_MHC_ALLELES], [VIEW_SA])

Load: hparam_search_CONTEXT-PSEUDO-HEAD_Attn-DECOY_19-LR_0.00001_epoch=0-step=12911_0.1 - val-prot - SA


100%|██████████| 333800/333800 [00:05<00:00, 65279.65it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Attn-DECOY_19-LR_0.00001_epoch=0-step=12911_0.1 - val-mhc - SA


100%|██████████| 333800/333800 [00:10<00:00, 32081.86it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Cls-DECOY_19-LR_0.00001_epoch=0-step=12911_0.1 - val-prot - SA


100%|██████████| 333800/333800 [00:05<00:00, 64753.83it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Cls-DECOY_19-LR_0.00001_epoch=0-step=12911_0.1 - val-mhc - SA


100%|██████████| 333800/333800 [00:10<00:00, 32244.73it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Avg-DECOY_19-LR_0.00001_epoch=0-step=12911_0.1 - val-prot - SA


100%|██████████| 333800/333800 [00:05<00:00, 64963.95it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Avg-DECOY_19-LR_0.00001_epoch=0-step=12911_0.1 - val-mhc - SA


100%|██████████| 333800/333800 [00:10<00:00, 32200.53it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Attn-DECOY_19-LR_0.00001_epoch=4-step=64559_0.1 - val-prot - SA


100%|██████████| 333800/333800 [00:05<00:00, 64925.95it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Attn-DECOY_19-LR_0.00001_epoch=4-step=64559_0.1 - val-mhc - SA


100%|██████████| 333800/333800 [00:10<00:00, 32252.04it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Cls-DECOY_19-LR_0.00001_epoch=4-step=64559_0.1 - val-prot - SA


100%|██████████| 333800/333800 [00:05<00:00, 64984.99it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Cls-DECOY_19-LR_0.00001_epoch=4-step=64559_0.1 - val-mhc - SA


100%|██████████| 333800/333800 [00:10<00:00, 32195.22it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Avg-DECOY_19-LR_0.00001_epoch=4-step=64559_0.1 - val-prot - SA


100%|██████████| 333800/333800 [00:05<00:00, 64609.42it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Avg-DECOY_19-LR_0.00001_epoch=4-step=64559_0.1 - val-mhc - SA


100%|██████████| 333800/333800 [00:10<00:00, 32156.15it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Attn-DECOY_99-LR_0.00001_epoch=0-step=12911_0.1 - val-prot - SA


100%|██████████| 333800/333800 [00:05<00:00, 64916.39it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Attn-DECOY_99-LR_0.00001_epoch=0-step=12911_0.1 - val-mhc - SA


100%|██████████| 333800/333800 [00:10<00:00, 32099.29it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Cls-DECOY_99-LR_0.00001_epoch=0-step=12911_0.1 - val-prot - SA


100%|██████████| 333800/333800 [00:05<00:00, 65066.83it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Cls-DECOY_99-LR_0.00001_epoch=0-step=12911_0.1 - val-mhc - SA


100%|██████████| 333800/333800 [00:10<00:00, 32049.82it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Avg-DECOY_99-LR_0.00001_epoch=0-step=12911_0.1 - val-prot - SA


100%|██████████| 333800/333800 [00:05<00:00, 64879.74it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Avg-DECOY_99-LR_0.00001_epoch=0-step=12911_0.1 - val-mhc - SA


100%|██████████| 333800/333800 [00:10<00:00, 32089.00it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Attn-DECOY_99-LR_0.00001_epoch=0-step=64559_0.1 - val-prot - SA


100%|██████████| 333800/333800 [00:05<00:00, 64946.43it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Attn-DECOY_99-LR_0.00001_epoch=0-step=64559_0.1 - val-mhc - SA


100%|██████████| 333800/333800 [00:10<00:00, 32126.28it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Cls-DECOY_99-LR_0.00001_epoch=0-step=64559_0.1 - val-prot - SA


100%|██████████| 333800/333800 [00:05<00:00, 65333.91it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Cls-DECOY_99-LR_0.00001_epoch=0-step=64559_0.1 - val-mhc - SA


100%|██████████| 333800/333800 [00:10<00:00, 32410.03it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Avg-DECOY_99-LR_0.00001_epoch=0-step=64559_0.1 - val-prot - SA


100%|██████████| 333800/333800 [00:05<00:00, 65315.39it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Avg-DECOY_99-LR_0.00001_epoch=0-step=64559_0.1 - val-mhc - SA


100%|██████████| 333800/333800 [00:10<00:00, 32408.77it/s]


In [9]:
list(info.keys())

['hparam_search_CONTEXT-PSEUDO-HEAD_Attn-DECOY_19-LR_0.00001_epoch=0-step=12911_0.1',
 'hparam_search_CONTEXT-PSEUDO-HEAD_Cls-DECOY_19-LR_0.00001_epoch=0-step=12911_0.1',
 'hparam_search_CONTEXT-PSEUDO-HEAD_Avg-DECOY_19-LR_0.00001_epoch=0-step=12911_0.1',
 'hparam_search_CONTEXT-PSEUDO-HEAD_Attn-DECOY_19-LR_0.00001_epoch=4-step=64559_0.1',
 'hparam_search_CONTEXT-PSEUDO-HEAD_Cls-DECOY_19-LR_0.00001_epoch=4-step=64559_0.1',
 'hparam_search_CONTEXT-PSEUDO-HEAD_Avg-DECOY_19-LR_0.00001_epoch=4-step=64559_0.1',
 'hparam_search_CONTEXT-PSEUDO-HEAD_Attn-DECOY_99-LR_0.00001_epoch=0-step=12911_0.1',
 'hparam_search_CONTEXT-PSEUDO-HEAD_Cls-DECOY_99-LR_0.00001_epoch=0-step=12911_0.1',
 'hparam_search_CONTEXT-PSEUDO-HEAD_Avg-DECOY_99-LR_0.00001_epoch=0-step=12911_0.1',
 'hparam_search_CONTEXT-PSEUDO-HEAD_Attn-DECOY_99-LR_0.00001_epoch=0-step=64559_0.1',
 'hparam_search_CONTEXT-PSEUDO-HEAD_Cls-DECOY_99-LR_0.00001_epoch=0-step=64559_0.1',
 'hparam_search_CONTEXT-PSEUDO-HEAD_Avg-DECOY_99-LR_0.00001_e

In [14]:
learning_rate = "0.00001"
proportion = 0.1

for head in ["Cls", "Attn", "Avg"]:
    for decoys_per_obs in [19, 99]:
        step_sel = [12911, 64559]
        for steps in step_sel:
            epoch = 0 if (decoys_per_obs == 99 or steps == 12911) else 4
            model_name = f"hparam_search_CONTEXT-PSEUDO-HEAD_{head}-DECOY_{decoys_per_obs}-LR_{learning_rate}"
            checkpoint_name = f"{model_name}_epoch={epoch}-step={steps}_{proportion}"
            print(f"{head} {decoys_per_obs} {steps}", end="")
            #print(f"{info[checkpoint_name]}", end="")
            for split in ["val-mhc", "val-prot"]:
                for metric in ["AP", "roc_auc", "accuracy"]:
                    print(f" & {info[checkpoint_name][split]['SA'][metric]:.3f}", end="")
            
            print("")

Cls 19 12911 & 0.319 & 0.938 & 0.986 & 0.447 & 0.956 & 0.987
Cls 19 64559 & 0.394 & 0.949 & 0.982 & 0.552 & 0.964 & 0.983
Cls 99 12911 & 0.062 & 0.823 & 0.990 & 0.232 & 0.917 & 0.990
Cls 99 64559 & 0.349 & 0.945 & 0.991 & 0.492 & 0.960 & 0.992
Attn 19 12911 & 0.298 & 0.936 & 0.973 & 0.447 & 0.957 & 0.982
Attn 19 64559 & 0.410 & 0.941 & 0.983 & 0.535 & 0.964 & 0.982
Attn 99 12911 & 0.044 & 0.792 & 0.990 & 0.184 & 0.884 & 0.990
Attn 99 64559 & 0.315 & 0.938 & 0.990 & 0.496 & 0.958 & 0.992
Avg 19 12911 & 0.252 & 0.926 & 0.982 & 0.417 & 0.954 & 0.985
Avg 19 64559 & 0.369 & 0.949 & 0.977 & 0.506 & 0.960 & 0.980
Avg 99 12911 & 0.026 & 0.689 & 0.990 & 0.145 & 0.835 & 0.990
Avg 99 64559 & 0.287 & 0.924 & 0.990 & 0.471 & 0.956 & 0.992


In [7]:
to_assess = [
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Attn-DECOY_19-LR_0.00001", "epoch=0-step=12911", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Cls-DECOY_19-LR_0.00001", "epoch=0-step=12911", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Avg-DECOY_19-LR_0.00001", "epoch=0-step=12911", proportion),  
    
    ("hparam_search", "CONTEXT-PSEUDO-HEAD_Attn-DECOY_19-LR_0.000001", "epoch=0-step=12911", proportion),
    ("hparam_search", "CONTEXT-PSEUDO-HEAD_Cls-DECOY_19-LR_0.000001", "epoch=0-step=12911", proportion),
    ("hparam_search", "CONTEXT-PSEUDO-HEAD_Avg-DECOY_19-LR_0.000001", "epoch=0-step=12911", proportion),  
    
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Attn-DECOY_19-LR_0.0001", "epoch=0-step=12911", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Cls-DECOY_19-LR_0.0001", "epoch=0-step=12911", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Avg-DECOY_19-LR_0.0001", "epoch=0-step=12911", proportion),    
    
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Attn-DECOY_19-LR_0.00001", "epoch=4-step=64559", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Cls-DECOY_19-LR_0.00001", "epoch=4-step=64559", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Avg-DECOY_19-LR_0.00001", "epoch=4-step=64559", proportion),  
    
    ("hparam_search", "CONTEXT-PSEUDO-HEAD_Attn-DECOY_19-LR_0.000001", "epoch=4-step=64559", proportion),
    ("hparam_search", "CONTEXT-PSEUDO-HEAD_Cls-DECOY_19-LR_0.000001", "epoch=4-step=64559", proportion),
    ("hparam_search", "CONTEXT-PSEUDO-HEAD_Avg-DECOY_19-LR_0.000001", "epoch=4-step=64559", proportion), 
    
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Attn-DECOY_19-LR_0.0001", "epoch=4-step=64559", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Cls-DECOY_19-LR_0.0001", "epoch=4-step=64559", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Avg-DECOY_19-LR_0.0001", "epoch=4-step=64559", proportion),
    
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Attn-DECOY_99-LR_0.00001", "epoch=0-step=12911", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Cls-DECOY_99-LR_0.00001", "epoch=0-step=12911", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Avg-DECOY_99-LR_0.00001", "epoch=0-step=12911", proportion),  
    
    ("hparam_search", "CONTEXT-PSEUDO-HEAD_Attn-DECOY_99-LR_0.000001", "epoch=0-step=12911", proportion),
    ("hparam_search", "CONTEXT-PSEUDO-HEAD_Cls-DECOY_99-LR_0.000001", "epoch=0-step=12911", proportion),
    ("hparam_search", "CONTEXT-PSEUDO-HEAD_Avg-DECOY_99-LR_0.000001", "epoch=0-step=12911", proportion),  
    
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Attn-DECOY_99-LR_0.0001", "epoch=0-step=12911", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Cls-DECOY_99-LR_0.0001", "epoch=0-step=12911", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Avg-DECOY_99-LR_0.0001", "epoch=0-step=12911", proportion),    
    
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Attn-DECOY_99-LR_0.00001", "epoch=0-step=64559", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Cls-DECOY_99-LR_0.00001", "epoch=0-step=64559", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Avg-DECOY_99-LR_0.00001", "epoch=0-step=64559", proportion),  
    
    ("hparam_search", "CONTEXT-PSEUDO-HEAD_Attn-DECOY_99-LR_0.000001", "epoch=0-step=64559", proportion),
    ("hparam_search", "CONTEXT-PSEUDO-HEAD_Cls-DECOY_99-LR_0.000001", "epoch=0-step=64559", proportion),
    ("hparam_search", "CONTEXT-PSEUDO-HEAD_Avg-DECOY_99-LR_0.000001", "epoch=0-step=64559", proportion), 
    
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Attn-DECOY_99-LR_0.0001", "epoch=0-step=64559", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Cls-DECOY_99-LR_0.0001", "epoch=0-step=64559", proportion),
    # ("hparam_search", "CONTEXT-PSEUDO-HEAD_Avg-DECOY_99-LR_0.0001", "epoch=0-step=64559", proportion),
]


In [8]:
info = evaluate_checkpoints(to_assess, [SPLIT_VAL_PROTEINS, SPLIT_VAL_MHC_ALLELES], [VIEW_SA])

Load: hparam_search_CONTEXT-PSEUDO-HEAD_Attn-DECOY_19-LR_0.000001_epoch=0-step=12911_0.1 - val-prot - SA


100%|██████████| 333800/333800 [00:05<00:00, 66674.17it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Attn-DECOY_19-LR_0.000001_epoch=0-step=12911_0.1 - val-mhc - SA


100%|██████████| 333800/333800 [00:10<00:00, 32805.18it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Cls-DECOY_19-LR_0.000001_epoch=0-step=12911_0.1 - val-prot - SA


100%|██████████| 333800/333800 [00:04<00:00, 67171.23it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Cls-DECOY_19-LR_0.000001_epoch=0-step=12911_0.1 - val-mhc - SA


100%|██████████| 333800/333800 [00:10<00:00, 32954.20it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Avg-DECOY_19-LR_0.000001_epoch=0-step=12911_0.1 - val-prot - SA


100%|██████████| 333800/333800 [00:05<00:00, 65840.00it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Avg-DECOY_19-LR_0.000001_epoch=0-step=12911_0.1 - val-mhc - SA


100%|██████████| 333800/333800 [00:10<00:00, 31786.15it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Attn-DECOY_19-LR_0.000001_epoch=4-step=64559_0.1 - val-prot - SA


100%|██████████| 333800/333800 [00:05<00:00, 64751.74it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Attn-DECOY_19-LR_0.000001_epoch=4-step=64559_0.1 - val-mhc - SA


100%|██████████| 333800/333800 [00:10<00:00, 31360.21it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Cls-DECOY_19-LR_0.000001_epoch=4-step=64559_0.1 - val-prot - SA


100%|██████████| 333800/333800 [00:05<00:00, 63679.53it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Cls-DECOY_19-LR_0.000001_epoch=4-step=64559_0.1 - val-mhc - SA


100%|██████████| 333800/333800 [00:10<00:00, 31113.32it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Avg-DECOY_19-LR_0.000001_epoch=4-step=64559_0.1 - val-prot - SA


100%|██████████| 333800/333800 [00:05<00:00, 63646.58it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Avg-DECOY_19-LR_0.000001_epoch=4-step=64559_0.1 - val-mhc - SA


100%|██████████| 333800/333800 [00:10<00:00, 31094.64it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Attn-DECOY_99-LR_0.000001_epoch=0-step=12911_0.1 - val-prot - SA


100%|██████████| 333800/333800 [00:05<00:00, 63677.20it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Attn-DECOY_99-LR_0.000001_epoch=0-step=12911_0.1 - val-mhc - SA


100%|██████████| 333800/333800 [00:10<00:00, 31021.47it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Cls-DECOY_99-LR_0.000001_epoch=0-step=12911_0.1 - val-prot - SA


100%|██████████| 333800/333800 [00:05<00:00, 63470.93it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Cls-DECOY_99-LR_0.000001_epoch=0-step=12911_0.1 - val-mhc - SA


100%|██████████| 333800/333800 [00:10<00:00, 31068.43it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Avg-DECOY_99-LR_0.000001_epoch=0-step=12911_0.1 - val-prot - SA


100%|██████████| 333800/333800 [00:05<00:00, 63588.76it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Avg-DECOY_99-LR_0.000001_epoch=0-step=12911_0.1 - val-mhc - SA


100%|██████████| 333800/333800 [00:10<00:00, 31085.62it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Attn-DECOY_99-LR_0.000001_epoch=0-step=64559_0.1 - val-prot - SA


100%|██████████| 333800/333800 [00:05<00:00, 63854.74it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Attn-DECOY_99-LR_0.000001_epoch=0-step=64559_0.1 - val-mhc - SA


100%|██████████| 333800/333800 [00:10<00:00, 31061.83it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Cls-DECOY_99-LR_0.000001_epoch=0-step=64559_0.1 - val-prot - SA


100%|██████████| 333800/333800 [00:05<00:00, 63780.87it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Cls-DECOY_99-LR_0.000001_epoch=0-step=64559_0.1 - val-mhc - SA


100%|██████████| 333800/333800 [00:10<00:00, 31196.65it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Avg-DECOY_99-LR_0.000001_epoch=0-step=64559_0.1 - val-prot - SA


100%|██████████| 333800/333800 [00:05<00:00, 63912.19it/s]


Load: hparam_search_CONTEXT-PSEUDO-HEAD_Avg-DECOY_99-LR_0.000001_epoch=0-step=64559_0.1 - val-mhc - SA


100%|██████████| 333800/333800 [00:10<00:00, 31344.03it/s]


In [9]:
learning_rate = "0.000001"
proportion = 0.1

for head in ["Cls", "Attn", "Avg"]:
    for decoys_per_obs in [19, 99]:
        step_sel = [12911, 64559]
        for steps in step_sel:
            epoch = 0 if (decoys_per_obs == 99 or steps == 12911) else 4
            model_name = f"hparam_search_CONTEXT-PSEUDO-HEAD_{head}-DECOY_{decoys_per_obs}-LR_{learning_rate}"
            checkpoint_name = f"{model_name}_epoch={epoch}-step={steps}_{proportion}"
            print(f"{head} {decoys_per_obs} {steps}", end="")
            #print(f"{info[checkpoint_name]}", end="")
            for split in ["val-mhc", "val-prot"]:
                for metric in ["AP", "roc_auc", "accuracy"]:
                    print(f" & {info[checkpoint_name][split]['SA'][metric]:.3f}", end="")
            
            print("")

Cls 19 12911 & 0.020 & 0.660 & 0.990 & 0.061 & 0.771 & 0.990
Cls 19 64559 & 0.209 & 0.902 & 0.974 & 0.356 & 0.943 & 0.975
Cls 99 12911 & 0.013 & 0.584 & 0.990 & 0.033 & 0.689 & 0.990
Cls 99 64559 & 0.027 & 0.721 & 0.990 & 0.119 & 0.835 & 0.990
Attn 19 12911 & 0.018 & 0.649 & 0.990 & 0.073 & 0.774 & 0.990
Attn 19 64559 & 0.210 & 0.902 & 0.975 & 0.349 & 0.937 & 0.975
Attn 99 12911 & 0.014 & 0.594 & 0.990 & 0.034 & 0.694 & 0.990
Attn 99 64559 & 0.029 & 0.731 & 0.990 & 0.132 & 0.829 & 0.990
Avg 19 12911 & 0.020 & 0.657 & 0.990 & 0.081 & 0.776 & 0.990
Avg 19 64559 & 0.202 & 0.892 & 0.969 & 0.349 & 0.937 & 0.970
Avg 99 12911 & 0.016 & 0.601 & 0.990 & 0.037 & 0.700 & 0.990
Avg 99 64559 & 0.031 & 0.733 & 0.990 & 0.134 & 0.840 & 0.990
