In [2]:
import torch 
import torch.nn as nn
import pytorch_lightning as pl
from torchsummary import summary
from torch.utils.data import DataLoader, Dataset, random_split


In [3]:

# Embedding vector --> From LLM Encoder & Graph encoder || MHC & peptide 
# Siam model ( MHC, Peptide)-> 1 if binds else 0 
# input_dim = 1280

In [4]:
class EmbeddingDataset(Dataset):
    
    def __init__(self, hla_embedding, peptide_embedding) -> None:
        super(EmbeddingDataset).__init__()

        self.hla_embedding = hla_embedding
        self.peptide_embedding = peptide_embedding

    def __len__(self):
        assert len(self.hla_embedding) == len(self.peptide_embedding)
        return len(self.hla_embedding)
    
    def __getitem__(self, idx):        
        return self.hla_embedding[idx], self.peptide_embedding[idx]



class EmbeddingDataModule(pl.LightningDataModule):

    def __init__(self, hla, peptide ,batch_size):
        super(EmbeddingDataModule).__init__()
        self.hla = hla     
        self.peptide = peptide 
        self.batch_size = batch_size


    def setup(self, stage=None):
        full_dataset = EmbeddingDataset(self.hla,self.peptide)
        # datasplit = random_split(full_dataset, [int(len(full_dataset) * 0.8), len(full_dataset) - int(len(full_dataset) * 0.8)])
        # self.train_ids,self.val_ids = datasplit[0].dataset.identifiers,datasplit[1].dataset.identifiers
        self.train, self.val = random_split(full_dataset, [int(len(full_dataset) * 0.8), len(full_dataset) - int(len(full_dataset) * 0.8)])
        # self.train, self.val = datasplit[0].dataset.data,datasplit[1].dataset.data

    def train_dataloader(self):
        return DataLoader(self.train, batch_size=self.batch_size, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.val, batch_size=self.batch_size, num_workers=4)





In [26]:


class SiameseModel(nn.Module):
    def __init__(self,input_dim ) -> None:
        super(SiameseModel,self).__init__()
        self.input_dim = input_dim

        self.featurizer = nn.Sequential(
                nn.Conv1d(self.input_dim, 512, kernel_size=3, stride=1, padding=1),
                nn.ReLU(),
                nn.Conv1d(512, 256, kernel_size=3, stride=1, padding=1),
                nn.ReLU(),
                nn.Conv1d(256, 128, kernel_size=3, stride=1, padding=1),
                nn.ReLU(),
                nn.AdaptiveMaxPool1d(1) ,
                nn.Flatten() 

            )
        
        self.conv2d_block = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=(3, 3), stride=1, padding=1), 
            nn.ReLU(),
            nn.AdaptiveMaxPool2d(1),  # [batch_size, 64, 1, 1]
            nn.Flatten(),
        )

        self.scorecompute_block = nn.Sequential(
            
                nn.Linear(128,32),
                nn.ReLU(),
                nn.Linear(32,1),
                nn.Sigmoid(),
        )

    
    # def comb_vecs(self, vec1, vec2):

    #     vec1, vec2 = torch.reshape(vec1,(1,128)),torch.reshape(vec2,(1,128))
    #     combined = torch.cat((vec1, vec2), dim=0)

    #     return torch.reshape(combined, (1,1,256))


    def contrastive_loss(vec1,vec2,score_assoc):

        margin = 1.5
        D  = (vec1 - vec2).pow(2).sum(1).sqrt()

        return 1/2( score_assoc * D.pow(2) + (1-score_assoc)*max(0,margin-D)**2 )




    def forward(self,x1,x2):
        feat_x1,feat_x2 = self.featurizer(x1), self.featurizer(x2)

        print(feat_x1)
        print(feat_x2)
        print(feat_x1.shape)
        print(feat_x2.shape)

        # combined_representation = self.comb_vecs(feat_x1,feat_x2)
        # print(combined_representation.shape)

        combined_representation =torch.reshape(feat_x1 - feat_x2,(1,128))

        # conv_out = self.conv2d_block(combined_representation)
        # print(conv_out)
        
        score = self.scorecompute_block(combined_representation)
        print(score.shape)

        return score



    def training_step(self,batch,):
        mha, peptide, association = batch 

        predicted_association  = self.forward(mha,peptide)

        train_loss = 
        








In [27]:

test = torch.randn((100,1,1280))



In [28]:
model = SiameseModel(1)

score = model(test[0],test[1])

tensor([[0.0778],
        [0.2020],
        [0.2769],
        [0.1718],
        [0.2195],
        [0.0522],
        [0.5762],
        [0.4831],
        [0.2684],
        [0.3621],
        [0.5332],
        [0.1550],
        [0.0638],
        [0.2826],
        [0.4244],
        [0.2985],
        [0.3986],
        [0.1462],
        [0.2447],
        [0.3184],
        [0.5004],
        [0.1554],
        [0.2892],
        [0.2846],
        [0.2496],
        [0.2505],
        [0.3420],
        [0.2486],
        [0.0232],
        [0.4192],
        [0.3390],
        [0.1923],
        [0.2035],
        [0.1990],
        [0.3161],
        [0.2595],
        [0.2711],
        [0.4170],
        [0.2540],
        [0.2942],
        [0.2324],
        [0.0769],
        [0.2822],
        [0.1591],
        [0.1210],
        [0.1413],
        [0.1273],
        [0.0135],
        [0.3423],
        [0.3325],
        [0.0124],
        [0.2872],
        [0.1100],
        [0.1212],
        [0.3964],
        [0

In [14]:
score.detach().numpy()


array([[0.5213428]], dtype=float32)

In [29]:
test[0]

tensor([[-0.7917,  1.5126,  0.3925,  ..., -0.9828,  0.4899,  0.3739]])

In [30]:
test[1]

tensor([[-0.4881, -0.6606, -1.8935,  ..., -0.8714,  0.2259, -2.2273]])

In [39]:
(test[0] - test[1]).pow(2).sum(1).sqrt()

tensor([50.8710])

In [41]:


def contrastive_loss(vec1,vec2,score_assoc):

    margin = 1.5
    D  = (vec1 - vec2).pow(2).sum(1).sqrt()

    return (1/2)*( score_assoc * D.pow(2) + (1-score_assoc)*max(0,margin-D)**2 )



In [43]:
contrastive_loss(test[0],test[1],0.1)

tensor([129.3928])