In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
import pandas as pd
import os
import numpy as np
from sklearn.model_selection import train_test_split
from scipy.stats import beta

from selfpeptide.model.binding_affinity_classifier import Peptide_HLA_BindingClassifier
from selfpeptide.model.components import ResMLP_Network
from selfpeptide.utils.data_utils import load_immunogenicity_dataframes, filter_peptide_dataset, load_immunogenicity_dataframes_jointseqs
from selfpeptide.utils.constants import *

from torch.utils.data import Dataset, DataLoader

In [3]:
with open("../trained_models/BindingModels/floral-sweep-3/config.json", "r") as f:
    binding_config = json.load(f)
for k in binding_config.keys():
    if not isinstance(binding_config[k], str):
        continue
    if "/home/gvisona/SelfPeptides" in binding_config[k]:
        binding_config[k] = binding_config[k].replace("/home/gvisona/SelfPeptides", "/home/gvisona/Projects/SelfPeptides")
    if "/fast/gvisona/SelfPeptides" in binding_config[k]:
        binding_config[k] = binding_config[k].replace("/fast/gvisona/SelfPeptides", "/home/gvisona/Projects/SelfPeptides")
binding_config

{'PMA_ln': True,
 'PMA_num_heads': 4,
 'accumulate_batches': 64,
 'batch_size': 16,
 'binding_affinity_df': '/home/gvisona/Projects/SelfPeptides/processed_data/Binding_Affinity/DHLAP_binding_affinity_data.csv',
 'cool_down': 0.8,
 'dropout_p': 0.05,
 'early_stopping': True,
 'embedding_dim': 512,
 'experiment_group': 'Binding_affinity_training_LS_joint',
 'experiment_name': 'Binding_model_LS_joint_seeds',
 'hla_embedding_dim': 1024,
 'hla_prot_seq_file': '/home/gvisona/Projects/SelfPeptides/processed_data/HLA_embeddings/HLA_proteins_T5/hla_proteins_mapping.csv',
 'joint_embedder_hidden_dim': 2048,
 'ligand_atlas_binding_df': '/home/gvisona/Projects/SelfPeptides/processed_data/Binding_Affinity/HLA_Ligand_Atlas_processed.csv',
 'lr': 0.00030967294546198955,
 'ls_alpha': 0.05,
 'max_updates': 20000,
 'min_frac': 0.1,
 'mlp_hidden_dim': 2048,
 'mlp_input_dim': 512,
 'mlp_num_layers': 2,
 'momentum': 0.9,
 'n_attention_layers': 2,
 'nesterov_momentum': True,
 'num_heads': 4,
 'output_dim': 

In [4]:
imm_config = {}
config = {}
config["immunogenicity_df"] = "/home/gvisona/Projects/SelfPeptides/processed_data/Immunogenicity/Processed_TCell_IEDB_beta_summed.csv"
config["pseudo_seq_file"] = "/home/gvisona/Projects/SelfPeptides/processed_data/HLA_embeddings/HLA_pseudoseqs_T5/hla_pseudoseq_mapping.csv"
config["hla_prot_seq_file"] = "/home/gvisona/Projects/SelfPeptides/processed_data/HLA_embeddings/HLA_proteins_T5/hla_proteins_mapping.csv"
config["binding_affinity_df"] = '/home/gvisona/Projects/SelfPeptides/processed_data/Binding_Affinity/DHLAP_binding_affinity_data.csv'
config["binding_model_checkpoint"] = '/home/gvisona/Projects/SelfPeptides/trained_models/BindingModels/floral-sweep-3/checkpoints/001_checkpoint.pt'
config["binding_model_config"] = '/home/gvisona/Projects/SelfPeptides/trained_models/BindingModels/floral-sweep-3/config.json'
config["dhlap_df"] = '/home/gvisona/Projects/SelfPeptides/processed_data/Immunogenicity/DHLAP_immunogenicity_data.csv'
config["test_size"] = 0.15
config["val_size"] = 0.1
config["seed"]= 0


In [5]:
train_df, val_df, test_df, dhlap_imm_df = load_immunogenicity_dataframes_jointseqs(config, True)
train_df

IEDB N. training samples: 21435
IEDB N. val samples: 2373
IEDB N. test samples: 4187


Unnamed: 0,Peptide,HLA,Qualitative Measurement,Peptide length,Number of Subjects Tested,Number of Subjects Positive,Alpha,Beta,Allele Pseudo-sequence,Allele Protein sequence,Target,Sample,Peptide Length,Distr. Mean,Distr. Variance,Distr. Mode,Distr. Precision,Stratification_index
18101,KVIALATEDK,HLA-A03:01,Negative,10,1.0,0.0,1.0,2.0,YFAMYQENVAQTDVDTLYIIYRDYTWAELAYTWY,MAVMAPRTLLLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,0,KVIALATEDK_HLA-A03:01,10,0.333333,0.055556,0.0,3.0,HLA-A03:01_0
6632,LEYSISNDL,HLA-B44:02,Negative,9,4.0,0.0,1.0,5.0,YYTKYREISTNTYENTAYIRYDDYTWAVDAYLSY,MRVTAPRTLLLLLWGAVALTETWAGSHSMRYFYTAMSRPGRGEPRF...,0,LEYSISNDL_HLA-B44:02,9,0.166667,0.019841,0.0,6.0,HLA-B44:02_0
8003,TTWEDVPYL,HLA-A02:01,Positive,9,4.0,4.0,5.0,1.0,YFAMYGEKVAHTHVDTLYVRYHYYTWAVLAYTWY,MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,1,TTWEDVPYL_HLA-A02:01,9,0.833333,0.019841,1.0,6.0,HLA-A02:01_1
14414,KEKAKEMNAL,HLA-B40:01,Negative,10,2.0,0.0,1.0,3.0,YHTKYREISTNTYESNLYLRYNYYSLAVLAYEWY,MRVTAPRTVLLLLSAALALTETWAGSHSMRYFHTAMSRPGRGEPRF...,0,KEKAKEMNAL_HLA-B40:01,10,0.250000,0.037500,0.0,4.0,HLA-B40:01_0
12127,IEYAKLYVL,HLA-B40:02,Negative,9,2.0,0.0,1.0,3.0,YHTKYREISTNTYESNLYLSYNYYTWAVLAYEWY,MRVTAPRTLLLLLWGAVALTETWAGSHSMRYFHTSVSRPGRGEPRF...,0,IEYAKLYVL_HLA-B40:02,9,0.250000,0.037500,0.0,4.0,HLA-B40:02_0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
21490,ILKEPVHGVY,HLA-B15:10,Positive,10,1.0,1.0,2.0,1.0,YYSEYRNICTNTYESNLYLRYDYYTWAELAYLWY,MRVTAPRTVLLLLSGALALTETWAGSHSMRYFYTAMSRPGRGEPRF...,1,ILKEPVHGVY_HLA-B15:10,10,0.666667,0.055556,1.0,3.0,HLA-B15:10_1
21645,RYYDGNIYEL,HLA-A24:07,Positive,10,1.0,1.0,2.0,1.0,YSAMYEEKVAQTDENIAYLMFHYYTWAVQAYTGY,MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFSTSVSRPGRGEPRF...,1,RYYDGNIYEL_HLA-A24:07,10,0.666667,0.055556,1.0,3.0,HLA-A24:07_1
25970,LLWTLVVLL,HLA-A02:09,Positive,9,1.0,1.0,2.0,1.0,YFAMYGEKVAHTHVDTLYVRYHYYTWAVLAYTWY,MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,1,LLWTLVVLL_HLA-A02:09,9,0.666667,0.055556,1.0,3.0,HLA-A02:09_1
26450,SLGLVILLVL,HLA-A02:11,Positive,10,1.0,1.0,2.0,1.0,YFAMYGEKVAHIDVDTLYVRYHYYTWAVLAYTWY,MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,1,SLGLVILLVL_HLA-A02:11,10,0.666667,0.055556,1.0,3.0,HLA-A02:11_1


In [6]:
iedb_df, dhlap_imm_df = load_immunogenicity_dataframes_jointseqs(config, False)
iedb_df

Unnamed: 0,Peptide,HLA,Qualitative Measurement,Peptide length,Number of Subjects Tested,Number of Subjects Positive,Alpha,Beta,Allele Pseudo-sequence,Allele Protein sequence,Target,Sample,Peptide Length,Distr. Mean,Distr. Variance,Distr. Mode,Distr. Precision,Stratification_index
0,GILGFVFTL,HLA-A02:01,Positive,9,898.0,513.0,514.0,386.0,YFAMYGEKVAHTHVDTLYVRYHYYTWAVLAYTWY,MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,1,GILGFVFTL_HLA-A02:01,9,0.571111,0.000272,0.571269,900.0,HLA-A02:01_1
1,NLVPMVATV,HLA-A02:01,Positive,9,578.0,467.0,468.0,112.0,YFAMYGEKVAHTHVDTLYVRYHYYTWAVLAYTWY,MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,1,NLVPMVATV_HLA-A02:01,9,0.806897,0.000268,0.807958,580.0,HLA-A02:01_1
2,QYIKWPWYI,HLA-A24:02,Positive,9,461.0,291.0,292.0,171.0,YSAMYEEKVAHTDENIAYLMFHYYTWAVQAYTGY,MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFSTSVSRPGRGEPRF...,1,QYIKWPWYI_HLA-A24:02,9,0.630670,0.000502,0.631236,463.0,HLA-A24:02_1
3,FLPSDFFPSV,HLA-A02:01,Positive,10,313.0,216.0,217.0,98.0,YFAMYGEKVAHTHVDTLYVRYHYYTWAVLAYTWY,MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,1,FLPSDFFPSV_HLA-A02:01,10,0.688889,0.000678,0.690096,315.0,HLA-A02:01_1
4,YLQPRTFLL,HLA-A02:01,Positive-High,9,274.0,186.0,187.0,89.0,YFAMYGEKVAHTHVDTLYVRYHYYTWAVLAYTWY,MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,1,YLQPRTFLL_HLA-A02:01,9,0.677536,0.000789,0.678832,276.0,HLA-A02:01_1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
27990,KEYTPQIYTY,HLA-B49:01,Negative,10,1.0,0.0,1.0,2.0,YHTKYREISTNTYENIAYWRYNLYTWAELAYLWY,MRVTAPRTVLLLLSAALALTETWAGSHSMRYFHTAMSRPGRGEPRF...,0,KEYTPQIYTY_HLA-B49:01,10,0.333333,0.055556,0.000000,3.0,HLA-B49:01_0
27991,IQQLYEFRK,HLA-A33:03,Negative,9,1.0,0.0,1.0,2.0,YTAMYRNNVAHIDVDTLYIMYQDYTWAVLAYTWY,MAVMAPRTLLLLLLGALALTQTWAGSHSMRYFTTSVSRPGRGEPRF...,0,IQQLYEFRK_HLA-A33:03,9,0.333333,0.055556,0.000000,3.0,HLA-A33:03_0
27992,IQQLYEFRK,HLA-A33:01,Negative,9,1.0,0.0,1.0,2.0,YTAMYRNNVAHIDVDTLYIMYQDYTWAVLAYTWH,MAVMAPRTLLLLLLGALALTQTWAGSHSMRYFTTSVSRPGRGEPRF...,0,IQQLYEFRK_HLA-A33:01,9,0.333333,0.055556,0.000000,3.0,HLA-A33:01_0
27993,DELVDPINY,HLA-B49:01,Negative,9,1.0,0.0,1.0,2.0,YHTKYREISTNTYENIAYWRYNLYTWAELAYLWY,MRVTAPRTVLLLLSAALALTETWAGSHSMRYFHTAMSRPGRGEPRF...,0,DELVDPINY_HLA-B49:01,9,0.333333,0.055556,0.000000,3.0,HLA-B49:01_0


# Peptides fintetuned ESM2 embeddings

In [7]:
with open("../processed_data/Peptide_embeddings/ESM2_idx2peptide.json", "r") as f:
    idx2peptide = json.load(f)
    
peptide2idx = {v: k for k, v in idx2peptide.items()}
len(idx2peptide)

17784

In [8]:
esm2_peptide_embeddings = np.load("../processed_data/Peptide_embeddings/FinetunedESM2_imm_peptides_embeddings.npy")
esm2_peptide_embeddings.shape

(17784, 480)

In [9]:

class BetaDistributionDataset_ESM2Embs(Dataset):
    def __init__(self, df, peptide_embs, peptide2idx_mapping, hla_repr=["Allele Pseudo-sequence"]):
        super().__init__()
        cols = ["Peptide", *hla_repr, 
                "Alpha", "Beta", "Target"]
        self.data_matrix = df[cols].values.tolist()
        
        peptides = df["Peptide"].values
        idxs = [int(peptide2idx_mapping[p]) for p in peptides]
        # print(idxs)
        self.peptide_esm2_embeddings = [peptide_embs[i] for i in idxs]
        
        
    def __len__(self):
        return len(self.data_matrix)
    
    def __getitem__(self, ix):
        return self.data_matrix[ix], self.peptide_esm2_embeddings[ix]


In [10]:
dset = BetaDistributionDataset_ESM2Embs(train_df, esm2_peptide_embeddings, peptide2idx, hla_repr=["Allele Pseudo-sequence", "Allele Protein sequence"])
dset[0]

(['KVIALATEDK',
  'YFAMYQENVAQTDVDTLYIIYRDYTWAELAYTWY',
  'MAVMAPRTLLLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRMEPRAPWIEQEGPEYWDQETRNVKAQSQTDRVDLGTLRGYYNQSEAGSHTIQIMYGCDVGSDGRFLRGYRQDAYDGKDYIALNEDLRSWTAADMAAQITKRKWEAAHEAEQLRAYLDGTCVEWLRRYLENGKETLQRTDPPKTHMTHHPISDHEATLRCWALGFYPAEITLTWQRDGEDQTQDTELVETRPAGDGTFQKWAAVVVPSGEEQRYTCHVQHEGLPKPLTLRWELSSQPTIPIVGIIAGLVLLGAVITGAVVAAVMWRRKSSDRKGGSYTQAASSDSAQGSDVSLTACKV',
  1.0,
  2.0,
  0],
 array([ 5.78934774e-02, -1.36778578e-01,  1.81181401e-01,  1.09881297e-01,
        -1.04171313e-01,  4.92570773e-02, -5.58215305e-02, -8.63942578e-02,
         9.30502564e-02,  1.23541392e-01,  7.60633498e-02, -5.95053546e-02,
        -3.97013240e-02, -8.75522047e-02,  1.13377191e-01,  6.28435519e-03,
        -5.06509021e-02,  1.14172935e-01,  2.37336040e-01, -7.76094496e-02,
        -4.64581326e-02, -1.37919843e-01, -4.87093404e-02,  7.26822317e-02,
         4.06974033e-02, -4.10680920e-02,  1.19964100e-01,  1.37769848e-01,
        -6.40182868e-

In [11]:
train_loader = DataLoader(dset, batch_size=32, drop_last=True, shuffle=True)

In [12]:
batch = next(iter(train_loader))

In [13]:
model_config = {
    "seed": 0,
    "run_number": None,
    "experiment_name": "devel",
    "experiment_group": "Beta_regs_js",
    "project_folder": "/home/gvisona/Projects/SelfPeptides",
    "init_checkpoint": None,
    "force_restart": False,
    "min_subjects_tested": 1,
    "hla_filter": None,
    "immunogenicity_df": "/home/gvisona/Projects/SelfPeptides/processed_data/Immunogenicity/Processed_TCell_IEDB_beta_summed.csv",
    "dhlap_df": "/home/gvisona/Projects/SelfPeptides/processed_data/Immunogenicity/DHLAP_immunogenicity_data.csv",
    "pseudo_seq_file": "/home/gvisona/Projects/SelfPeptides/processed_data/HLA_embeddings/HLA_pseudoseqs_T5/hla_pseudoseq_mapping.csv",
    "hla_prot_seq_file": "/home/gvisona/Projects/SelfPeptides/processed_data/HLA_embeddings/HLA_proteins_T5/hla_proteins_mapping.csv",
    "binding_model_config": "/home/gvisona/Projects/SelfPeptides/trained_models/BindingModels/floral-sweep-3/config.json",
    "binding_model_checkpoint": "/home/gvisona/Projects/SelfPeptides/trained_models/BindingModels/floral-sweep-3/checkpoints/001_checkpoint.pt",
    "dropout_p": 0.15,
    "mlp_input_dim": 480,
    "mlp_hidden_dim": 2048,
    "mlp_num_layers": 2,
    "mlp_output_dim": 512,
    "imm_regression_hidden_dim": 2048,
    "max_updates": 50,
    "patience": 1000,
    "validate_every_n_updates": 1,
    "test_size": 0.15,
    "val_size": 0.1,
    "trainval_df": None,
    "test_df": None,
    "augmentation_df": None,
    "batch_size": 8,
    "accumulate_batches": 1,
    "IR_n_bins": 50,
    "LDS_kernel": "triang",
    "LDS_ks": 11,
    "LDS_sigma": 1.0,
    "loss_weights": "LDS_weights",
    "lr": 0.0001,
    "weight_decay": 1e-5,
    "momentum": 0.9,
    "nesterov_momentum": True,
    "min_frac": 0.1,
    "ramp_up": 0.1,
    "cool_down": 0.8,
    "use_posterior_mean": False,
    "wandb_sweep": None
}


In [33]:

class ImmunogenicityBetaModel_ESM2Embeddings(nn.Module):
    def __init__(self, config, device="cpu", epsilon=1e-3):
        super().__init__()
        self.config = config
        self.device = device
        self.epsilon = epsilon
        # assert config["output_dim"]==2, "Beta model requires 2 outputs"
        
        self.joint_mlp = ResMLP_Network(config, device)  
        
        self.output_module = nn.Sequential(
            nn.Linear(config["mlp_output_dim"], config["imm_regression_hidden_dim"]),
            nn.ReLU(),
            nn.Linear(config["imm_regression_hidden_dim"], 2)
        )      
        

    def forward(self, binding_score_logit, peptides_embs, *args):
        binding_score = torch.sigmoid(binding_score_logit).view(-1)
        
        binding_score_tanh = binding_score * 2 - 1
        # mlp_input = torch.cat([binding_score.view(-1,1), binding_peptides_embs, binding_hlas_embs], dim=1)
        mlp_input = binding_score_tanh[:, None] * peptides_embs
        joint_embs = self.joint_mlp(mlp_input)
#         classifier_input = joint_embs.clone()
        output = self.output_module(joint_embs)
        
        means = self.epsilon + (1-2*self.epsilon) * torch.sigmoid(output[:, 0])
        posterior_means = binding_score * means
        precisions = 2 + torch.exp(output[:, 1])
        
        return torch.hstack([means.view(-1,1), posterior_means.view(-1,1), precisions.view(-1,1)]), joint_embs
    
    
class JointPeptidesNetwork_Beta_ESM2Embeddings(nn.Module):
    def __init__(self, imm_config, binding_config, binding_checkpoint=None, device="cpu"):
        super().__init__()
        if not isinstance(binding_config, dict):
            with open(binding_config, "r") as f:
                binding_config = json.load(f)
        binding_config["pretrained_aa_embeddings"] = "none"
        
        self.binding_model = Peptide_HLA_BindingClassifier(binding_config, device=device) 
        if binding_checkpoint is not None:
            self.binding_model.load_state_dict(torch.load(binding_checkpoint, map_location=device))
        else:
            warnings.warn("Binding model not initialized")
        self.binding_model.eval()
        
        
        # Freeze binding model
        for p in self.binding_model.parameters():
            p.requires_grad = False

        self.immunogenicity_model = ImmunogenicityBetaModel_ESM2Embeddings(imm_config, device=device)
        self.immunogenicity_model.train()
    
    def forward(self, peptides, hla_pseudoseqs, hla_prots, peptide_embeddings, *args):
        binding_score_logit, (joint_embeddings, binding_peptides_embs, binding_hlas_embs) = self.binding_model(peptides, hla_pseudoseqs, hla_prots)
        output, imm_joint_embs = self.immunogenicity_model(binding_score_logit, peptide_embeddings)
        predictions = torch.hstack([torch.sigmoid(binding_score_logit).view(-1,1), output])
        return predictions

In [34]:

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


model = JointPeptidesNetwork_Beta_ESM2Embeddings(model_config, binding_config, 
                         binding_checkpoint=config["binding_model_checkpoint"], 
                         device=device)

In [35]:
peptides, hla_pseudoseqs, hla_prots = batch[0][:3]
imm_alpha, imm_beta, imm_target = batch[0][3:]
pep_embs = batch[1]
pep_embs.shape

torch.Size([32, 480])

In [36]:
predictions = model(peptides, hla_pseudoseqs, hla_prots, pep_embs)

In [37]:
predictions

tensor([[0.5798, 0.5777, 0.3350, 2.6086],
        [0.8622, 0.6025, 0.5195, 3.0161],
        [0.8684, 0.5699, 0.4949, 3.1186],
        [0.7707, 0.5644, 0.4350, 2.8389],
        [0.6028, 0.5688, 0.3429, 3.0027],
        [0.7558, 0.5343, 0.4039, 2.7915],
        [0.7601, 0.6127, 0.4658, 2.8575],
        [0.1697, 0.4652, 0.0790, 3.1029],
        [0.5673, 0.6278, 0.3561, 2.9338],
        [0.4504, 0.4896, 0.2205, 3.6823],
        [0.3138, 0.4880, 0.1532, 3.5628],
        [0.8358, 0.5325, 0.4451, 2.9861],
        [0.6319, 0.6170, 0.3899, 2.7740],
        [0.8485, 0.6176, 0.5240, 2.8208],
        [0.8108, 0.5760, 0.4670, 2.9522],
        [0.6703, 0.5593, 0.3749, 2.9863],
        [0.8666, 0.5948, 0.5155, 2.9046],
        [0.8218, 0.6111, 0.5022, 2.8877],
        [0.1579, 0.4810, 0.0759, 3.4077],
        [0.8509, 0.6051, 0.5149, 3.0506],
        [0.7983, 0.6245, 0.4985, 2.8186],
        [0.8609, 0.5812, 0.5004, 2.8557],
        [0.4034, 0.4601, 0.1856, 3.3145],
        [0.8373, 0.5416, 0.4535, 2

In [21]:
model.binding_model

Peptide_HLA_BindingClassifier(
  (peptide_embedder): PeptideEmbedder(
    (tokenizer): AA_Tokenizer()
    (aa_embs): Embedding(23, 512, padding_idx=22)
    (transformer_encoder): TransformerEncoder(
      (pos_encoding): PositionalEncoding(
        (dropout): Dropout(p=0.05, inplace=False)
      )
      (dropout): Dropout(p=0.05, inplace=False)
      (encoder_layers): ModuleList(
        (0-1): 2 x TEncoderLayer(
          (multihead_attention): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
          )
          (dropout1): Dropout(p=0.05, inplace=False)
          (res_norm1): ResNorm(
            (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          )
          (feed_forward): Sequential(
            (0): Linear(in_features=512, out_features=512, bias=True)
            (1): ReLU()
            (2): Linear(in_features=512, out_features=512, bias=True)
          )
          (dropout2): Dropout(p

In [22]:
binding_logits = predictions[:, 0]
binding_logits

tensor([0.5798, 0.8622, 0.8684, 0.7707, 0.6028, 0.7558, 0.7601, 0.1697, 0.5673,
        0.4504, 0.3138, 0.8358, 0.6319, 0.8485, 0.8108, 0.6703, 0.8666, 0.8218,
        0.1579, 0.8509, 0.7983, 0.8609, 0.4034, 0.8373, 0.8446, 0.8079, 0.8792,
        0.8576, 0.7882, 0.8216, 0.3618, 0.8693], grad_fn=<SelectBackward0>)

In [59]:
binding_logits[0] = -20.1
binding_logits

tensor([-20.1000,   0.7982,   0.7540,   0.4755], grad_fn=<AsStridedBackward0>)

In [60]:
binding_scores = torch.sigmoid(binding_logits)*2 -1
binding_scores

tensor([-1.0000,  0.3792,  0.3601,  0.2334], grad_fn=<SubBackward0>)

In [62]:
binding_scores[:, None] * pep_embs

tensor([[-0.0569,  0.0901, -0.1406,  ...,  0.1967,  0.0453, -0.1113],
        [-0.0076, -0.0029,  0.0438,  ..., -0.0354,  0.0038,  0.0353],
        [ 0.0018, -0.0079,  0.0908,  ..., -0.0262, -0.0030,  0.0200],
        [ 0.0254, -0.0143,  0.0422,  ..., -0.0338, -0.0284,  0.0411]],
       grad_fn=<MulBackward0>)

In [63]:
pep_embs

tensor([[ 0.0569, -0.0901,  0.1406,  ..., -0.1967, -0.0453,  0.1113],
        [-0.0202, -0.0078,  0.1154,  ..., -0.0934,  0.0101,  0.0930],
        [ 0.0051, -0.0221,  0.2522,  ..., -0.0727, -0.0084,  0.0554],
        [ 0.1089, -0.0613,  0.1809,  ..., -0.1448, -0.1219,  0.1762]])