In [1]:
#!source /usr/local/Ascend/ascend-toolkit/set_env.sh
import torch

import torch.nn as nn
import math
from functools import partial
from esm.modules import ContactPredictionHead, ESM1bLayerNorm, RobertaLMHead, TransformerLayer, MultiheadAttention 
import esm
from typing import Union

In [2]:
# print(f"Number of NPU devices: {torch.npu.device_count()}")
# print(f"Current NPU device index: {torch.npu.current_device()}")
# print(f"NPU device name: {torch.npu.get_device_name(torch.npu.current_device())}")

In [3]:
class GroundingAttention(nn.Module):
    def __init__(self, dim, num_heads=4, qkv_bias=True,
                 attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim**-0.5

        self.kv = nn.Linear(dim, dim*2, bias=qkv_bias)
        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        # self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x, r):
        B, N, C = x.shape
        B_, N_, C_ = r.shape

        kv = self.kv(r).reshape(B_, N_, 2, self.num_heads, C_ //
                                self.num_heads).permute(2, 0, 3, 1, 4)
        k, v = kv.unbind(0)
        q = self.q(x).reshape(B, N, self.num_heads, C //
                              self.num_heads).permute(0, 2, 1, 3)

        attn = (q @ k.transpose(-2, -1)) * self.scale  # (B, heads, N, N_)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        # x = self.proj(x)
        x = self.proj_drop(x)
        return x

In [4]:
class FFN(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(FFN, self).__init__()
        self.relu = nn.ReLU()
        self.linear1 = nn.Linear(input_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, input_dim)

    def forward(self, x):
        """
        输入 x 的形状: (batch_size, seq_len, input_dim)
        输出 y 的形状: (batch_size, seq_len, output_dim)
        """
        residual = x
        # x = x.view(-1, x.size(-1))  # 将 x 展平成二维
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        # x = x.view(batch_size, seq_len, -1)  # 将 x 恢复成三维
        x = residual + x
        return x

In [5]:
class bertlayer(nn.Module):
    def __init__(self, embeddingdim, hidden_dim, num_head = 16):
        super(bertlayer, self).__init__()
        self.atte_norm = ESM1bLayerNorm(embeddingdim)
        self.ffn_norm = ESM1bLayerNorm(embeddingdim)
        self.atte = torch.nn.MultiheadAttention(embed_dim = embeddingdim, num_heads = num_head, dropout = 0.0)
        self.ffn = FFN(embeddingdim, hidden_dim)
    def forward(self, x, x_padding_mask):
        residual = x
        x = self.atte_norm(x)
        x, attn = self.atte(x, x, x, key_padding_mask = x_padding_mask)
        x = x + residual
        x = x + self.ffn(self.ffn_norm(x))
        return x, attn



In [6]:
# class InteractionBlock(nn.Module):
#     def __init__(self, embed_dim, ffn_dim, BertLayerNorm = ESM1bLayerNorm, attention_heads = 20, add_bias_kv = True, use_rotary_embeddings = True):
#         super(InteractionBlock, self).__init__()
#         # self.injector_query_norm = norm_layer(embedding_dim)
#         # self.injector_kv_norm = norm_layer(embedding_dim)
#         # self.extractor_query_norm = norm_layer(embedding_dim)
#         # self.extractor_kv_norm = norm_layer(embedding_dim)
#         # self.extractor_norm = norm_layer(embedding_dim)
#         # self.injector = GroundingAttention(embedding_dim)
#         # self.block = GroundingAttention(embedding_dim)
#         # self.extractor = GroundingAttention(embedding_dim)
#         # self.extractor_ffn = FFN(embedding_dim * ffn_dim, ffn_dim_rate * embedding_dim * ffn_dim)
#         self.attention_heads = attention_heads
#         self.embed_dim = embed_dim
#         self.ffn_dim = ffn_dim
#         self.injector_q_norm = BertLayerNorm(embed_dim)
#         self.injector_kv_norm = BertLayerNorm(embed_dim)
#         self.injector = GroundingAttention(embed_dim)
#         self.block = TransformerLayer(
#             self.embed_dim,
#             4 * self.embed_dim, # embed_dim = 1280
#             self.attention_heads, # 20
#             add_bias_kv=False,
#             use_esm1b_layer_norm=True,
#             use_rotary_embeddings=True,
#         )
#         self.extractor_q_norm = BertLayerNorm(embed_dim)
#         self.extractor_kv_norm = BertLayerNorm(embed_dim)
#         self.extractor = GroundingAttention(embed_dim)
#         self.ffn = FFN(embed_dim, ffn_dim)
#     def forward(self, x, r, self_attn_mask=None, self_attn_padding_mask=None, need_head_weights=False):
#         # x = self.injector(self.injector_query_norm(x), self.injector_kv_norm(r)) + x
#         # x = self.block(x, x)
#         # r = self.extractor(self.extractor_query_norm(r), self.extractor_kv_norm(x)) + r
#         # r = r + self.extractor_ffn(self.extractor_norm(r))
#         
#         
#         # x, _ = self.injector_attention(
#         #     query=self.injector_q_norm(x),
#         #     key=self.injector_kv_norm(r),
#         #     value=self.injector_kv_norm(r),
#         #     key_padding_mask=self_attn_padding_mask,
#         #     need_weights=True,
#         #     need_head_weights=need_head_weights,
#         #     attn_mask=self_attn_mask,
#         # )
#         # print(self.injector_q_norm(x).shape)
#         # print(self.injector_q_norm(r).shape)
#         x = x + self.injector(self.injector_q_norm(x),self.injector_kv_norm(r))
#         x, _ = self.block(x,
#                           self_attn_padding_mask=self_attn_padding_mask,
#                           need_head_weights=need_head_weights,
#                           )
#         r = r + self.extractor(self.extractor_q_norm(r),self.extractor_kv_norm(x))
#         r = self.ffn(r)
#         return x, r
class InteractionBlock(nn.Module):
    def __init__(self, embed_dim, ffn_dim, BertLayerNorm = ESM1bLayerNorm, attention_heads = 16, add_bias_kv = True, use_rotary_embeddings = True):
        super(InteractionBlock, self).__init__()
        # self.injector_query_norm = norm_layer(embedding_dim)
        # self.injector_kv_norm = norm_layer(embedding_dim)
        # self.extractor_query_norm = norm_layer(embedding_dim)
        # self.extractor_kv_norm = norm_layer(embedding_dim)
        # self.extractor_norm = norm_layer(embedding_dim)
        # self.injector = GroundingAttention(embedding_dim)
        # self.block = GroundingAttention(embedding_dim)
        # self.extractor = GroundingAttention(embedding_dim)
        # self.extractor_ffn = FFN(embedding_dim * ffn_dim, ffn_dim_rate * embedding_dim * ffn_dim)
        self.attention_heads = attention_heads
        self.embed_dim = embed_dim
        self.ffn_dim = ffn_dim
        self.injector_q_norm = BertLayerNorm(embed_dim)
        self.injector_kv_norm = BertLayerNorm(embed_dim)
        # self.injector = GroundingAttention(embed_dim)

        self.injector = torch.nn.MultiheadAttention(embed_dim = embed_dim, num_heads = attention_heads, dropout = 0.0)
        self.block = bertlayer(embed_dim, embed_dim * 4)
        self.extractor_q_norm = BertLayerNorm(embed_dim)
        self.extractor_kv_norm = BertLayerNorm(embed_dim)
        self.extractor = torch.nn.MultiheadAttention(embed_dim = embed_dim, num_heads = attention_heads, dropout = 0.0)
        self.ffn = FFN(embed_dim, ffn_dim)
    def forward(self, x, r, x_attn_padding_mask=None, r_attn_padding_mask = None, need_head_weights=False):
        # x = self.injector(self.injector_query_norm(x), self.injector_kv_norm(r)) + x
        # x = self.block(x, x)
        # r = self.extractor(self.extractor_query_norm(r), self.extractor_kv_norm(x)) + r
        # r = r + self.extractor_ffn(self.extractor_norm(r))
        
        
        # x, _ = self.injector_attention(
        #     query=self.injector_q_norm(x),
        #     key=self.injector_kv_norm(r),
        #     value=self.injector_kv_norm(r),
        #     key_padding_mask=self_attn_padding_mask,
        #     need_weights=True,
        #     need_head_weights=need_head_weights,
        #     attn_mask=self_attn_mask,
        # )
        # print(self.injector_q_norm(x).shape)
        # print(self.injector_q_norm(r).shape)
        
        
        # x = x + self.injector(self.injector_q_norm(x),self.injector_kv_norm(r))
        residual_x = x
        residual_r = r
        # x = self.injector_q_norm(x)
        # r = self.injector_kv_norm(r)
        x = x.transpose(0, 1)
        r = r.transpose(0, 1)
        x = self.injector_q_norm(x)
        r = self.injector_kv_norm(r)
        # print("r.shape")
        # print(r.shape)
        # print("r_attn_padding_mask.shape")
        # print(r_attn_padding_mask.shape)
        # print("x.shape")
        # print(x.shape)
        x, attni = self.injector(x, r, r, key_padding_mask=r_attn_padding_mask )
        x = x.transpose(0, 1)
        x = x + residual_x
        r = residual_r
        x = x.transpose(0, 1)
        x,attnb = self.block(x, x_attn_padding_mask)
        
        # r = r + self.extractor(self.extractor_q_norm(r),self.extractor_kv_norm(x))
        residual_r = r
        residual_x = x.transpose(0, 1)
        # x = self.extractor_kv_norm(x)
        # r = self.extractor_q_norm(r)
        # x = x.transpose(0, 1)
        r = r.transpose(0, 1)
        # print(r.shape)
        # print(x.shape)
        # print(x_attn_padding_mask.shape)
        x = self.extractor_kv_norm(x)
        r = self.extractor_q_norm(r)
        r, attne = self.extractor(r, x, x, key_padding_mask=x_attn_padding_mask)
        r = r.transpose(0, 1)
        r = r + residual_r
        
        r = self.ffn(r)
        return residual_x, r, [attni.detach(), attnb.detach(), attne.detach()]

In [7]:
# class part_test(nn.Module):
#     def __init__(self, embedding_dim, ffn_dim, device, num_layers =36):
#         super(part_test, self).__init__()
#         self.layers = nn.ModuleList()
#         for _ in range(num_layers):
#             block = MultiheadAttention(
#                     embedding_dim,
#                     20,
#                     add_bias_kv=False,
#                     add_zero_attn=False,
#                     use_rotary_embeddings=False,
#                     encoder_decoder_attention=True
#                 )
#             self.layers.append(block)
#     def forward(self, x, x_attn_padding_mask):
#         for layer in self.layers:
#             x, attn = layer(
#                 query=x,
#                 key=x,
#                 value=x,
#                 key_padding_mask=x_attn_padding_mask, #【[fffffttttt]】之类的
#                 need_weights=False,
#                 need_head_weights=False,# false
#                 attn_mask=None, # None
#             )
#         return x
# device = torch.device('npu:0')
# modelb = part_test(1280, 1280  * 4,device)
# X = torch.randn(30,15,1280)
# x_attn_padding_mask = torch.zeros(15,30, dtype=torch.bool)
# X = X.to(device)
# x_attn_padding_mask = x_attn_padding_mask.to(device)
# modelb.to(device)
# opt = torch.optim.Adam(modelb.parameters(), lr=0.00001)
# nuuu = 500000
# for i in range(nuuu):
#     x= modelb(X, x_attn_padding_mask)
#     y = x.sum()
#     y.backward()
#     opt.step()
#     # print(y)
#     if i ==2:
#         print("start")
#     if i ==nuuu - 1:
#         print("done")

In [8]:
# embedding_dim = 1280
# modelb = TransformerLayer(
#                 embedding_dim,
#                 embedding_dim *4, # embed_dim = 1280
#                 20, # 20
#                 add_bias_kv=False,
#                 use_esm1b_layer_norm=True,
#                 use_rotary_embeddings=False,
#             )
# X = torch.randn(50,15,embedding_dim)
# x_attn_padding_mask = torch.zeros(15,50, dtype=torch.bool)
# X = X.to(device)
# x_attn_padding_mask = x_attn_padding_mask.to(device)
# modelb.to(device)
# opt = torch_npu.optim.NpuFusedAdamW(modelb.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False)
# nuuu = 5000
# for i in range(nuuu):
#     x, _ = modelb(X,
#         self_attn_padding_mask=x_attn_padding_mask,
#         need_head_weights=False,
#     )
#     y = x.sum()
#     y.backward()
#     opt.step()
#     # print(y)
#     if i ==2:
#         print("start")
#     if i ==nuuu - 1:
#         print("done")

In [9]:
import esm
class DynamicFeatureSelector(nn.Module):
    def __init__(self, input_size, num_features, num_layers=4):
        super(DynamicFeatureSelector, self).__init__()
        
        # 自动计算每层的节点数量
        self.hidden_layers = []
        self.batch_norm_layers = []  # 用于存储 BatchNorm 层
        current_size = input_size
        decrement = (input_size - num_features) // num_layers  # 递减的步长

        # 添加隐藏层和 BatchNorm 层
        for _ in range(num_layers):
            if current_size <= num_features:
                break
            next_size = max(current_size - decrement, num_features)
            self.hidden_layers.append(nn.Linear(current_size, next_size))
            self.batch_norm_layers.append(nn.BatchNorm1d(next_size))  # 添加 BatchNorm 层
            current_size = next_size

        # 将隐藏层和 BatchNorm 层转换为 ModuleList
        self.hidden_layers = nn.ModuleList(self.hidden_layers)
        self.batch_norm_layers = nn.ModuleList(self.batch_norm_layers)
        self.output_layer = nn.Linear(current_size, num_features)  # 输出层

    def forward(self, x):
        for layer, batch_norm in zip(self.hidden_layers, self.batch_norm_layers):
            x = layer(x)  # 前向传播
            x = batch_norm(x)  # 归一化
            x = torch.relu(x)  # 激活函数
        x = self.output_layer(x)
        return x

class mpi_adapter(nn.Module):
    def __init__(self, embedding_dim, ffn_dim, num_layers =39):
        super(mpi_adapter, self).__init__()
        # self.embedding = nn.Embedding(num_embeddings=22, embedding_dim=embedding_dim)
        # self.par_position = PositionalEncoding(num_hiddens=embedding_dim, max_len = par_len)
        # self.mut_position = PositionalEncoding(num_hiddens=embedding_dim, max_len = mut_len)
        # self.mut_esm, _ = esm.pretrained.esm2_t33_650M_UR50D()
        # _, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D()
        # self.padding_idx = self.alphabet.padding_idx
        # self.device = device
        # self.esm.to(self.device)
        # self.esm.eval()
        self.liner_mut = nn.Linear(1280,embedding_dim)
        self.liner_par = nn.Linear(1280,embedding_dim)
        # self.batch_converter = self.alphabet.get_batch_converter()
        self.layers = nn.ModuleList()
        for _ in range(num_layers):
            self.layers.append(InteractionBlock(embedding_dim, ffn_dim))
    def forward(self, mut0, mut1, par, mut0_padding_mask, par_padding_mask):
        # print([(str(0), sequence) for sequence in mut0s])
        # print([(str(0), sequence) for sequence in mut1s])
        # print([(str(0), sequence) for sequence in pars])
        # _, _, mut0 = self.batch_converter([(str(0), sequence) for sequence in mut0s])
        # _, _, mut1 = self.batch_converter([(str(0), sequence) for sequence in mut1s])
        # _, _, par = self.batch_converter([(str(0), sequence) for sequence in pars])
        # print(par)
        # mut0 = mut0s.to(self.device)
        # mut1 = mut1s.to(self.device)
        # par = pars.to(self.device)
        # mut0_padding_mask = mut0_padding_mask.to(self.device)
        # par_padding_mask = par_padding_mask.to(self.device)
        # mut0_padding_mask = mut0.eq(self.padding_idx)
        # mut1_padding_mask = mut1.eq(self.padding_idx)
        # par_padding_mask = par.eq(self.padding_idx)
        # mut0_padding_mask = torch.cat((mut0_padding_mask,mut1_padding_mask),dim=1)

        attn_list = []
        mut0 = torch.cat((mut0,mut1),dim=1)
        mut0 = self.liner_mut(mut0)
        par = self.liner_par(par)
        for layer in self.layers:
            mut0, par, lists = layer(mut0, par, x_attn_padding_mask = mut0_padding_mask, r_attn_padding_mask = par_padding_mask)
            attn_list.append(lists)
        return par, attn_list
        # mut0 = torch.cat((mut0,mut1),dim=1)
        # mut0 = self.liner_mut(mut0)
        # par = self.liner_par(par)
        # for layer in self.layers:
        #     mut0, par = layer(mut0, par, x_attn_padding_mask = mut0_padding_mask, r_attn_padding_mask = par_padding_mask)
        # return par
    
# device = torch.device('npu:0')
# # modelb =  TransformerLayer(
# #             1280,
# #             4000, # embed_dim = 1280
# #             20, # 20
# #             add_bias_kv=False,
# #             use_esm1b_layer_norm=True,
# #             use_rotary_embeddings=False,
# #         )
# modelb = mpi_adapter(1280*3, 1280 * 3 * 4,device)
# X = torch.randn(10,15,1280)
# X2 = torch.randn(10,15,1280)
# R = torch.randn(10,100,1280)
# x_attn_padding_mask = torch.zeros(10,30, dtype=torch.bool)
# R_attn_padding_mask = torch.zeros(10,100, dtype=torch.bool)
# X = X.to(device)
# X2 = X2.to(device)
# R = R.to(device)
# x_attn_padding_mask = x_attn_padding_mask.to(device)
# R_attn_padding_mask = R_attn_padding_mask.to(device)
# modelb.to(device)
# opt = torch.optim.Adam(modelb.parameters(), lr=0.00001)
# nuuu = 500000
# for i in range(nuuu):
#     x= modelb(X, X2, R, x_attn_padding_mask, R_attn_padding_mask)
#     y = x.sum()
#     y.backward()
#     opt.step()
#     # print(y)
#     if i ==2:
#         print("start")
#     if i ==nuuu - 1:
#         print("done")


In [10]:
# for layer in modelb.layers:
#     print(next(layer.parameters()).device)

In [11]:
# embedding_dim = 1280*4
# batch_size = 400
# modelb = InteractionBlock(embedding_dim, embedding_dim * 4)
# X = torch.randn(batch_size ,30,embedding_dim)
# # X2 = torch.randn(batch_size ,15,embedding_dim)
# R = torch.randn(batch_size ,100,embedding_dim)
# x_attn_padding_mask = torch.zeros(batch_size ,30, dtype=torch.bool)
# R_attn_padding_mask = torch.zeros(batch_size ,100, dtype=torch.bool)
# X = X.to(device)
# # X2 = X2.to(device)
# R = R.to(device)
# x_attn_padding_mask = x_attn_padding_mask.to(device)
# R_attn_padding_mask = R_attn_padding_mask.to(device)
# modelb.to(device)
# opt = torch.optim.Adam(modelb.parameters(), lr=0.00001)
# nuuu = 500000
# for i in range(nuuu):
#     x , _= modelb(X, R, x_attn_padding_mask, R_attn_padding_mask)
#     y = x.sum()
#     y.backward()
#     opt.step()
#     if i ==2:
#         print("start")
#     if i ==nuuu - 1:
#         print("done")

In [12]:
class MLP_head(nn.Module):
    def __init__(self, embedding_dim, num_layers = 2):
        super(MLP_head, self).__init__()
        self.layers = nn.ModuleList()
        for _ in range(num_layers):
            self.layers.append(GroundingAttention(embedding_dim))
    def forward(self, x):
        for layer in self.layers:
            x = layer(x, x)
        return x[:, 0, :]

In [13]:
class MLP_head_one_layer(nn.Module):
    def __init__(self, embedding_dim, attention_heads=16):
        super(MLP_head_one_layer, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=attention_heads, dropout=0.0)

    def forward(self, x, x_attn_padding_mask, need_head_weights=False):
        # 处理输入
        result = x[:, 0:1, :]  # 取出第一位置的特征
        x = x.transpose(0, 1)  # 转换为 (seq_length, batch_size, embedding_dim)
        result = result.transpose(0, 1)  # 转换为 (batch_size, 1, embedding_dim)

        # 多头自注意力计算
        result, attn = self.attention(result, x, x, key_padding_mask=x_attn_padding_mask)

        # 处理输出
        result = result.transpose(0, 1)  # 转换回 (batch_size, 1, embedding_dim)
        result = result.squeeze(1)  # 去掉大小为1的维度
        return result, attn.detach()

In [14]:
#bert中的带权交叉熵顺势函数
class cross_entropy_bert(nn.Module):
    def __init__(self, device):
        super(cross_entropy_bert, self).__init__()
        # self.liner = nn.Linear(embedding_dim, 4)
        self.activate = nn.Sigmoid()
        self.device = device
    def forward(self, b_labels, outputs, weights):
        # weights = weights
        # outputs = self.liner(outputs)
        # print("output:")
        # print(outputs)
        # outputs = self.activate(outputs)
        labels = []
        for index, fue in enumerate(b_labels):
            labels.append(fue)
        loss_sum = torch.tensor(0)
        for i in range(outputs.shape[0]):
            back_part = torch.tensor(0).to(device)
            for j in outputs[i]:
                back_part = back_part + torch.exp(j).to(device)
            back_part = torch.log(back_part).to(device)
            loss_sum = weights[labels[i]] * ((-1) * outputs[i][labels[i]] + back_part)
        loss_sum = loss_sum/outputs.shape[0]
        return loss_sum


In [15]:
# class Sum_model(nn.Module):
#     def __init__(self, device, embedding_dim = 192,):
#         super(Sum_model, self).__init__()
#         self.device = device
#         self.result = nn.Parameter(torch.randn(1280))
#         self.backbone = mpi_adapter(embedding_dim, embedding_dim*4,device)
#         self.neck = MLP_head_one_layer(embedding_dim)
#         # self.head = cross_entropy_bert(embedding_dim, device)
#         self.head = nn.Linear(embedding_dim, 4)
#     def forward(self,  mut0s, mut1s, pars, mut0_padding_mask, par_padding_mask, weight, label):
#         # mut0s = mut0s.to(self.device)
#         # mut1s = mut1s.to(self.device)
#         # pars = pars.to(self.device)
#         # mut0_padding_mask = mut0_padding_mask.to(self.device)
#         # par_padding_mask = par_padding_mask.to(self.device)


#         res = self.result.repeat(pars.shape[0], 1).unsqueeze(1) 
#         pars = torch.concat([res,pars], dim = 1)
#         false_column = torch.zeros(par_padding_mask.size(0), 1, dtype=torch.bool).to(self.device)
#         # 在第一列前添加一排 False
#         par_padding_mask = torch.cat((false_column, par_padding_mask), dim=1)


#         x = self.backbone(mut0s, mut1s, pars, mut0_padding_mask, par_padding_mask)
#         # print("x1:")
#         # x = self.neck(pars, par_padding_mask)
#         x = self.neck(x, par_padding_mask)
#         # # print("x2")
#         # # print(x)
#         # # x = self.head(label, x, weight)
#         x = self.head(x)
#         return x

In [16]:
import torch
import torch.nn as nn

class Sum_model(nn.Module):
    def __init__(self, embedding_dim=192):
        super(Sum_model, self).__init__()
        self.result = nn.Parameter(torch.randn(1280))  # 不再指定设备
        self.backbone = mpi_adapter(embedding_dim, embedding_dim * 4)
        self.neck = MLP_head_one_layer(embedding_dim)
        self.head = nn.Linear(embedding_dim, 4)

    def forward(self, mut0s, mut1s, pars, mut0_padding_mask, par_padding_mask, weight, label):
        # 确保所有输入在同一设备上
        mut0s = mut0s.to(pars.device)  # 使用 pars 的设备
        mut1s = mut1s.to(pars.device)
        pars = pars.to(pars.device)
        mut0_padding_mask = mut0_padding_mask.to(pars.device)
        par_padding_mask = par_padding_mask.to(pars.device)

        # 扩展结果并拼接
        res = self.result.repeat(pars.shape[0], 1).unsqueeze(1)  # (batch_size, 1, embedding_dim)
        pars = torch.cat([res, pars], dim=1)  # 拼接操作

        # 创建并拼接 padding mask
        false_column = torch.zeros(par_padding_mask.size(0), 1, dtype=torch.bool).to(pars.device)
        par_padding_mask = torch.cat((false_column, par_padding_mask), dim=1)  # 在第一列前添加一排 False

        # 通过 backbone 和 neck 进行前向传播
        x, attn_backbone = self.backbone(mut0s, mut1s, pars, mut0_padding_mask, par_padding_mask)
        x_neck, attn_neck = self.neck(x, par_padding_mask)
        x = self.head(x_neck)
        return x, x_neck, attn_backbone, attn_neck

In [17]:
# device = torch.device('cuda:0')
# model = Sum_model()
# model.to(device)
# mut0 = torch.randn(7, 21, 1280).to(device)
# mut1 = torch.randn(7, 21, 1280).to(device)
# par0 = torch.randn(7, 593, 1280).to(device)
# mut0_mask = torch.randint(0, 2, (7, 21), dtype=torch.bool).to(device) # 随机生成布尔值
# mut1_mask = torch.randint(0, 2, (7, 21), dtype=torch.bool).to(device) # 随机生成布尔值
# par0_mask = torch.randint(0, 2, (7, 593), dtype=torch.bool).to(device) # 随机生成布尔值
# model(mut0, mut1, par0,torch.cat((mut0_mask,mut1_mask),dim=1),par0_mask, None,None).shape

In [18]:
# model = Sum_model()
# model.to(device)
# model = nn.DataParallel(model, device_ids=[0, 1, 2])  # 使用 DataParallel

# num_epch =21 
# mut0 = torch.randn(num_epch, 21, 1280).to(device)
# mut1 = torch.randn(num_epch, 21, 1280).to(device)
# par0 = torch.randn(num_epch, 1000, 1280).to(device)
# mut0_mask = torch.randint(0, 2, (num_epch, 21), dtype=torch.bool).to(device) # 随机生成布尔值
# mut1_mask = torch.randint(0, 2, (num_epch, 21), dtype=torch.bool).to(device) # 随机生成布尔值
# par0_mask = torch.randint(0, 2, (num_epch, 1000), dtype=torch.bool).to(device) # 随机生成布尔值
# # 定义损失和优化器
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# criterion = nn.MSELoss()  # 示例损失函数

# # 模拟训练循环
# for epoch in range(100000):
#     model.train()
#     optimizer.zero_grad()

#     # 前向传播
#     # output = model(x)
#     mut = model(mut0, mut1, par0 ,torch.cat((mut0_mask, mut1_mask),dim = 1) ,par0_mask, None, None)
#     # 计算损失
#     loss = mut.mean()

#     # 反向传播
#     loss.backward()
#     optimizer.step()

#     print(f'Epoch {epoch+1}, Loss: {loss.item()}')

In [19]:


# # model = MultiHeadAttentionModel(embed_dim, num_heads)
# model = mpi_adapter(1280, 1280 *4, device)
# model = nn.DataParallel(model, device_ids=[0, 1, 2])  # 使用 DataParallel
# model.to(device)
# num_epch =21 
# mut0 = torch.randn(num_epch, 21, 1280).to(device)
# mut1 = torch.randn(num_epch, 21, 1280).to(device)
# par0 = torch.randn(num_epch, 1000, 1280).to(device)
# mut0_mask = torch.randint(0, 2, (num_epch, 21), dtype=torch.bool).to(device) # 随机生成布尔值
# mut1_mask = torch.randint(0, 2, (num_epch, 21), dtype=torch.bool).to(device) # 随机生成布尔值
# par0_mask = torch.randint(0, 2, (num_epch, 1000), dtype=torch.bool).to(device) # 随机生成布尔值
# # 定义损失和优化器
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# criterion = nn.MSELoss()  # 示例损失函数

# # 模拟训练循环
# for epoch in range(100000):
#     model.train()
#     optimizer.zero_grad()

#     # 前向传播
#     # output = model(x)
#     mut = model(mut0, mut1, par0 ,torch.cat((mut0_mask, mut1_mask),dim = 1) ,par0_mask)
#     # 计算损失
#     loss = mut.mean()

#     # 反向传播
#     loss.backward()
#     optimizer.step()

#     print(f'Epoch {epoch+1}, Loss: {loss.item()}')

In [20]:
# import torch

# # 定义设备和模型
# device = torch.device('cuda:0')
# model = Sum_model(device)
# model = torch.nn.DataParallel(model, device_ids=[0, 1, 2])  # 使用 DataParallel 在多张卡上训练
# model.to(device)

# # 生成随机输入数据
# mut0 = torch.randn(7, 21, 1280).to(device)
# mut1 = torch.randn(7, 21, 1280).to(device)
# par0 = torch.randn(7, 593, 1280).to(device)
# mut0_mask = torch.randint(0, 2, (7, 21), dtype=torch.bool).to(device)  # 随机生成布尔值
# mut1_mask = torch.randint(0, 2, (7, 21), dtype=torch.bool).to(device)  # 随机生成布尔值
# par0_mask = torch.randint(0, 2, (7, 593), dtype=torch.bool).to(device)  # 随机生成布尔值

# # 在模型前向传播时使用 DataParallel
# output_shape = model(mut0, mut1, par0, torch.cat((mut0_mask, mut1_mask), dim=1), par0_mask, None, None).shape

In [21]:
print('hello')

hello


In [22]:
class PandasDataReader:
    def __init__(self, df, batch_size=1, shuffle=False):
        self.df = df.sample(frac=1).reset_index(drop=True) if shuffle else df
        self.batch_size = batch_size
        self.current_index = 0

    def __iter__(self):
        return self

    def __next__(self):
        if self.current_index >= len(self.df):
            raise StopIteration()

        batch = self.df.iloc[self.current_index:self.current_index+self.batch_size]
        self.current_index += self.batch_size

        # 如果 batch_size 为 1, 则直接返回 batch 的第一行
        return batch.iloc[0] if self.batch_size == 1 else batch
# data_reader = PandasDataReader(df, batch_size=7, shuffle=True)

In [23]:
class ESMModelWrapper(nn.Module):
    def __init__(self, model):
        super(ESMModelWrapper, self).__init__()
        self.model = model

    def forward(self, batch_tokens, repr_layers=[33], return_contacts=False):
        return self.model(batch_tokens, repr_layers=repr_layers, return_contacts=return_contacts)


class ESMFeatureEncoder(nn.Module):
    def __init__(self):
        super(ESMFeatureEncoder, self).__init__()
        self.device = 'cuda:2'
        self.model, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D()
        self.model.to(self.device)
        # self.model = torch.nn.DataParallel(self.model, device_ids=[0, 1, 2])
        self.model.eval()  # Set the model to evaluation mode
        self.batch_converter = self.alphabet.get_batch_converter()
        self.padding_idx = self.alphabet.padding_idx
        # Wrap the model with the ESMModelWrapper
        self.model = ESMModelWrapper(self.model)

    def encode(self, sequences):
        batch_labels, batch_strs, batch_tokens = self.batch_converter([(str(0), sequence) for sequence in sequences])
        batch_tokens = batch_tokens.to(self.device)
        batch_mask = batch_tokens.eq(self.padding_idx)
        # print(batch_tokens.shape)
        with torch.no_grad():
            results = self.model(batch_tokens, repr_layers=[33], return_contacts=False)
        token_representations = results['representations'][33]
        # print(results['representations'][33].mean(dim=1).unsqueeze(1).shape)
        return token_representations,batch_mask

In [24]:
# import pandas as pd
# data_path = 'mippi/processed_mutations.dataset'
# df = pd.read_pickle(data_path)

# new_df1 = df[df['label'] == 1].copy()
# new_df1['mut0'], new_df1['mut1'] = new_df1['mut1'], new_df1['mut0']
# new_df1['label'] = 3
# new_df2 = df[df['label'] == 3].copy()
# new_df2['mut0'], new_df2['mut1'] = new_df2['mut1'], new_df2['mut0']
# new_df2['label'] = 1

# new_df3 = df[df['label'] == 2].copy()
# new_df3['mut0'], new_df3['mut1'] = new_df3['mut1'], new_df3['mut0']
# new_df3['label'] = 2
# df = pd.concat([df, new_df1, new_df2, new_df3], ignore_index=True)
# df["position_total"] = df["Feature range(s)"].apply(
#     lambda x: sorted(set([int(y.split("-")[0]) for y in x] + [int(y.split("-")[1]) for y in x]))
# )
# df = df[df['mut0'] != df['mut1']]
# df["position"] = df["position_total"].apply(lambda x: math.ceil((min(x) + max(x)) / 2))
# print(df.shape)
# df.head()

In [25]:
len_half_protein_used =10 
# seq_lengths = df['Original sequence'].apply(lambda x: len(x[0]))
# # 选择唯一元素长度小于等于 10 的行
# df_filtered = df.loc[seq_lengths <= len_half_protein_used]
# seq_lengths = df['Resulting sequence'].apply(lambda x: len(x[0]))
# # 选择唯一元素长度小于等于 10 的行
# df_filtered = df.loc[seq_lengths <= len_half_protein_used]
# df = df[df["position_total"].apply(lambda x: max(x) - min(x) < 2* len_half_protein_used)]
# df.shape

In [26]:
# df = df[df['mut0'].str.len() <= 1500]
# df = df[df['par0'].str.len() <= 1000]
# df = df[df['label'] != 4]
# df.shape

In [27]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class GHMC_Loss(nn.Module):
    def __init__(self, device, bins=8, momentum=0.3):
        super(GHMC_Loss, self).__init__()
        self.device = device
        self.bins = bins
        self.momentum = momentum
        self.edges = [float(x) / self.bins for x in range(self.bins + 1)]
        if momentum > 0:
            self.acc_sum = np.zeros(bins)

    def forward(self, targets, logits, no_meaning):
        targets = torch.tensor(targets.to_list())
        targets = F.one_hot(targets, num_classes=4).float().to(self.device)
        # Calculate gradient norm
        edges = self.edges
        mmt = self.momentum
        weights = torch.zeros_like(logits)
        g = torch.abs(logits.softmax(dim=1).detach().to(self.device) - targets.to(self.device))

        tot = logits.shape[0] * logits.shape[1]  # Total number of elements
        n = 0  # n valid bins
        for i in range(self.bins):
            inds = (g >= edges[i]) & (g < edges[i + 1])
            num_in_bin = inds.sum().item()
            if num_in_bin > 0:
                if mmt > 0:
                    self.acc_sum[i] = mmt * self.acc_sum[i] + (1 - mmt) * num_in_bin
                    weights[inds] = tot / self.acc_sum[i]
                else:
                    weights[inds] = tot / num_in_bin
                n += 1
        if n > 0:
            weights = weights / n

        # Flatten targets to match logits shape
        targets = targets.argmax(dim=1)

        # Calculate the loss
        loss = F.cross_entropy(logits, targets, reduction='none')

        # Apply weights to the loss
        weights = weights.max(dim=1)[0]  # Get the maximum weight for each sample
        loss = (loss * weights).sum() / tot

        return loss

In [28]:
def positional_encoding(max_len, d_model):
    # 创建位置编码矩阵
    position = torch.arange(max_len, dtype=torch.float).unsqueeze(1)  # 形状为 (max_len, 1)
    div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(np.log(10000.0) / d_model))  # 计算分母
    
    # 计算位置编码
    pe = torch.zeros(max_len, d_model)  # 初始化位置编码矩阵
    pe[:, 0::2] = torch.sin(position * div_term)  # 偶数维度
    pe[:, 1::2] = torch.cos(position * div_term)  # 奇数维度
    pe.requires_grad = False
    return pe



In [29]:
import pandas as pd
x_train_fold = pd.read_csv('../data/x_train_fold_mirror_multi.csv')
x_test_fold = pd.read_csv('../data/x_test_fold_mirror_multi.csv')

In [30]:
from sklearn.model_selection import StratifiedKFold
from sklearn.utils.class_weight import compute_class_weight
import numpy as np
from IPython.display import clear_output
from sklearn.metrics import accuracy_score
import json
batch_size = 4
# 五折交叉验证
splits = 10
# Create StratifiedKFold object.
# 创建StratifiedKFold对象    StratifiedKFold是sklearn库中的一个类，用于将数据集进行分层抽样设置了4个分割（n_splits=splits），打乱数据顺序（shuffle=True）并设置随机种子（random_state=1）。
skf = StratifiedKFold(n_splits=splits, shuffle=True)# , random_state=1
# device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
device = torch.device('cuda:2')
# 定义优化器
model = Sum_model()
model.load_state_dict(torch.load('../mippi_variation/model_params_mirror_multi.pth'))
print(f"Using device: {device}")
model.to(device)
model.eval()
model_loss = GHMC_Loss(device)
esm_model = ESMFeatureEncoder()

Using device: cuda:2


In [31]:
import sys
def print_progress_bar(iteration, total, length=40):
    percent = (iteration / total) * 100
    filled_length = int(length * iteration // total)
    bar = '█' * filled_length + '-' * (length - filled_length)
    sys.stdout.write(f'\r|{bar}| {percent:.2f}% Complete')
    sys.stdout.flush()

In [32]:
position_embedding = positional_encoding(4 * len_half_protein_used + 2 , 1280).to(device)
def predict(positions, mut0, mut1, par0, device = device, esm_model = esm_model, model = model):
    # print([sequence for sequence in batch["mut0"]][0])
    # print(mut0)
    mut0,mut0_mask = esm_model.encode(mut0)
    # print(mut0_mask)
    mut1,mut1_mask = esm_model.encode(mut1)
    par0,par0_mask = esm_model.encode(par0)
    # zero_row = torch.zeros(mut0.shape[0], mut0.shape[1], 1).to(device)
    # mut0 = torch.cat((mut0, zero_row), dim=-1)
    # zero_row = torch.ones(mut1.shape[0], mut1.shape[1], 1).to(device)
    # mut1 = torch.cat((mut1, zero_row), dim=-1)
    #处理mut0，mut1，去除0，找到位置，接入全局变量，01拼接在模型里
    # mut0 = mut0[:, 1:, :]
    if mut0.shape[1] < 2*len_half_protein_used:
        mut0 = torch.cat((mut0.mean(dim=1, keepdim=True),mut0),dim=1)
        mut0_mask = mut0_mask.cpu().numpy()
        mut0_mask = np.concatenate([np.full((mut0_mask.shape[0], 1), False), mut0_mask], axis=1)
        mut1 = torch.cat((mut1.mean(dim=1, keepdim=True),mut1),dim=1)
        mut1_mask = mut1_mask.cpu().numpy()
        mut1_mask = np.concatenate([np.full((mut1_mask.shape[0], 1), False), mut1_mask], axis=1)
    else:
        result = torch.randn(mut0.shape[0],2*len_half_protein_used,mut0.shape[2])
        result_padding = (np.random.randint(0, 2, size=(mut0.shape[0], 2*len_half_protein_used)) == 1)

        for ia in range(len(positions)):
            position = int(positions[ia])
            if position - len_half_protein_used < 0 :
                result[ia, :, :] = mut0[ia,:2 * len_half_protein_used,:]
                result_padding[ia, :] = mut0_mask[ia, : 2 * len_half_protein_used].cpu()
            elif position + len_half_protein_used > mut0.shape[1] :
                result[ia, :, :] = mut0[ia,-2 * len_half_protein_used:,:]
                result_padding[ia, :] = mut0_mask[ia, -2 * len_half_protein_used].cpu()
            else:
                result[ia, :, :] = mut0[ia,position - len_half_protein_used:position + len_half_protein_used,:]
                result_padding[ia, :] = mut0_mask[ia, position - len_half_protein_used : position + len_half_protein_used].cpu()

        mut0 = torch.cat((mut0.mean(dim=1, keepdim=True).cpu(), result),dim=1)
        mut0_mask = result_padding
        mut0_mask = np.concatenate([np.full((mut0_mask.shape[0], 1), False), mut0_mask], axis=1)
        
        
        
        result = torch.randn(mut1.shape[0],2*len_half_protein_used,mut1.shape[2])
        result_padding = (np.random.randint(0, 2, size=(mut1.shape[0], 2*len_half_protein_used)) == 1)
        for ia in range(len(positions)):
            position = int(positions[ia])
            if position - len_half_protein_used < 0 :
                result[ia, :, :] = mut1[ia,:2 * len_half_protein_used,:]
                result_padding[ia, :] = mut1_mask[ia , : 2 * len_half_protein_used].cpu()
            elif position + len_half_protein_used > mut1.shape[1] :
                result[ia, :, :] = mut1[ia,-2 * len_half_protein_used : ,:]
                result_padding[ia, :] = mut1_mask[ia , -2 * len_half_protein_used].cpu()
            else:
                # print(i)
                result[ia, :, :] = mut1[ia,position - len_half_protein_used:position + len_half_protein_used,:]
                result_padding[ia, :] = mut1_mask[ia, position - len_half_protein_used : position + len_half_protein_used].cpu()
        mut1 = torch.cat((mut1.mean(dim=1, keepdim=True).cpu(), result),dim=1)
        mut1_mask = result_padding
        mut1_mask = np.concatenate([np.full((mut1_mask.shape[0], 1), False), mut1_mask], axis=1)
        
        
    mut0_mask = torch.from_numpy(mut0_mask)
    mut0_mask = mut0_mask.to(device)
    mut1_mask = torch.from_numpy(mut1_mask)
    mut1_mask = mut1_mask.to(device)
    mut0 = mut0.to(device)
    mut1 = mut1.to(device)
    par0 = par0.to(device)
    mut0 = mut0 + position_embedding[:mut0.shape[1] , : ]
    mut1 = mut1 + position_embedding[2 * len_half_protein_used + 1 : 2 * len_half_protein_used + 1 + mut1.shape[1] , : ]
    x, tensors, attn_backbone, attn_neck = model(mut0, mut1, par0,torch.cat((mut0_mask,mut1_mask),dim=1),par0_mask, None, None)
    _, predicted = torch.max(x, 1)
    return predicted, x, tensors, attn_backbone, attn_neck

label_def = []
pred_def = []
# a,b,c,d = [batch["position"]], [batch["mut0"]], [batch["mut1"]], [batch["par0"]]
# predicted_dif = predict(a,b,c,d)
# mut0,mut0_mask = esm_model.encode(['VSFRYIFGLPPLILVLLPVASSDCDIEGKDGKQYE','PPLILVLLPVASSDCDIEGKDGK'])
# predict([5], ['VSFRYIFGLPPLILVLLPVASSDCDIEGKDGKQYE'], ['VSFRRIFGLPPLILVLLPVASSDCDIEGKDGKQYE'], ['PPLILVLLPVASSDCDIEGKDGK'])
data_reader = PandasDataReader(x_test_fold, batch_size=1, shuffle=True)
tss = []
for batch in data_reader:
    # if len(batch["mut0"]) > 2*len_half_protein_used:
    #     continue
    predicted_dif, e, ts, attn_backbone, attn_neck = predict([batch["position"]], [batch["mut0"]], [batch["mut1"]], [batch["par0"]])
    label_def = label_def + [batch['label']]
    pred_def = pred_def + predicted_dif.tolist()
    tss.append(ts.detach().cpu())
    print_progress_bar(len(pred_def), x_test_fold.shape[0])
    tensors = torch.cat(tss, dim = 0)
    break
accuracy_score(pred_def, label_def)

|----------------------------------------| 0.02% Complete

0.0

In [33]:
import networkx as nx

# 从 GML 文件读取图
G = nx.read_gml('../manything/graph_ppi_only_p.gml')


In [34]:
import pandas as pd
pros = pd.read_csv('../manything/9606.protein.info.v12.0.txt', sep='\t')
pros.head()

Unnamed: 0,#string_protein_id,preferred_name,protein_size,annotation
0,9606.ENSP00000000233,ARF5,180,ADP-ribosylation factor 5; GTP-binding protein...
1,9606.ENSP00000000412,M6PR,277,Cation-dependent mannose-6-phosphate receptor;...
2,9606.ENSP00000001008,FKBP4,459,"Peptidyl-prolyl cis-trans isomerase FKBP4, N-t..."
3,9606.ENSP00000001146,CYP26B1,512,Cytochrome P450 26B1; Involved in the metaboli...
4,9606.ENSP00000002125,NDUFAF7,441,"Protein arginine methyltransferase NDUFAF7, mi..."


In [35]:
import pandas as pd
from Bio import SeqIO

# 定义FASTA文件路径
file_path = '../manything/9606.protein.sequences.v12.0.fa'

# 创建一个空的列表来存储序列数据
data = []

# 解析FASTA文件
for record in SeqIO.parse(file_path, "fasta"):
    data.append({"id": record.id, "sequence": str(record.seq)})

# 将数据转换为DataFrame
seqs = pd.DataFrame(data)
seqs = seqs.merge(pros[['#string_protein_id', 'preferred_name']], 
                                   left_on='id', 
                                   right_on='#string_protein_id', 
                                   how='left')

# 替换 protein1 列
seqs['id'] = seqs['preferred_name']

# 删除不需要的列
seqs.drop(columns=['#string_protein_id', 'preferred_name'], inplace=True)
print(seqs.shape)
seqs.head()

(19699, 2)


Unnamed: 0,id,sequence
0,ARF5,MGLTVSALFSRIFGKKQMRILMVGLDAAGKTTILYKLKLGEIVTTI...
1,M6PR,MFPFYSCWRTGLLLLLLAVAVRESWQTEEKTCDLVGEKGKESEKEL...
2,FKBP4,MTAEEMKATESGAQSAPLPMEGVDISPKQDEGVLKVIKREGTGTEM...
3,CYP26B1,MLFEGLDLVSALATLAACLVSVTLLLAVSQQLWQLRWAATRDKSCK...
4,NDUFAF7,MSVLLRSGLGPLCAVARAAIPFIWRGKYFSSGNEPAENPVTPMLRH...


In [36]:
list(G.nodes())

['ARF5',
 'M6PR',
 'FKBP4',
 'CYP26B1',
 'NDUFAF7',
 'FUCA2',
 'HS3ST1',
 'SEMA3F',
 'CFTR',
 'CYP51A1',
 'USP28',
 'SLC7A2',
 'HSPB6',
 'PDK4',
 'USH1C',
 'RALA',
 'BAIAP2L1',
 'CACNG3',
 'TMEM132A',
 'DVL2',
 'RPAP3',
 'SKAP2',
 'PRSS21',
 'HOXA11',
 'CX3CL1',
 'TRAPPC6A',
 'WDR54',
 'SPATA20',
 'CEACAM7',
 'RHBDD2',
 'TSR3',
 'OSBPL7',
 'YBX2',
 'KRT33A',
 'TFAP2D',
 'CRY1',
 'PGLYRP1',
 'STARD3NL',
 'CAMK1G',
 'CD74',
 'FAM76A',
 'CPA1',
 'SYPL1',
 'RANBP9',
 'CD4',
 'TSPAN9',
 'QPCTL',
 'PPP5C',
 'UBR7',
 'MAP4K5',
 'INMT',
 'ERCC1',
 'GPRC5A',
 'HEBP1',
 'COX15',
 'MS4A12',
 'RGPD5',
 'XYLT2',
 'SCTR',
 'SYT13',
 'SNAI2',
 'SLC7A9',
 'RTF2',
 'RB1CC1',
 'AKAP11',
 'PIGQ',
 'CDH17',
 'B4GALT7',
 'FAM13B',
 'CHPF2',
 'GABARAPL2',
 'MYOC',
 'OTC',
 'TTC17',
 'HOXC8',
 'MRI1',
 'BOD1L1',
 'TARBP1',
 'RTN4R',
 'PSMA4',
 'RIPOR3',
 'ZPBP',
 'LCP2',
 'DCN',
 'TNFRSF17',
 'MRPS10',
 'GUCA1A',
 'GRN',
 'THAP3',
 'VAMP3',
 'UTS2',
 'RCN1',
 'RFC2',
 'PPP1R3F',
 'NEXMIF',
 'ARHGEF5',
 'NFE2

In [37]:
import pandas as pd

# 读取 Excel 文件中的所有表格
file_path = '../manything/protein_net/41588_2018_130_MOESM6_ESM.xlsx'
sheets = pd.read_excel(file_path, sheet_name=None, header=None)

# 将每个表格存储到单独的 DataFrame
dataframes = {sheet_name: df for sheet_name, df in sheets.items()}
importants = []
# 打印每个 DataFrame 的名称和前几行
for name, df in dataframes.items():
    print(f"表格名称: {name}")
    print(len(set(df[0].tolist()) & set(list(G.nodes()))), len(df[0].tolist()))  # 打印前几行
    importants.append(df[0].tolist())


表格名称: FMRP
757 794
表格名称: CHM
396 408
表格名称: EMB
1747 1865
表格名称: PSD
1332 1395
表格名称: SFARI
866 881
表格名称: SFARIhq
140 141
表格名称: DN65
63 65


In [38]:
# samples_all = pd.read_csv('manything/ASD_use.csv')
# samples_all.head()

In [39]:
disease = []
control = []
for i in range(len(importants)):
    disease.append([[],[],[],[]])
    control.append([[],[],[],[]])

In [40]:
def get_nerb(node_a, G):
    if node_a in G:
        connected_nodes = list(G.neighbors(node_a))
    else:
        connected_nodes = []
    return connected_nodes

In [41]:
# def get_min_and_node(source_node, target_nodes, G):
#     shortest_distances = {}
#     for target in target_nodes:
#         if target in G:
#             distance = nx.shortest_path_length(G, source=source_node, target=target)
#             shortest_distances[target] = distance

#     if shortest_distances:
#         min_distance = min(shortest_distances.values())
#         closest_nodes = [node for node, distance in shortest_distances.items() if distance == min_distance]
#     else:
#         min_distance = None
#     return min_distance, closest_nodes

In [42]:
# distances = []
# for i in range(len(importants)):
#     distance = {}
#     for node in list(G.nodes):
#         distance[node] = get_min(node, importants[i], G)
#     distances.append(distance)
#     print(i)

In [43]:
# # 随机抽取 "Autism (ASD)" 和 "Sibling Control" 的行
# autism_samples = samples_all[samples_all['PrimaryPhenotype'] == 'Autism (ASD)'].sample(n=150)
# sibling_control_samples = samples_all[samples_all['PrimaryPhenotype'] == 'Sibling Control'].sample(n=150)

# # 合并样本
# samples = pd.concat([autism_samples, sibling_control_samples])
import pandas as pd
# # 打乱顺序
# samples = samples.sample(frac=1, random_state=1).reset_index(drop=True)
samples = pd.read_csv('../sequ/processed_psymukb.csv')
samples.head()

  samples = pd.read_csv('../sequ/processed_psymukb.csv')


Unnamed: 0.1,Unnamed: 0,EntrezID,cytoBand,Chr,Start,End,Ref,Alt,Variant,Gene.refGene,...,cDNA_var,protein_var,peptide,pep_len,protein_ori,protein_mut,protein_pos,ori_win51,mut_win51,mut_peptide
0,0,1778,14q32.31,14,102505753,102505753,A,C,A>C,DYNC1H1,...,c.A11465C,p.H3822P,MSEPGGGGGEDGSAGLEVSAVQNVADVSVLQKHLRKLVPLLLEDGG...,4646.0,H,P,3822,VSQQYLPLSTACSSIYFTMESLKQIHFLYQYSLQFFLDIYHNVLYE...,VSQQYLPLSTACSSIYFTMESLKQIPFLYQYSLQFFLDIYHNVLYE...,MSEPGGGGGEDGSAGLEVSAVQNVADVSVLQKHLRKLVPLLLEDGG...
1,1,148103,19q13.11,19,35251174,35251174,G,A,G>A,ZNF599,...,c.C532T,p.L178F,MAAPALALVSFEDVVVTFTGEEWGHLDLAQRTLYQEVMLETCRLLV...,588.0,L,F,178,KHDDLEPDDSLGLRVLQERVTPQDALHECDSQGPGKDPMTDARNNP...,KHDDLEPDDSLGLRVLQERVTPQDAFHECDSQGPGKDPMTDARNNP...,MAAPALALVSFEDVVVTFTGEEWGHLDLAQRTLYQEVMLETCRLLV...
2,2,116442,Xq28,X,154490173,154490173,G,A,G>A,RAB39B,...,c.G557T,p.W186L,MEAIWLYQFRLIVIGDSTVGKSCLIRRFTEGRFAQVSDPTVGVDFF...,213.0,W,L,186,EKAFTDLTRDIYELVKRGEITIQEGWEGVKSGFVPNVVHSSEEVVK...,EKAFTDLTRDIYELVKRGEITIQEGLEGVKSGFVPNVVHSSEEVVK...,MEAIWLYQFRLIVIGDSTVGKSCLIRRFTEGRFAQVSDPTVGVDFF...
3,3,7528,14q32.2,14,100743830,100743830,G,T,G>T,YY1,...,c.G1138T,p.D380Y,MASGDTLYIATDGSEMPAEIVELHEIEVETIPVETIETTVVGEEEE...,414.0,D,Y,380,CTFEGCGKRFSLDFNLRTHVRIHTGDRPYVCPFDGCNKKFAQSTNL...,CTFEGCGKRFSLDFNLRTHVRIHTGYRPYVCPFDGCNKKFAQSTNL...,MASGDTLYIATDGSEMPAEIVELHEIEVETIPVETIETTVVGEEEE...
4,4,128859,20q11.21,20,31626755,31626755,G,A,G>A,BPIFB6,...,c.G887A,p.R296H,MLRILCLALCSLLTGTRADPGALLRLGMDIMNQVQSAMDESHILEK...,453.0,R,H,296,QKSFHVNIQDTMIGELPPQTTKTLARFIPEVAVAYPKSKPLTTQIK...,QKSFHVNIQDTMIGELPPQTTKTLAHFIPEVAVAYPKSKPLTTQIK...,MLRILCLALCSLLTGTRADPGALLRLGMDIMNQVQSAMDESHILEK...


In [44]:
samples['PrimaryPhenotype'].value_counts(), samples.shape

(PrimaryPhenotype
 Autism (ASD)                                                          4622
 Developmental Delay (DD)                                              4339
 Uncharacterized (Mixed healthy control)                               2501
 Schizophrenia (SCZ)                                                   1656
 Sibling Control                                                       1502
 Intellectual Disability (ID)                                          1456
 Congenital Heart Disease (CHD)                                        1116
 Tourette Disorder (TD)                                                 347
 Epileptic Encephalopathies (EE)                                        306
 Fetal non-Preterm birth (non-PTB)                                      284
 Developmental and Epileptic Encephalopathies (DEE)                     235
 Infantile Spasms (IS)                                                  151
 Fetal preterm birth (PTB)                                            

In [45]:
samples_h = samples[samples['PrimaryPhenotype'] == 'Sibling Control']
samples_h['PrimaryPhenotype'].value_counts()

PrimaryPhenotype
Sibling Control    1502
Name: count, dtype: int64

In [46]:
samples_h.head()

Unnamed: 0.1,Unnamed: 0,EntrezID,cytoBand,Chr,Start,End,Ref,Alt,Variant,Gene.refGene,...,cDNA_var,protein_var,peptide,pep_len,protein_ori,protein_mut,protein_pos,ori_win51,mut_win51,mut_peptide
39,41,4065;100526664,2q24.2,2,160671934,160671934,T,G,T>G,LY75;LY75-CD302,...,c.T4531C,p.C1511R,MRTGWATPRRPAGLLMLLFWFFDLAEPSGRAANDPFTIVHGNTGKC...,1873.0,C,R,1511,GNCVLLDPKGTWKHEKCNSVKDGAICYKPTKSKKLSRLTYSSRCPA...,GNCVLLDPKGTWKHEKCNSVKDGAIRYKPTKSKKLSRLTYSSRCPA...,MRTGWATPRRPAGLLMLLFWFFDLAEPSGRAANDPFTIVHGNTGKC...
40,42,4645,18q21.1,18,47500753,47500753,A,G,A>G,MYO5B,...,c.A1289C,p.Q430P,MSVGELYSQCTRVWIPDPDEVWRSAELTKDYKEGDKSLQLRLEDET...,1848.0,Q,P,430,AKHIYAQLFGWIVEHINKALHTSLKQHSFIGVLDIYGFETFEVNSF...,AKHIYAQLFGWIVEHINKALHTSLKPHSFIGVLDIYGFETFEVNSF...,MSVGELYSQCTRVWIPDPDEVWRSAELTKDYKEGDKSLQLRLEDET...
41,43,10898,7q22.1,7,99048352,99048352,G,A,G>A,CPSF4,...,c.G431A,p.R144Q,MQEIIASVDHIKFDLEIAVEQQLGAQPLPFPGMDKSGAAVCEFFLK...,244.0,R,Q,144,SKIKDCPWYDRGFCKHGPLCRHRHTRRVICVNYLVGFCPEGPSCKF...,SKIKDCPWYDRGFCKHGPLCRHRHTQRVICVNYLVGFCPEGPSCKF...,MQEIIASVDHIKFDLEIAVEQQLGAQPLPFPGMDKSGAAVCEFFLK...
42,44,949,12q24.31,12,125298857,125298857,G,T,G>T,SCARB1,...,c.G521A,p.R174H,MGCSAKARWAAGALGVAGLLCAVLGAVMIVMVPSLIKQQVLKNVRI...,506.0,R,H,174,ENKPMTLKLIMTLAFTTLGERAFMNRTVGEIMWGYKDPLVNLINKY...,ENKPMTLKLIMTLAFTTLGERAFMNHTVGEIMWGYKDPLVNLINKY...,MGCSAKARWAAGALGVAGLLCAVLGAVMIVMVPSLIKQQVLKNVRI...
44,46,705,6p21.1,6,41897935,41897935,T,C,T>C,BYSL,...,c.T497C,p.M166T,MPKFKAARGVGGQEKHAPLADQILAGNAVRAGVREKRRGRGTGEAE...,437.0,M,T,166,PPARRTLADIIMEKLTEKQTEVETVMSEVSGFPMPQLDPRVLEVYR...,PPARRTLADIIMEKLTEKQTEVETVTSEVSGFPMPQLDPRVLEVYR...,MPKFKAARGVGGQEKHAPLADQILAGNAVRAGVREKRRGRGTGEAE...


In [47]:
samples_h.iloc[3]

Unnamed: 0                                                    44
EntrezID                                                     949
cytoBand                                                12q24.31
Chr                                                           12
Start                                                  125298857
                                     ...                        
protein_mut                                                    H
protein_pos                                                  174
ori_win51      ENKPMTLKLIMTLAFTTLGERAFMNRTVGEIMWGYKDPLVNLINKY...
mut_win51      ENKPMTLKLIMTLAFTTLGERAFMNHTVGEIMWGYKDPLVNLINKY...
mut_peptide    MGCSAKARWAAGALGVAGLLCAVLGAVMIVMVPSLIKQQVLKNVRI...
Name: 42, Length: 168, dtype: object

In [48]:
samples_h = samples_h.reset_index()
samples_h.head()

Unnamed: 0.1,index,Unnamed: 0,EntrezID,cytoBand,Chr,Start,End,Ref,Alt,Variant,...,cDNA_var,protein_var,peptide,pep_len,protein_ori,protein_mut,protein_pos,ori_win51,mut_win51,mut_peptide
0,39,41,4065;100526664,2q24.2,2,160671934,160671934,T,G,T>G,...,c.T4531C,p.C1511R,MRTGWATPRRPAGLLMLLFWFFDLAEPSGRAANDPFTIVHGNTGKC...,1873.0,C,R,1511,GNCVLLDPKGTWKHEKCNSVKDGAICYKPTKSKKLSRLTYSSRCPA...,GNCVLLDPKGTWKHEKCNSVKDGAIRYKPTKSKKLSRLTYSSRCPA...,MRTGWATPRRPAGLLMLLFWFFDLAEPSGRAANDPFTIVHGNTGKC...
1,40,42,4645,18q21.1,18,47500753,47500753,A,G,A>G,...,c.A1289C,p.Q430P,MSVGELYSQCTRVWIPDPDEVWRSAELTKDYKEGDKSLQLRLEDET...,1848.0,Q,P,430,AKHIYAQLFGWIVEHINKALHTSLKQHSFIGVLDIYGFETFEVNSF...,AKHIYAQLFGWIVEHINKALHTSLKPHSFIGVLDIYGFETFEVNSF...,MSVGELYSQCTRVWIPDPDEVWRSAELTKDYKEGDKSLQLRLEDET...
2,41,43,10898,7q22.1,7,99048352,99048352,G,A,G>A,...,c.G431A,p.R144Q,MQEIIASVDHIKFDLEIAVEQQLGAQPLPFPGMDKSGAAVCEFFLK...,244.0,R,Q,144,SKIKDCPWYDRGFCKHGPLCRHRHTRRVICVNYLVGFCPEGPSCKF...,SKIKDCPWYDRGFCKHGPLCRHRHTQRVICVNYLVGFCPEGPSCKF...,MQEIIASVDHIKFDLEIAVEQQLGAQPLPFPGMDKSGAAVCEFFLK...
3,42,44,949,12q24.31,12,125298857,125298857,G,T,G>T,...,c.G521A,p.R174H,MGCSAKARWAAGALGVAGLLCAVLGAVMIVMVPSLIKQQVLKNVRI...,506.0,R,H,174,ENKPMTLKLIMTLAFTTLGERAFMNRTVGEIMWGYKDPLVNLINKY...,ENKPMTLKLIMTLAFTTLGERAFMNHTVGEIMWGYKDPLVNLINKY...,MGCSAKARWAAGALGVAGLLCAVLGAVMIVMVPSLIKQQVLKNVRI...
4,44,46,705,6p21.1,6,41897935,41897935,T,C,T>C,...,c.T497C,p.M166T,MPKFKAARGVGGQEKHAPLADQILAGNAVRAGVREKRRGRGTGEAE...,437.0,M,T,166,PPARRTLADIIMEKLTEKQTEVETVMSEVSGFPMPQLDPRVLEVYR...,PPARRTLADIIMEKLTEKQTEVETVTSEVSGFPMPQLDPRVLEVYR...,MPKFKAARGVGGQEKHAPLADQILAGNAVRAGVREKRRGRGTGEAE...


In [49]:
samples_h.iloc[1030]

index                                                      16430
Unnamed: 0                                                 17854
EntrezID                                                  196403
cytoBand                                                 12q13.3
Chr                                                           12
                                     ...                        
protein_mut                                                    L
protein_pos                                                  217
ori_win51      ITRALQVKKACPMCGRFYGQLVGNQPQNGRMLVSKDATLLLPSYEK...
mut_win51      ITRALQVKKACPMCGRFYGQLVGNQLQNGRMLVSKDATLLLPSYEK...
mut_peptide    MPILSSSGSKMAACGGTCKNKVTVSKPVWDFLSKETPARLARLREE...
Name: 1030, Length: 169, dtype: object

In [None]:
import json

n_c = 0
with open('../manything/disease_onlyp_Control.json', 'r') as f:
    diseases = json.load(f)
n_d = 1031
controls = []
for index in range(n_d, len(samples_h)):
    disease_now = [[] for i in range(4)]
    control_now = [[] for i in range(4)]
    batch = samples_h.iloc[index]
    if batch['gene_symbol'] not in list(G.nodes):
        continue
    if len(batch["peptide"]) > 10000 :
        continue
    # for numm in range(len(importants)):
    #     if batch['gene_symbol'] in importants[numm]:
    #         continue
    nnnn = 0
    nodes = get_nerb(batch['gene_symbol'], G)
    for protein in nodes:
        if seqs[seqs['id'] == protein].empty:
            continue
        if len(seqs[seqs['id'] == protein].iloc[0]['sequence']) > 10000:
            continue
        predicted_dif, _, _, _, _ = predict([batch["protein_pos"]], [batch["peptide"]], [batch["mut_peptide"]], [seqs[seqs['id'] == protein].iloc[0]['sequence']])
        if True:
            disease_now[predicted_dif.item()].append(protein)
            nnnn = nnnn + 1
        else:
            control_now[predicted_dif.item()].append(protein)
            nnnn = nnnn + 1
        print_progress_bar(nnnn, len(nodes))
    if True:
        # disease[numm][predicted_dif.item()].append(distances[numm][protein])
        diseases.append({'mutation': batch['gene_symbol'],'nab':disease_now})
        n_d = n_d + 1
    else:
        # control[numm][predicted_dif.item()].append(distances[numm][protein])
        controls.append({'mutation': batch['gene_symbol'],'nab':control_now})
        n_c = n_c + 1
    # diseases.append(disease_now)
    # controls.append(control_now)
    if index % 10 == 0:
        with open('../manything/disease_onlyp_Control1.json', 'w') as f:
            json.dump(diseases, f)
        # with open('../manything/healthy_oldPPI_CDH.json', 'w') as f:
        #     json.dump(controls, f)
        # with open('manything/disease_now4.json', 'w') as f:
        #     json.dump(diseases, f)

        with open('../manything/count_onlyp_Control1.json', 'w') as f:
            json.dump({'disease': n_d, 'control': n_c}, f)

with open('../manything/disease_onlyp_Control1.json', 'w') as f:
    json.dump(diseases, f)
# with open('../manything/healthy_oldPPI_CDH.json', 'w') as f:
#     json.dump(controls, f)
# with open('manything/disease_now4.json', 'w') as f:
#     json.dump(diseases, f)

with open('../manything/count_onlyp_Control1.json', 'w') as f:
    json.dump({'disease': n_d, 'control': n_c}, f)

|██████████------------------------------| 27.36% Complete

In [52]:
with open('../manything/disease_onlyp_Control2.json', 'w') as f:
    json.dump(diseases, f)
# with open('../manything/healthy_oldPPI_CDH.json', 'w') as f:
#     json.dump(controls, f)
# with open('manything/disease_now4.json', 'w') as f:
#     json.dump(diseases, f)

with open('../manything/count_onlyp_Control2.json', 'w') as f:
    json.dump({'disease': n_d, 'control': n_c}, f)

In [None]:
# import json
# n_d = 0
# n_c = 0
# diseases = []
# controls = []
# for index in range(len(samples_h)):
#     disease_now = [[] for i in range(4)]
#     control_now = [[] for i in range(4)]
#     batch = samples_h.iloc[index]
#     if batch['gene_symbol'] not in list(G.nodes):
#         continue
#     if len(batch["peptide"]) > 10000 :
#         continue
#     # for numm in range(len(importants)):
#     #     if batch['gene_symbol'] in importants[numm]:
#     #         continue
#     nnnn = 0
#     nodes = get_nerb(batch['gene_symbol'], G)
#     for protein in nodes:
#         if seqs[seqs['id'] == protein].empty:
#             continue
#         if len(seqs[seqs['id'] == protein].iloc[0]['sequence']) > 10000:
#             continue
#         predicted_dif, _, _, _, _ = predict([batch["protein_pos"]], [batch["peptide"]], [batch["mut_peptide"]], [seqs[seqs['id'] == protein].iloc[0]['sequence']])
#         if True:
#             disease_now[predicted_dif.item()].append(protein)
#             nnnn = nnnn + 1
#         else:
#             control_now[predicted_dif.item()].append(protein)
#             nnnn = nnnn + 1
#         print_progress_bar(nnnn, len(nodes))
#     if True:
#         # disease[numm][predicted_dif.item()].append(distances[numm][protein])
#         diseases.append({'mutation': batch['gene_symbol'],'nab':disease_now})
#         n_d = n_d + 1
#     else:
#         # control[numm][predicted_dif.item()].append(distances[numm][protein])
#         controls.append({'mutation': batch['gene_symbol'],'nab':control_now})
#         n_c = n_c + 1
#     # diseases.append(disease_now)
#     # controls.append(control_now)
#     if index % 10 == 0:
#         with open('../manything/disease_onlyp_Control.json', 'w') as f:
#             json.dump(diseases, f)
#         # with open('../manything/healthy_oldPPI_CDH.json', 'w') as f:
#         #     json.dump(controls, f)
#         # with open('manything/disease_now4.json', 'w') as f:
#         #     json.dump(diseases, f)

#         with open('../manything/count_onlyp_Control.json', 'w') as f:
#             json.dump({'disease': n_d, 'control': n_c}, f)

# with open('../manything/disease_onlyp_Control.json', 'w') as f:
#     json.dump(diseases, f)
# # with open('../manything/healthy_oldPPI_CDH.json', 'w') as f:
# #     json.dump(controls, f)
# # with open('manything/disease_now4.json', 'w') as f:
# #     json.dump(diseases, f)

# with open('../manything/count_onlyp_Control.json', 'w') as f:
#     json.dump({'disease': n_d, 'control': n_c}, f)