In [266]:
import torch 
import torch.nn as nn
import pytorch_lightning as pl
from torchsummary import summary
from torch.utils.data import DataLoader, Dataset, random_split


In [2]:

# Embedding vector --> From LLM Encoder & Graph encoder || MHC & peptide 
# Siam model ( MHC, Peptide)-> 1 if binds else 0 
# input_dim = 1280

In [None]:
class EmbeddingDataset(Dataset):
    
    def __init__(self, hla_embedding, peptide_embedding) -> None:
        super(EmbeddingDataset).__init__()

        self.hla_embedding = hla_embedding
        self.peptide_embedding = peptide_embedding

    def __len__(self):
        assert len(self.hla_embedding) == len(self.peptide_embedding)
        return len(self.hla_embedding)
    
    def __getitem__(self, idx):        
        return self.hla_embedding[idx], self.peptide_embedding[idx]



class EmbeddingDataModule(pl.LightningDataModule):

    def __init__(self, hla, peptide ,batch_size):
        super(EmbeddingDataModule).__init__()
        self.hla = hla     
        self.peptide = peptide 
        self.batch_size = batch_size

    def setup(self, stage=None):
        full_dataset = EmbeddingDataset(self.hla,self.peptide)
        # datasplit = random_split(full_dataset, [int(len(full_dataset) * 0.8), len(full_dataset) - int(len(full_dataset) * 0.8)])
        # self.train_ids,self.val_ids = datasplit[0].dataset.identifiers,datasplit[1].dataset.identifiers
        self.train, self.val = random_split(full_dataset, [int(len(full_dataset) * 0.8), len(full_dataset) - int(len(full_dataset) * 0.8)])
        # self.train, self.val = datasplit[0].dataset.data,datasplit[1].dataset.data

    def train_dataloader(self):
        return DataLoader(self.train, batch_size=self.batch_size, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.val, batch_size=self.batch_size, num_workers=4)





In [234]:


class SiameseModel(nn.Module):
    def __init__(self,input_dim ) -> None:
        super(SiameseModel,self).__init__()
        self.input_dim = input_dim

        self.featurizer = nn.Sequential(
                nn.Conv1d(self.input_dim, 512, kernel_size=3, stride=1, padding=1),
                nn.ReLU(),
                nn.Conv1d(512, 256, kernel_size=3, stride=1, padding=1),
                nn.ReLU(),
                nn.Conv1d(256, 128, kernel_size=3, stride=1, padding=1),
                nn.ReLU(),
                nn.AdaptiveMaxPool1d(1) ,
                nn.Flatten() 

            )
        
        # self.conv2d_block = nn.Sequential(
        #     nn.Conv2d(1, 64, kernel_size=(3, 3), stride=1, padding=1), 
        #     nn.ReLU(),
        #     nn.AdaptiveMaxPool2d(1),  # [batch_size, 64, 1, 1]
        #     nn.Flatten(),
        # )

        self.scorecompute_block = nn.Sequential(
            
                nn.Linear(128,32),
                nn.ReLU(),
                nn.Linear(32,1),
                nn.Sigmoid(),
        )

    
    # def comb_vecs(self, vec1, vec2):

    #     vec1, vec2 = torch.reshape(vec1,(1,128)),torch.reshape(vec2,(1,128))
    #     combined = torch.cat((vec1, vec2), dim=0)

    #     return torch.reshape(combined, (1,1,256))


    def forward(self,x1,x2):
        feat_x1,feat_x2 = self.featurizer(x1), self.featurizer(x2)

        # print(feat_x1)
        # print(feat_x1.shape)

        # combined_representation = self.comb_vecs(feat_x1,feat_x2)
        # print(combined_representation.shape)

        combined_representation =torch.reshape(feat_x1 - feat_x2,(1,128))

        # conv_out = self.conv2d_block(combined_representation)
        # print(conv_out)
        
        score = self.scorecompute_block(combined_representation)
        print(score.shape)

        return score





In [235]:

test = torch.randn((100,1,1280))



In [254]:
model = SiameseModel(1)

score = model(test[0],test[1])

torch.Size([1, 1])


In [265]:
score.detach().numpy()

array([[0.5289478]], dtype=float32)