In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
import pandas as pd
import os
import numpy as np
from sklearn.model_selection import train_test_split
from scipy.stats import beta

from selfpeptide.model.binding_affinity_classifier import Peptide_HLA_BindingClassifier
from selfpeptide.model.components import ResMLP_Network
from selfpeptide.utils.data_utils import load_immunogenicity_dataframes, filter_peptide_dataset, load_immunogenicity_dataframes_jointseqs
from selfpeptide.utils.constants import *
from selfpeptide.model.immunogenicity_predictor import JointPeptidesNetwork_Beta
from selfpeptide.utils.data_utils import SequencesInteractionDataset, load_binding_affinity_dataframes_jointseqs

from torch.utils.data import Dataset, DataLoader



In [2]:
with open("../trained_models/BindingModels/floral-sweep-3/config.json", "r") as f:
    binding_config = json.load(f)
for k in binding_config.keys():
    if not isinstance(binding_config[k], str):
        continue
    if "/home/gvisona/SelfPeptides" in binding_config[k]:
        binding_config[k] = binding_config[k].replace("/home/gvisona/SelfPeptides", "/home/gvisona/Projects/SelfPeptides")
    if "/fast/gvisona/SelfPeptides" in binding_config[k]:
        binding_config[k] = binding_config[k].replace("/fast/gvisona/SelfPeptides", "/home/gvisona/Projects/SelfPeptides")
binding_config

{'PMA_ln': True,
 'PMA_num_heads': 4,
 'accumulate_batches': 64,
 'batch_size': 16,
 'binding_affinity_df': '/home/gvisona/Projects/SelfPeptides/processed_data/Binding_Affinity/DHLAP_binding_affinity_data.csv',
 'cool_down': 0.8,
 'dropout_p': 0.05,
 'early_stopping': True,
 'embedding_dim': 512,
 'experiment_group': 'Binding_affinity_training_LS_joint',
 'experiment_name': 'Binding_model_LS_joint_seeds',
 'hla_embedding_dim': 1024,
 'hla_prot_seq_file': '/home/gvisona/Projects/SelfPeptides/processed_data/HLA_embeddings/HLA_proteins_T5/hla_proteins_mapping.csv',
 'joint_embedder_hidden_dim': 2048,
 'ligand_atlas_binding_df': '/home/gvisona/Projects/SelfPeptides/processed_data/Binding_Affinity/HLA_Ligand_Atlas_processed.csv',
 'lr': 0.00030967294546198955,
 'ls_alpha': 0.05,
 'max_updates': 20000,
 'min_frac': 0.1,
 'mlp_hidden_dim': 2048,
 'mlp_input_dim': 512,
 'mlp_num_layers': 2,
 'momentum': 0.9,
 'n_attention_layers': 2,
 'nesterov_momentum': True,
 'num_heads': 4,
 'output_dim': 

In [3]:
imm_config = {}
config = {}
config["immunogenicity_df"] = "/home/gvisona/Projects/SelfPeptides/processed_data/Immunogenicity/Processed_TCell_IEDB_beta_summed.csv"
config["pseudo_seq_file"] = "/home/gvisona/Projects/SelfPeptides/processed_data/HLA_embeddings/HLA_pseudoseqs_T5/hla_pseudoseq_mapping.csv"
config["hla_prot_seq_file"] = "/home/gvisona/Projects/SelfPeptides/processed_data/HLA_embeddings/HLA_proteins_T5/hla_proteins_mapping.csv"
config["binding_affinity_df"] = '/home/gvisona/Projects/SelfPeptides/processed_data/Binding_Affinity/DHLAP_binding_affinity_data.csv'
config["binding_model_checkpoint"] = '/home/gvisona/Projects/SelfPeptides/trained_models/BindingModels/floral-sweep-3/checkpoints/001_checkpoint.pt'
config["binding_model_config"] = '/home/gvisona/Projects/SelfPeptides/trained_models/BindingModels/floral-sweep-3/config.json'
condig[""
config["dhlap_df"] = '/home/gvisona/Projects/SelfPeptides/processed_data/Immunogenicity/DHLAP_immunogenicity_data.csv'
config["test_size"] = 0.15
config["val_size"] = 0.1
config["seed"]= 0


In [24]:
model_config = {
    "seed": 0,
    "run_number": None,
    "experiment_name": "devel",
    "experiment_group": "Beta_regs_js",
    "project_folder": "/home/gvisona/Projects/SelfPeptides",
    "init_checkpoint": None,
    "force_restart": False,
    "min_subjects_tested": 1,
    "hla_filter": None,
    "immunogenicity_df": "/home/gvisona/Projects/SelfPeptides/processed_data/Immunogenicity/Processed_TCell_IEDB_beta_summed.csv",
    "dhlap_df": "/home/gvisona/Projects/SelfPeptides/processed_data/Immunogenicity/DHLAP_immunogenicity_data.csv",
    "pseudo_seq_file": "/home/gvisona/Projects/SelfPeptides/processed_data/HLA_embeddings/HLA_pseudoseqs_T5/hla_pseudoseq_mapping.csv",
    "hla_prot_seq_file": "/home/gvisona/Projects/SelfPeptides/processed_data/HLA_embeddings/HLA_proteins_T5/hla_proteins_mapping.csv",
    "binding_model_config": "/home/gvisona/Projects/SelfPeptides/trained_models/BindingModels/floral-sweep-3/config.json",
    "binding_model_checkpoint": "/home/gvisona/Projects/SelfPeptides/trained_models/BindingModels/floral-sweep-3/checkpoints/001_checkpoint.pt",
    "dropout_p": 0.15,
    "mlp_input_dim": 1025,
    "mlp_hidden_dim": 2048,
    "mlp_num_layers": 2,
    "mlp_output_dim": 512,
    "imm_regression_hidden_dim": 2048,
    "max_updates": 50,
    "patience": 1000,
    "validate_every_n_updates": 1,
    "test_size": 0.15,
    "val_size": 0.1,
    "trainval_df": None,
    "test_df": None,
    "augmentation_df": None,
    "batch_size": 8,
    "accumulate_batches": 1,
    "IR_n_bins": 50,
    "LDS_kernel": "triang",
    "LDS_ks": 11,
    "LDS_sigma": 1.0,
    "loss_weights": "LDS_weights",
    "lr": 0.0001,
    "weight_decay": 1e-5,
    "momentum": 0.9,
    "nesterov_momentum": True,
    "min_frac": 0.1,
    "ramp_up": 0.1,
    "cool_down": 0.8,
    "use_posterior_mean": False,
    "wandb_sweep": None
}

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [25]:
model = JointPeptidesNetwork_Beta(model_config, binding_config, 
                             binding_checkpoint=model_config["binding_model_checkpoint"], 
                             device="cpu")

In [31]:
train_ba_df, val_ba_df, test_ba_df = load_binding_affinity_dataframes_jointseqs(binding_config)
train_ba_df

Unnamed: 0,HLA,Peptide,Label,Allele Pseudo-sequence,Allele Protein sequence,Stratification_index
690521,HLA-C08:02,APRHGSLGF,0,YYAGYREKYRQTDVSNLYLRYNFYTWAERAYTWY,MRVMAPRTLILLLSGALALTETWACSHSMRYFYTAVSRPGRGEPRF...,HLA-C08:02_0
582137,HLA-C06:02,NHAEIMQAV,1,YDSGYREKYRQADVNKLYLWYDSYTWAEWAYTWY,MRVMAPRTLILLLSGALALTETWACSHSMRYFDTAVSRPGRGEPRF...,HLA-C06:02_1
52936,HLA-C06:02,GYIHVTQTF,1,YDSGYREKYRQADVNKLYLWYDSYTWAEWAYTWY,MRVMAPRTLILLLSGALALTETWACSHSMRYFDTAVSRPGRGEPRF...,HLA-C06:02_1
12726,HLA-A11:01,ARDLYDAGVKR,1,YYAMYQENVAQTDVDTLYIIYRDYTWAAQAYRWY,MAVMAPRTLLLLLSGALALTQTWAGSHSMRYFYTSVSRPGRGEPRF...,HLA-A11:01_1
222412,HLA-C12:03,FAYTSRIVV,1,YYAGYREKYRQADVSNLYLWYDSYTWAEWAYTWY,MRVMAPRTLILLLSGALALTETWACSHSMRYFYTAVSRPGRGEPRF...,HLA-C12:03_1
...,...,...,...,...,...,...
123558,HLA-B27:09,VAFTSHEHF,0,YHTEYREICAKTDEDTLYLNYHHYTWAVLAYEWY,MRVTAPRTLLLLLWGAVALTETWAGSHSMRYFHTSVSRPGRGEPRF...,HLA-B27:09_0
125124,HLA-A02:09,YLEPAIAKY,0,YFAMYGEKVAHTHVDTLYVRYHYYTWAVLAYTWY,MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,HLA-A02:09_0
265161,HLA-B50:02,NEIKDSVVA,1,YHTKYREISTNTYESNLYWRYNLYTWAELAYLSY,MRVTAPRTVLLLLSAALALTETWAGSHSMRYFHTAMSRPGRGEPRF...,HLA-B50:02_1
308960,HLA-A02:08,YLEPAIAKY,1,YYAMYGENVAHTHVDTLYLRYHYYTWAVWAYTWY,MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFYTSVSRPGRGEPRF...,HLA-A02:08_1


In [33]:
train_df, val_df, test_df, dhlap_imm_df = load_immunogenicity_dataframes_jointseqs(config)
train_df

IEDB N. training samples: 21435
IEDB N. val samples: 2373
IEDB N. test samples: 4187


Unnamed: 0,Peptide,HLA,Qualitative Measurement,Peptide length,Number of Subjects Tested,Number of Subjects Positive,Alpha,Beta,Allele Pseudo-sequence,Allele Protein sequence,Target,Sample,Peptide Length,Distr. Mean,Distr. Variance,Distr. Mode,Distr. Precision,Stratification_index
18101,KVIALATEDK,HLA-A03:01,Negative,10,1.0,0.0,1.0,2.0,YFAMYQENVAQTDVDTLYIIYRDYTWAELAYTWY,MAVMAPRTLLLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,0,KVIALATEDK_HLA-A03:01,10,0.333333,0.055556,0.0,3.0,HLA-A03:01_0
6632,LEYSISNDL,HLA-B44:02,Negative,9,4.0,0.0,1.0,5.0,YYTKYREISTNTYENTAYIRYDDYTWAVDAYLSY,MRVTAPRTLLLLLWGAVALTETWAGSHSMRYFYTAMSRPGRGEPRF...,0,LEYSISNDL_HLA-B44:02,9,0.166667,0.019841,0.0,6.0,HLA-B44:02_0
8003,TTWEDVPYL,HLA-A02:01,Positive,9,4.0,4.0,5.0,1.0,YFAMYGEKVAHTHVDTLYVRYHYYTWAVLAYTWY,MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,1,TTWEDVPYL_HLA-A02:01,9,0.833333,0.019841,1.0,6.0,HLA-A02:01_1
14414,KEKAKEMNAL,HLA-B40:01,Negative,10,2.0,0.0,1.0,3.0,YHTKYREISTNTYESNLYLRYNYYSLAVLAYEWY,MRVTAPRTVLLLLSAALALTETWAGSHSMRYFHTAMSRPGRGEPRF...,0,KEKAKEMNAL_HLA-B40:01,10,0.250000,0.037500,0.0,4.0,HLA-B40:01_0
12127,IEYAKLYVL,HLA-B40:02,Negative,9,2.0,0.0,1.0,3.0,YHTKYREISTNTYESNLYLSYNYYTWAVLAYEWY,MRVTAPRTLLLLLWGAVALTETWAGSHSMRYFHTSVSRPGRGEPRF...,0,IEYAKLYVL_HLA-B40:02,9,0.250000,0.037500,0.0,4.0,HLA-B40:02_0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
21490,ILKEPVHGVY,HLA-B15:10,Positive,10,1.0,1.0,2.0,1.0,YYSEYRNICTNTYESNLYLRYDYYTWAELAYLWY,MRVTAPRTVLLLLSGALALTETWAGSHSMRYFYTAMSRPGRGEPRF...,1,ILKEPVHGVY_HLA-B15:10,10,0.666667,0.055556,1.0,3.0,HLA-B15:10_1
21645,RYYDGNIYEL,HLA-A24:07,Positive,10,1.0,1.0,2.0,1.0,YSAMYEEKVAQTDENIAYLMFHYYTWAVQAYTGY,MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFSTSVSRPGRGEPRF...,1,RYYDGNIYEL_HLA-A24:07,10,0.666667,0.055556,1.0,3.0,HLA-A24:07_1
25970,LLWTLVVLL,HLA-A02:09,Positive,9,1.0,1.0,2.0,1.0,YFAMYGEKVAHTHVDTLYVRYHYYTWAVLAYTWY,MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,1,LLWTLVVLL_HLA-A02:09,9,0.666667,0.055556,1.0,3.0,HLA-A02:09_1
26450,SLGLVILLVL,HLA-A02:11,Positive,10,1.0,1.0,2.0,1.0,YFAMYGEKVAHIDVDTLYVRYHYYTWAVLAYTWY,MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,1,SLGLVILLVL_HLA-A02:11,10,0.666667,0.055556,1.0,3.0,HLA-A02:11_1


In [34]:
train_df["Label"] = train_df["Target"]

In [27]:

train_dset = SequencesInteractionDataset(train_ba_df, hla_repr=["Allele Pseudo-sequence", "Allele Protein sequence"])
train_loader = DataLoader(train_dset, batch_size=binding_config['batch_size'], drop_last=True, shuffle=True)

In [29]:
for ix, train_batch in enumerate(train_loader):
    peptides, hla_pseudoseqs, hla_protein_seq = train_batch[:3]
    batch_targets = train_batch[-1].float().to(device)
    predicted_logits, embs = model.binding_model(peptides, hla_pseudoseqs, hla_protein_seq)
    break

In [30]:
predicted_logits

tensor([[-0.5924],
        [-2.0259],
        [-1.8275],
        [-1.7415],
        [-1.8562],
        [ 1.3742],
        [-1.8530],
        [-1.8272],
        [ 2.0037],
        [ 1.2928],
        [-1.7638],
        [ 0.7614],
        [ 1.9362],
        [-1.2502],
        [-0.3659],
        [-1.8471]])

In [35]:

train_imm_dset = SequencesInteractionDataset(train_df, hla_repr=["Allele Pseudo-sequence", "Allele Protein sequence"])
train_imm_loader = DataLoader(train_imm_dset, batch_size=binding_config['batch_size'], drop_last=True, shuffle=True)

In [39]:
for ix, train_batch in enumerate(train_imm_loader):
    peptides, hla_pseudoseqs, hla_protein_seq = train_batch[:3]
    batch_targets = train_batch[-1].float().to(device)
    predicted_logits, embs = model.binding_model(peptides, hla_pseudoseqs, hla_protein_seq)
    break

In [40]:
predicted_logits

tensor([[ 1.8741],
        [ 1.9191],
        [ 1.8584],
        [ 1.9974],
        [-1.6596],
        [ 1.4052],
        [ 1.5796],
        [ 1.8153],
        [ 1.7295],
        [ 1.6750],
        [ 1.8168],
        [ 1.2311],
        [ 1.6782],
        [ 1.2800],
        [ 0.2196],
        [ 0.3276]])