In [17]:
import numpy as np
import h5py
from tqdm import tqdm
import os
import torch
from torch.utils.data import Dataset, DataLoader
from selfpeptide.utils.constants import *
import torch.nn as nn
import time

In [2]:
class Self_NonSelf_PeptideDataset(Dataset):
    def __init__(self, hdf5_dataset_fname, gen_size=1000, 
                 val_size=0,
                 negative_label=-1):
        self.hdf5_dataset_fname = hdf5_dataset_fname
        self.gen_size = gen_size
        self.val_size = val_size//2
        self.negative_label = negative_label        

        if not os.path.exists(self.hdf5_dataset_fname):
            raise FileNotFoundError("Specify a valid HDF5 file for the dataset")
        self._get_n_peptides()
        
        self.idx_self = self.val_size
        self.idx_nonself = self.val_size
    
        self._load_peptides(gen_size)
        
    def _get_n_peptides(self):
        with h5py.File(self.hdf5_dataset_fname, 'r') as f:
            self.n_self_peptides = len(f["reference_human_peptides"])
            self.n_nonself_peptides = len(f["nonself_peptides"])

                    
    def _load_peptides(self, n_peptides):
        peptides = torch.zeros((n_peptides, MAX_PEPTIDE_LEN)).long()
        labels = torch.ones(n_peptides).long()
        

        with h5py.File(self.hdf5_dataset_fname, 'r') as f:
            peptides[::2, :] = torch.from_numpy(f["reference_human_peptides"][self.idx_self:self.idx_self+n_peptides//2])
            peptides[1::2, :] = torch.from_numpy(f["nonself_peptides"][self.idx_nonself:self.idx_nonself+n_peptides//2])
            
        labels[1::2] = self.negative_label
        
        self.peptides = peptides.long()
        self.labels = labels.long()
        
        self.idx_self += n_peptides//2
        self.idx_nonself += n_peptides//2
        
        
    def refresh_data(self):
        if self.n_self_peptides-self.idx_self<self.gen_size:
            self.idx_self = self.val_size
        if self.n_nonself_peptides-self.idx_nonself<self.gen_size:
            self.idx_nonself = self.val_size
        self._load_peptides(self.gen_size)
    
    def __len__(self):
        return len(self.peptides)
    
    def __getitem__(self, idx):
        return self.peptides[idx], self.labels[idx]

In [3]:
hdf5_file = "../processed_data/pre_tokenized_peptides_dataset.hdf5"
dset = Self_NonSelf_PeptideDataset(hdf5_file, gen_size=10000, val_size=20)

In [4]:
dset[0]

(tensor([10,  2, 20, 17, 17,  5,  0, 14, 16,  4, 22, 22]), tensor(1))

In [5]:
dloader = DataLoader(dset, batch_size=16)

In [6]:
batch = next(iter(dloader))

In [7]:
peptides, labels = batch

In [8]:
peptides

tensor([[10,  2, 20, 17, 17,  5,  0, 14, 16,  4, 22, 22],
        [16,  2, 15,  0, 14,  7, 13,  5, 16, 17,  4, 17],
        [12, 15, 12, 14, 17,  7,  7,  6, 11, 14,  1, 12],
        [17, 12,  5, 14,  9,  9, 14, 12, 14, 22, 22, 22],
        [ 2, 20,  5,  8, 13,  6,  4, 11,  2, 17, 22, 22],
        [ 8, 11,  1,  2,  9,  2, 17,  0,  1, 14,  6,  7],
        [ 3, 10,  7,  9, 17,  1,  2,  0, 20, 14,  8, 22],
        [ 3,  9,  4, 16, 15,  3, 11,  9, 14, 22, 22, 22],
        [ 5,  0, 14,  0, 13,  9,  9,  8, 15,  9, 14,  3],
        [ 8, 17, 17, 10,  0,  9, 10, 20,  3,  3,  8, 22],
        [ 5, 13,  9,  9,  5, 16,  0, 14,  3, 12, 11, 12],
        [ 3, 14, 12, 20,  3,  1, 15, 13, 14,  5,  8, 12],
        [12,  0,  3,  3, 13,  2, 12, 15, 12,  3, 22, 22],
        [ 9, 15,  3,  5,  5,  7, 16, 13,  8,  5, 20,  3],
        [17,  3, 13, 12,  8,  5,  3,  3,  9, 22, 22, 22],
        [ 7,  9,  0,  4,  9, 17, 12,  9, 17,  0,  1, 22]])

In [9]:
labels

tensor([ 1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1])

In [14]:
pos_ix = (labels==1)
neg_ix = (labels==-1)
pos_ix

tensor([ True, False,  True, False,  True, False,  True, False,  True, False,
         True, False,  True, False,  True, False])

In [15]:
peptides[pos_ix]

tensor([[10,  2, 20, 17, 17,  5,  0, 14, 16,  4, 22, 22],
        [12, 15, 12, 14, 17,  7,  7,  6, 11, 14,  1, 12],
        [ 2, 20,  5,  8, 13,  6,  4, 11,  2, 17, 22, 22],
        [ 3, 10,  7,  9, 17,  1,  2,  0, 20, 14,  8, 22],
        [ 5,  0, 14,  0, 13,  9,  9,  8, 15,  9, 14,  3],
        [ 5, 13,  9,  9,  5, 16,  0, 14,  3, 12, 11, 12],
        [12,  0,  3,  3, 13,  2, 12, 15, 12,  3, 22, 22],
        [17,  3, 13, 12,  8,  5,  3,  3,  9, 22, 22, 22]])

In [16]:
peptides[neg_ix]

tensor([[16,  2, 15,  0, 14,  7, 13,  5, 16, 17,  4, 17],
        [17, 12,  5, 14,  9,  9, 14, 12, 14, 22, 22, 22],
        [ 8, 11,  1,  2,  9,  2, 17,  0,  1, 14,  6,  7],
        [ 3,  9,  4, 16, 15,  3, 11,  9, 14, 22, 22, 22],
        [ 8, 17, 17, 10,  0,  9, 10, 20,  3,  3,  8, 22],
        [ 3, 14, 12, 20,  3,  1, 15, 13, 14,  5,  8, 12],
        [ 9, 15,  3,  5,  5,  7, 16, 13,  8,  5, 20,  3],
        [ 7,  9,  0,  4,  9, 17, 12,  9, 17,  0,  1, 22]])

In [18]:
sim = nn.CosineSimilarity()

In [20]:
sim(peptides[pos_ix].float(), peptides[neg_ix].float())

tensor([0.7500, 0.8391, 0.7359, 0.8883, 0.5887, 0.7336, 0.7778, 0.5810])

In [64]:
def cosine_similarity_all_pairs(a, b, eps=1e-8):
    """
    added eps for numerical stability
    """
    a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
    a_norm = a / torch.clamp(a_n, min=eps)
    b_norm = b / torch.clamp(b_n, min=eps)
    sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
    return sim_mt



class CustomDistanceHingeLoss(nn.Module):
    def __init__(self, margin=0.8, device="cpu"):
        super().__init__()
        self.margin = margin
        self.device = device
        self.hinge_loss = nn.HingeEmbeddingLoss(margin=margin)
    
    def forward(self, embeddings, labels):
        pos_ix = (labels==1)
        neg_ix = (labels==-1)
        
        pos_embeddings = embeddings[pos_ix]
        neg_embeddings = embeddings[neg_ix]        
        
        # Similarities Pos-Pos
        pos_distance = 1 - cosine_similarity_all_pairs(pos_embeddings, pos_embeddings)
        ixs = torch.triu_indices(*pos_distance.shape, offset=1)
        pos_cos_distances = pos_distance[ixs[0], ixs[1]]
        
        
        # Similarities Pos-Neg
        neg_distance = 1 - cosine_similarity_all_pairs(pos_embeddings, neg_embeddings)
        ixs = torch.triu_indices(*neg_distance.shape, offset=0)
        neg_cos_distances = neg_distance[ixs[0], ixs[1]]
        
        cos_distances = torch.cat([neg_cos_distances, pos_cos_distances])
        hinge_labels = torch.ones(len(neg_cos_distances)+len(pos_cos_distances), device=self.device)
        hinge_labels[:len(neg_cos_distances)] = -1
        loss = self.hinge_loss(cos_distances, hinge_labels)
        return loss

In [55]:
distance = 1 - cosine_similarity_all_pairs(peptides[pos_ix].float(), peptides[pos_ix].float())

In [56]:
distance

tensor([[0.0000e+00, 2.2815e-01, 2.3377e-01, 1.8987e-01, 1.9783e-01, 2.3078e-01,
         1.4403e-01, 1.4610e-01],
        [2.2815e-01, 1.1921e-07, 2.3462e-01, 1.4739e-01, 2.8076e-01, 2.0262e-01,
         3.6034e-01, 2.2079e-01],
        [2.3377e-01, 2.3462e-01, 0.0000e+00, 2.0740e-01, 3.5977e-01, 1.2431e-01,
         2.3306e-01, 1.7408e-01],
        [1.8987e-01, 1.4739e-01, 2.0740e-01, 5.9605e-08, 2.9806e-01, 3.3031e-01,
         2.7234e-01, 1.8936e-01],
        [1.9783e-01, 2.8076e-01, 3.5977e-01, 2.9806e-01, 0.0000e+00, 3.1362e-01,
         2.3439e-01, 2.6291e-01],
        [2.3078e-01, 2.0262e-01, 1.2431e-01, 3.3031e-01, 3.1362e-01, 1.1921e-07,
         3.3884e-01, 2.2054e-01],
        [1.4403e-01, 3.6034e-01, 2.3306e-01, 2.7234e-01, 2.3439e-01, 3.3884e-01,
         5.9605e-08, 2.0270e-01],
        [1.4610e-01, 2.2079e-01, 1.7408e-01, 1.8936e-01, 2.6291e-01, 2.2054e-01,
         2.0270e-01, 1.1921e-07]])

In [57]:
distance.shape

torch.Size([8, 8])

In [58]:
torch.triu_indices(*distance.shape)

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3,
         3, 3, 4, 4, 4, 4, 5, 5, 5, 6, 6, 7],
        [0, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 2, 3, 4, 5, 6, 7, 3, 4, 5,
         6, 7, 4, 5, 6, 7, 5, 6, 7, 6, 7, 7]])

In [59]:
ixs = torch.triu_indices(*distance.shape, offset=1)
distance[ixs[0], ixs[1]]

tensor([0.2281, 0.2338, 0.1899, 0.1978, 0.2308, 0.1440, 0.1461, 0.2346, 0.1474,
        0.2808, 0.2026, 0.3603, 0.2208, 0.2074, 0.3598, 0.1243, 0.2331, 0.1741,
        0.2981, 0.3303, 0.2723, 0.1894, 0.3136, 0.2344, 0.2629, 0.3388, 0.2205,
        0.2027])

In [60]:
ixs.shape

torch.Size([2, 28])

In [65]:
loss = CustomDistanceHingeLoss()

In [66]:
loss(peptides.float(), labels)

tensor(0.3972)

In [63]:
labels

tensor([ 1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1])

In [53]:
peptides

tensor([[10,  2, 20, 17, 17,  5,  0, 14, 16,  4, 22, 22],
        [16,  2, 15,  0, 14,  7, 13,  5, 16, 17,  4, 17],
        [12, 15, 12, 14, 17,  7,  7,  6, 11, 14,  1, 12],
        [17, 12,  5, 14,  9,  9, 14, 12, 14, 22, 22, 22],
        [ 2, 20,  5,  8, 13,  6,  4, 11,  2, 17, 22, 22],
        [ 8, 11,  1,  2,  9,  2, 17,  0,  1, 14,  6,  7],
        [ 3, 10,  7,  9, 17,  1,  2,  0, 20, 14,  8, 22],
        [ 3,  9,  4, 16, 15,  3, 11,  9, 14, 22, 22, 22],
        [ 5,  0, 14,  0, 13,  9,  9,  8, 15,  9, 14,  3],
        [ 8, 17, 17, 10,  0,  9, 10, 20,  3,  3,  8, 22],
        [ 5, 13,  9,  9,  5, 16,  0, 14,  3, 12, 11, 12],
        [ 3, 14, 12, 20,  3,  1, 15, 13, 14,  5,  8, 12],
        [12,  0,  3,  3, 13,  2, 12, 15, 12,  3, 22, 22],
        [ 9, 15,  3,  5,  5,  7, 16, 13,  8,  5, 20,  3],
        [17,  3, 13, 12,  8,  5,  3,  3,  9, 22, 22, 22],
        [ 7,  9,  0,  4,  9, 17, 12,  9, 17,  0,  1, 22]])