In [1]:
import torchvision.datasets as dset 
import torchvision.transforms as transforms
from torch.utils.data import DataLoader,Dataset
import numpy as np
import torch
from torch.autograd import Variable   
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import pandas as pd
from sklearn.model_selection import train_test_split
import json
from tqdm import tqdm
import matplotlib.pyplot as plt
import glob


In [2]:
from esm.models.esmc import ESMC
from esm.sdk.api import ESMProtein, LogitsConfig

In [3]:
class SiameseNetwork(nn.Module):
    '''
    使用 Transformer 架构和多头注意力机制进行序列特征提取，输出特征提取后的信息。
    '''

    def __init__(self, embedding_dim = 1152, NormalizedSequenceLength = 50,):
        """
        初始化 Transformer 架构。

        参数:
            embedding_dim (int): 输入嵌入的维度（默认 1152）。
        """
        super(SiameseNetwork, self).__init__()

        self.embedding_dim = embedding_dim
        self.NormalizedSequenceLength = NormalizedSequenceLength


        self.transformer_encoder_layer1 = nn.TransformerEncoderLayer(
            d_model = self.embedding_dim,  # 输入特征维度
            nhead=8,      # 多头注意力头数
            dim_feedforward=512,  # 前馈网络隐藏层维度
            dropout=0.1,  # 丢弃率
            activation='relu'  # 激活函数
        )

        self.transformer_encoder1 = nn.TransformerEncoder(
            self.transformer_encoder_layer1,  # 编码层
            num_layers=2  # 编码层数
        )

        self.linear_layers1 = nn.Sequential(
            nn.Linear(self.NormalizedSequenceLength * 36, 512),
            nn.ReLU()
            )
        

        self.linear_layers2 = nn.Sequential(
            nn.Linear(self.embedding_dim, 512),
            nn.ReLU()
            )
    
        self.transformer_encoder_layer2 = nn.TransformerEncoderLayer(
        d_model = 512,  # 输入特征维度
        nhead=8,      # 多头注意力头数
        dim_feedforward=512,  # 前馈网络隐藏层维度
        dropout=0.1,  # 丢弃率
        activation='relu'  # 激活函数
        )

        self.transformer_encoder2 = nn.TransformerEncoder(
            self.transformer_encoder_layer1,  # 编码层
            num_layers=2  # 编码层数
        )





    def _forward_ligand(self, x:torch.Tensor):
        '''
        前向传播。

        参数:
            x (torch.Tensor): 输入张量，形状为 [batch_size, ESMC_size, 1, sequence_length, embedding_dim]。

        返回:
            torch.Tensor: 输出张量，形状为 
        '''
        x = x.squeeze(2)  # 去掉第二维度 [batch_size, ESMC_size, sequence_length, embedding_dim]
        x = x.reshape(x.size(0), -1, x.size(-1))  # 展平 [batch_size, ESMC_size * sequence_length, embedding_dim]
        x = x.permute(1, 0, 2)  # 转置 [ESMC_size * sequence_length, batch_size, embedding_dim]
        x = self.transformer_encoder1(x)  # 编码器  [ESMC_size * sequence_length, batch_size, embedding_dim] [1800, 32, 1152]
        x = x.permute(1, 2, 0)  # 转置 [batch_size, embedding_dim, ESMC_size * sequence_length]
        x = self.linear_layers1(x)  # 全连接层 [batch_size, embedding_dim , 512]
        x = x.permute(2, 0, 1)  # [512, batch_size, embedding_dim]
        x = self.transformer_encoder1(x)  # 编码器  [512, batch_size, embedding_dim]
        x = self.linear_layers2(x) # 全连接层 [512, batch_size, 512]
        x = self.transformer_encoder_layer2(x)  # 编码器  [512, batch_size, 512]
        x = x.permute(1, 0, 2)   # [batch_size, 512, 512]
        print(x.shape)

        


        return x
    

    def _forward_receptor(self, x:torch.Tensor):
        '''
        前向传播。

        参数:
            x (torch.Tensor): 输入张量，形状为 [batch_size, ESMC_size, 1, sequence_length, embedding_dim]。

        返回:            
            torch.Tensor: 输出张量，形状为
        '''
        x = x.squeeze(2)  # 去掉第二维度 [batch_size, ESMC_size, sequence_length, embedding_dim]
        x = x.reshape(x.size(0), -1, x.size(-1))  # 展平 [batch_size, ESMC_size * sequence_length, embedding_dim]
        x = x.permute(1, 0, 2)  # 转置 [ESMC_size * sequence_length, batch_size, embedding_dim]
        x = self.transformer_encoder1(x)  # 编码器  [ESMC_size * sequence_length, batch_size, embedding_dim] [1800, 32, 1152]
        x = x.permute(1, 2, 0)  # 转置 [batch_size, embedding_dim, ESMC_size * sequence_length]
        x = self.linear_layers1(x)  # 全连接层 [batch_size, embedding_dim , 512]
        x = x.permute(2, 0, 1)  # [512, batch_size, embedding_dim]
        x = self.transformer_encoder1(x)  # 编码器  [512, batch_size, embedding_dim]
        x = self.linear_layers2(x) # 全连接层 [512, batch_size, 512]
        x = self.transformer_encoder_layer2(x)  # 编码器  [512, batch_size, 512]
        x = x.permute(1, 0, 2)   # [batch_size, 512, 512]
        print(x.shape)

        return x


    def forward(self, input1, input2):
        output1 = self._forward_ligand(input1)
        output2 = self._forward_receptor(input2)
        return output1, output2

In [4]:
SNW = SiameseNetwork()


print(SNW)

SiameseNetwork(
  (transformer_encoder_layer1): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=1152, out_features=1152, bias=True)
    )
    (linear1): Linear(in_features=1152, out_features=512, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=512, out_features=1152, bias=True)
    (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
  )
  (transformer_encoder1): TransformerEncoder(
    (layers): ModuleList(
      (0-1): 2 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=1152, out_features=1152, bias=True)
        )
        (linear1): Linear(in_features=1152, out_features=512, bias=True)
        (dropout): Dropout(p=0.



In [5]:
from torchsummary import summary
summary(SNW)

TypeError: summary() missing 1 required positional argument: 'input_size'