In [8]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import json
import torch
from selfpeptide.model.immunogenicity_predictor import JointPeptidesNetwork_Beta

In [2]:
with open("../trained_models/BindingModels/floral-sweep-3/config.json", "r") as f:
    binding_config = json.load(f)
for k in binding_config.keys():
    if not isinstance(binding_config[k], str):
        continue
    if "/home/gvisona/SelfPeptides" in binding_config[k]:
        binding_config[k] = binding_config[k].replace("/home/gvisona/SelfPeptides", "/home/gvisona/Projects/SelfPeptides")
    if "/fast/gvisona/SelfPeptides" in binding_config[k]:
        binding_config[k] = binding_config[k].replace("/fast/gvisona/SelfPeptides", "/home/gvisona/Projects/SelfPeptides")
binding_config

{'PMA_ln': True,
 'PMA_num_heads': 4,
 'accumulate_batches': 64,
 'batch_size': 16,
 'binding_affinity_df': '/home/gvisona/Projects/SelfPeptides/processed_data/Binding_Affinity/DHLAP_binding_affinity_data.csv',
 'cool_down': 0.8,
 'dropout_p': 0.05,
 'early_stopping': True,
 'embedding_dim': 512,
 'experiment_group': 'Binding_affinity_training_LS_joint',
 'experiment_name': 'Binding_model_LS_joint_seeds',
 'hla_embedding_dim': 1024,
 'hla_prot_seq_file': '/home/gvisona/Projects/SelfPeptides/processed_data/HLA_embeddings/HLA_proteins_T5/hla_proteins_mapping.csv',
 'joint_embedder_hidden_dim': 2048,
 'ligand_atlas_binding_df': '/home/gvisona/Projects/SelfPeptides/processed_data/Binding_Affinity/HLA_Ligand_Atlas_processed.csv',
 'lr': 0.00030967294546198955,
 'ls_alpha': 0.05,
 'max_updates': 20000,
 'min_frac': 0.1,
 'mlp_hidden_dim': 2048,
 'mlp_input_dim': 512,
 'mlp_num_layers': 2,
 'momentum': 0.9,
 'n_attention_layers': 2,
 'nesterov_momentum': True,
 'num_heads': 4,
 'output_dim': 

In [6]:
with open("../outputs/BetaRegression_MeanRegression/SavingBestPearsonR/config.json", "r") as f:
    config = json.load(f)
for k in config.keys():
    if not isinstance(config[k], str):
        continue
    if "/home/gvisona/SelfPeptides" in config[k]:
        config[k] = config[k].replace("/home/gvisona/SelfPeptides", "/home/gvisona/Projects/SelfPeptides")
    if "/fast/gvisona/SelfPeptides" in config[k]:
        config[k] = config[k].replace("/fast/gvisona/SelfPeptides", "/home/gvisona/Projects/SelfPeptides")
config

{'seed': 42,
 'run_number': 'SavingBestPearsonR',
 'experiment_name': 'BetaRegr_Exploration',
 'experiment_group': 'BetaRegression_MeanRegression',
 'project_folder': '/home/gvisona/Projects/SelfPeptides',
 'init_checkpoint': None,
 'force_restart': True,
 'min_subjects_tested': 1,
 'hla_filter': None,
 'immunogenicity_df': '/home/gvisona/Projects/SelfPeptides/processed_data/Immunogenicity/Processed_TCell_IEDB_beta_summed.csv',
 'dhlap_df': '/home/gvisona/Projects/SelfPeptides/processed_data/Immunogenicity/DHLAP_immunogenicity_data.csv',
 'pseudo_seq_file': '/home/gvisona/Projects/SelfPeptides/processed_data/HLA_embeddings/HLA_pseudoseqs_T5/hla_pseudoseq_mapping.csv',
 'hla_prot_seq_file': '/home/gvisona/Projects/SelfPeptides/processed_data/HLA_embeddings/HLA_proteins_T5/hla_proteins_mapping.csv',
 'binding_model_config': '/home/gvisona/Projects/SelfPeptides/trained_models/BindingModels/floral-sweep-3/config.json',
 'binding_model_checkpoint': '/home/gvisona/Projects/SelfPeptides/train

In [11]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model = JointPeptidesNetwork_Beta(config, binding_config, 
                         binding_checkpoint=config["binding_model_checkpoint"], 
                         device=device)

model.load_state_dict(torch.load("../outputs/BetaRegression_MeanRegression/SavingBestPearsonR/checkpoints/001_checkpoint.pt", map_location=device))
model

JointPeptidesNetwork_Beta(
  (binding_model): Peptide_HLA_BindingClassifier(
    (peptide_embedder): PeptideEmbedder(
      (tokenizer): AA_Tokenizer()
      (aa_embs): Embedding(23, 512, padding_idx=22)
      (transformer_encoder): TransformerEncoder(
        (pos_encoding): PositionalEncoding(
          (dropout): Dropout(p=0.05, inplace=False)
        )
        (dropout): Dropout(p=0.05, inplace=False)
        (encoder_layers): ModuleList(
          (0-1): 2 x TEncoderLayer(
            (multihead_attention): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
            )
            (dropout1): Dropout(p=0.05, inplace=False)
            (res_norm1): ResNorm(
              (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
            )
            (feed_forward): Sequential(
              (0): Linear(in_features=512, out_features=512, bias=True)
              (1): ReLU()
              (2): Linear(i

In [12]:
results_df = pd.read_csv("../outputs/BetaRegression_MeanRegression/results/test_predictions_SavingBestPearsonR.csv")
results_df

Unnamed: 0,Peptide,HLA,Qualitative Measurement,Peptide length,Number of Subjects Tested,Number of Subjects Positive,Alpha,Beta,Allele Pseudo-sequence,Allele Protein sequence,Target,Sample,Peptide Length,Distr. Mean,Distr. Variance,Distr. Mode,Distr. Precision,Stratification_index,Prediction Distr. Mean,Prediction Distr. Precision
0,QILKGGSGT,HLA-A02:01,Positive,9,15.0,12.0,13.0,4.0,YFAMYGEKVAHTHVDTLYVRYHYYTWAVLAYTWY,MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,1,QILKGGSGT_HLA-A02:01,9,0.764706,0.009996,0.800000,17.0,HLA-A02:01_1,0.080882,11.139398
1,AITEVECFL,HLA-A02:01,Positive,9,13.0,6.0,7.0,8.0,YFAMYGEKVAHTHVDTLYVRYHYYTWAVLAYTWY,MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,1,AITEVECFL_HLA-A02:01,9,0.466667,0.015556,0.461538,15.0,HLA-A02:01_1,0.580371,10.835062
2,LYNKYSFKL,HLA-A24:02,Negative,9,5.0,0.0,1.0,6.0,YSAMYEEKVAHTDENIAYLMFHYYTWAVQAYTGY,MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFSTSVSRPGRGEPRF...,0,LYNKYSFKL_HLA-A24:02,9,0.142857,0.015306,0.000000,7.0,HLA-A24:02_0,0.088064,11.074095
3,ALDGNLVSMDV,HLA-A02:01,Positive,11,34.0,1.0,2.0,34.0,YFAMYGEKVAHTHVDTLYVRYHYYTWAVLAYTWY,MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,1,ALDGNLVSMDV_HLA-A02:01,11,0.055556,0.001418,0.029412,36.0,HLA-A02:01_1,0.191032,31.765368
4,EDLLMGTLGIV,HLA-A02:01,Positive,11,8.0,2.0,3.0,7.0,YFAMYGEKVAHTHVDTLYVRYHYYTWAVLAYTWY,MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,1,EDLLMGTLGIV_HLA-A02:01,11,0.300000,0.019091,0.250000,10.0,HLA-A02:01_1,0.210606,19.122175
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4182,IPFLGIRET,HLA-B07:02,Negative,9,52.0,0.0,1.0,53.0,YYSEYRNIYAQTDESNLYLSYDYYTWAERAYEWY,MLVMAPRTVLLLLSAALALTETWAGSHSMRYFYTSVSRPGRGEPRF...,0,IPFLGIRET_HLA-B07:02,9,0.018519,0.000330,0.000000,54.0,HLA-B07:02_0,0.097966,9.873003
4183,SMNATLVQA,HLA-A02:01,Positive,9,14.0,2.0,3.0,13.0,YFAMYGEKVAHTHVDTLYVRYHYYTWAVLAYTWY,MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,1,SMNATLVQA_HLA-A02:01,9,0.187500,0.008961,0.142857,16.0,HLA-A02:01_1,0.119097,16.244989
4184,AYRRRWRRL,HLA-A24:02,Positive,9,4.0,2.0,3.0,3.0,YSAMYEEKVAHTDENIAYLMFHYYTWAVQAYTGY,MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFSTSVSRPGRGEPRF...,1,AYRRRWRRL_HLA-A24:02,9,0.500000,0.035714,0.500000,6.0,HLA-A24:02_1,0.174572,5.655990
4185,NWKNFYPSY,HLA-A24:07,Negative,9,1.0,0.0,1.0,2.0,YSAMYEEKVAQTDENIAYLMFHYYTWAVQAYTGY,MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFSTSVSRPGRGEPRF...,0,NWKNFYPSY_HLA-A24:07,9,0.333333,0.055556,0.000000,3.0,HLA-A24:07_0,0.284538,3.337332
