In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from selfpeptide.utils.data_utils import Self_NonSelf_PeptideDataset

import matplotlib.pyplot as plt
import seaborn as sns
import math

In [50]:

def hypershperical_cosine_margin_loss(emb1, emb2, s=1.0, m=0.5):
    emb1 = emb1 / emb1.norm(dim=1)[:, None]
    emb2 = emb2 / emb2.norm(dim=1)[:, None]    
    c = torch.sum(emb1 * emb2, dim=-1)
    c -= m
    return torch.exp(s*c)

In [23]:
a = torch.tensor([[1, 1, 1, 1],
                 [1, 0, 0, 0]]).float()

p = torch.tensor([[2, 1, 3, 1],
                 [0, 0, 0, 1]]).float()


n = torch.tensor([[-2, 1, 0, 1],
                 [0, 4, 5, 1]]).float()

In [24]:
hypershperical_cosine_margin_loss(a, p)

tensor([1.4973, 0.6065])

In [25]:
def cosine_margin_triplet_loss(emb_anchor, emb_pos, emb_neg, s=1.0, m=0.5):
    d_pos = hypershperical_cosine_margin_loss(emb_anchor, emb_pos, s=s, m=m)
    d_neg = hypershperical_cosine_margin_loss(emb_anchor, emb_neg, s=s, m=0.0)    
    
    return -1* torch.log(d_pos/(d_pos+d_neg))

In [26]:
cosine_margin_triplet_loss(a, p, n)

tensor([0.5115, 0.9741])

In [27]:
np.arccos(0.7)

0.7953988301841436

In [51]:
hdf5_file = "../processed_data/pre_tokenized_peptides_dataset.hdf5"
dset = Self_NonSelf_PeptideDataset(hdf5_file, gen_size=1000, val_size=20)
dloader = DataLoader(dset, batch_size=8)
batch = next(iter(dloader))

peptides, labels = batch
peptides = peptides.float()

In [62]:
peptides = peptides[:-1]
labels = labels[:-1]
labels

tensor([ 1, -1,  1, -1,  1, -1,  1])

In [104]:
# def hypershperical_cosine_margin_similarity(emb1, emb2, s=1.0, m=0.5):
#     # Embs must be normalized
#     # emb1 = emb1 / emb1.norm(dim=1)[:, None]
#     # emb2 = emb2 / emb2.norm(dim=1)[:, None]    
#     c = torch.sum(emb1 * emb2, dim=-1)
#     c -= m
#     return torch.exp(s*c)

def hypershperical_cosine_margin_similarity(emb1, emb2, s=1.0, m=0.5):
    # Embs must be normalized
    # emb1 = emb1 / emb1.norm(dim=1)[:, None]
    # emb2 = emb2 / emb2.norm(dim=1)[:, None]    
    c = torch.mm(emb1, emb2.transpose(1, 0))
    c -= m
    return torch.exp(s*c)


class CustomCMT_Loss(nn.Module):
    def __init__(self, s=1.0, m=0.5):
        super().__init__()
        self.s = s
        self.m = m
        
    def forward(self, embeddings, labels):
        embeddings = embeddings / embeddings.norm(dim=1)[:, None]
        
        ix = (labels==1)
        pos_embs = embeddings[ix]
        neg_embg = embeddings[~ix]
        
        pos_sims = hypershperical_cosine_margin_similarity(pos_embs, pos_embs, s=self.s, m=self.m)
        pos_sims -= (math.e * torch.eye(len(pos_sims), device=pos_sims.device))
        neg_sims = hypershperical_cosine_margin_similarity(pos_embs, neg_embg, s=self.s, m=0.0)       
        
        easy_pos_sims, _ = torch.max(pos_sims, dim=1)
        hard_neg_sims, _ = torch.max(neg_sims, dim=1)
        
        loss = -1* torch.log(easy_pos_sims/(easy_pos_sims+hard_neg_sims))
        return torch.mean(loss)

In [111]:
loss = CustomCMT_Loss(m=0.4)
out = loss(peptides, labels)
out

tensor(0.9220)

In [112]:
out

tensor(0.9220)

In [116]:
out

tensor(0.9220)

In [117]:
amino_acids = sorted(['A',
 'C',
 'D',
 'E',
 'F',
 'G',
 'H',
 'I',
 'K',
 'L',
 'M',
 'N',
 'P',
 'Q',
 'R',
 'S',
 'T',
 'V',
 'W',
 'X',
 'Y',
 '-'])

In [118]:
len(amino_acids)

22

In [119]:
from selfpeptide.utils.processing_utils import get_vocabulary_tokens

In [126]:
vocab = get_vocabulary_tokens()
vocab

['A',
 'C',
 'D',
 'E',
 'F',
 'G',
 'H',
 'I',
 'K',
 'L',
 'M',
 'N',
 'P',
 'Q',
 'R',
 'S',
 'T',
 'V',
 'W',
 'X',
 'Y',
 '-',
 '*']

In [124]:
[aa for aa in amino_acids if aa not in vocab]

[]

In [3]:
   

def hypershperical_cosine_margin_similarity(emb1, emb2, s=1.0, m=0.5):
    # Embs must be normalized
    # emb1 = emb1 / emb1.norm(dim=1)[:, None]
    # emb2 = emb2 / emb2.norm(dim=1)[:, None]    
    c = torch.mm(emb1, emb2.transpose(1, 0))
    c -= m
    return torch.exp(s*c)


class CustomCMT_Loss(nn.Module):
    def __init__(self, s=1.0, m=0.5):
        super().__init__()
        self.s = s
        self.m = m
        
        
    def forward(self, embeddings, labels):
        emb_norm = embeddings.norm(dim=1)
        l2_weight = 0.01
        
        embeddings = embeddings / embeddings.norm(dim=1)[:, None]
        
        ix = (labels==1)
        pos_embs = embeddings[ix]
        neg_embg = embeddings[~ix]
        
        pos_sims = hypershperical_cosine_margin_similarity(pos_embs, pos_embs, s=self.s, m=self.m)
        # pos_sims -= (math.e * torch.eye(len(pos_sims), device=pos_sims.device))
        neg_sims = hypershperical_cosine_margin_similarity(pos_embs, neg_embg, s=self.s, m=0.0)       
        
        easy_pos_sims, _ = torch.max(pos_sims - (math.e * torch.eye(len(pos_sims), device=pos_sims.device)), dim=1)
        hard_neg_sims, _ = torch.max(neg_sims, dim=1)
        
        cmt_loss = torch.mean(-1* torch.log(easy_pos_sims/(easy_pos_sims+hard_neg_sims)))
        
        hypersphere_reg = torch.mean(torch.square(emb_norm-self.s))
        loss = cmt_loss #+ l2_weight*hypersphere_reg
        
        logs = {"loss": loss.item(), "cmt_loss": cmt_loss.item(), 'hypersphere_reg': hypersphere_reg}
        return loss, logs

In [66]:
embs = torch.tensor([
    [1.0, 0, 2.0, 0, 3.0],
    [0.001, -4.0, -0.03, -2, -0.0001],
    [0.1, 0.4, 0.3, 0.8, 0.9],
    [0.2, 0, 0, 0, 0.01]
])
embs

tensor([[ 1.0000e+00,  0.0000e+00,  2.0000e+00,  0.0000e+00,  3.0000e+00],
        [ 1.0000e-03, -4.0000e+00, -3.0000e-02, -2.0000e+00, -1.0000e-04],
        [ 1.0000e-01,  4.0000e-01,  3.0000e-01,  8.0000e-01,  9.0000e-01],
        [ 2.0000e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e-02]])

In [67]:
embeddings = embs / torch.clamp(embs.norm(dim=1)[:, None], min=1e-8)
embeddings

tensor([[ 2.6726e-01,  0.0000e+00,  5.3452e-01,  0.0000e+00,  8.0178e-01],
        [ 2.2360e-04, -8.9441e-01, -6.7081e-03, -4.4720e-01, -2.2360e-05],
        [ 7.6472e-02,  3.0589e-01,  2.2942e-01,  6.1178e-01,  6.8825e-01],
        [ 9.9875e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  4.9938e-02]])

In [68]:
labels = torch.tensor([1, -1, 1, -1])
labels

tensor([ 1, -1,  1, -1])

In [87]:
ix = (labels==1)
pos_embs = embeddings[ix]
neg_embs = embeddings[~ix]

In [70]:
pos_embs

tensor([[0.2673, 0.0000, 0.5345, 0.0000, 0.8018],
        [0.0765, 0.3059, 0.2294, 0.6118, 0.6882]])

In [86]:
neg_embg

tensor([[ 2.2360e-04, -8.9441e-01, -6.7081e-03, -4.4720e-01, -2.2360e-05],
        [ 9.9875e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  4.9938e-02]])

In [76]:
pos_sims = hypershperical_cosine_margin_similarity(pos_embs, pos_embs, s=1.0, m=0.0)
        # pos_sims -= (math.e * torch.eye(len(pos_sims), device=pos_sims.device))
neg_sims = hypershperical_cosine_margin_similarity(pos_embs, neg_embg, s=1.0, m=0.0)       

In [77]:
pos_sims

tensor([[2.7183, 2.0035],
        [2.0035, 2.7183]])

In [78]:
neg_sims

tensor([[0.9965, 1.3593],
        [0.5777, 1.1171]])

In [79]:
easy_pos_sims, _ = torch.max(pos_sims - (math.e * torch.eye(len(pos_sims), device=pos_sims.device)), dim=1)
easy_pos_sims

tensor([2.0035, 2.0035])

In [80]:
hard_neg_sims, _ = torch.max(neg_sims, dim=1)
hard_neg_sims

tensor([1.3593, 1.1171])

In [81]:
aa = easy_pos_sims/(easy_pos_sims+hard_neg_sims)

In [82]:
torch.log(easy_pos_sims/(easy_pos_sims+hard_neg_sims))

tensor([-0.5179, -0.4431])

In [83]:
torch.mean(-1*torch.log(aa))

tensor(0.4805)

In [84]:
torch.dot(embeddings[0], embeddings[1])

tensor(-0.0035)

In [88]:
torch.mm(pos_embs, neg_embs.transpose(1, 0))

tensor([[-0.0035,  0.3070],
        [-0.5487,  0.1107]])