In [5]:
import torch
import torch.nn as nn
import math

# from codespace.model.multihead_attention_transformer import _get_activation_fn
from mamba_ssm import Mamba


class FC_Decoder(nn.Module):
    def __init__(self, num_class, dim_feedforward, dropout, input_num=3):
        super().__init__()
        self.num_class = num_class

        self.output_layer1 = nn.Linear(
            dim_feedforward * input_num, dim_feedforward // input_num
        )
        self.activation1 = torch.nn.Sigmoid()
        self.dropout1 = torch.nn.Dropout(dropout)

        self.output_layer3 = nn.Linear(2 * (dim_feedforward // input_num), num_class)
        self.residue_attn = Mamba(d_model=480, d_state=16, d_conv=4, expand=2)
        self.residue_linear1 = nn.Linear(480, dim_feedforward // input_num)
        self.output_layer2 = nn.Linear(2 * (dim_feedforward // input_num), num_class)

    def forward(self, hs, residue):  # hs[3, 32, 512] residue[32,2000,480]
        residue_score = self.residue_attn(residue)  # residue_score[32,2000,480]
        residue_score = nn.functional.adaptive_avg_pool1d(
            residue_score, output_size=1
        )  # residue_score[32,2000,1]
        residue_score = nn.functional.softmax(
            residue_score, dim=1
        )  # residue_score[32,2000,1]
        residue = torch.sum(residue * residue_score, dim=1)  # [32,480]

        residue = self.residue_linear1(residue)  # residue[32,512//3]

        # 维度转换 第0维和第1维互换
        hs = hs.permute(1, 0, 2)  # [32, 3, 512]
        # 按第一维度展开
        hs = hs.flatten(1)  # [32,512*3]

        hs = self.output_layer1(hs)  # [32,512//3]

        conca_hs_residue = torch.cat((hs, residue), dim=1)
        # sigmoid
        conca_hs_residue = self.activation1(conca_hs_residue)
        conca_hs_residue = self.dropout1(conca_hs_residue)
        # (512//3,GO标签数)
        out = self.output_layer3(conca_hs_residue)
        return out


if __name__ == "__main__":
    hs = torch.rand(3, 32, 512).to("cuda:0")
    residue = torch.rand(32, 2000, 480).to("cuda:0")
    model = FC_Decoder(45, 512, 0.1).to("cuda:0")
    out = model(hs, residue)
    print(out.shape)

torch.Size([32, 45])


In [None]:
def read_residue(usefor, aspect, model_name, organism_num):
    residue_name = f"{usefor}_residue_{aspect}.pkl"
    file_path = os.path.join(
        finetune_data_path, organism_num, f"residue_{model_name}", residue_name
    )

    residue = pd.read_pickle(file_path)
    # 找到张量的最小值和最大值
    residue_min = residue.min(dim=2, keepdim=True).values
    residue_max = residue.max(dim=2, keepdim=True).values

    # 执行 min-max 归一化
    residue = (residue - residue_min) / (residue_max - residue_min)
    return residue

In [7]:
import os
import pandas as pd
import torch

In [4]:
dataset_path_in_kioedru = "/home/kioedru/code/SSGO/data"
dataset_path_in_Kioedru = "/home/Kioedru/code/SSGO/data"

if os.path.exists(dataset_path_in_kioedru):
    dataset_path = dataset_path_in_kioedru
else:
    dataset_path = dataset_path_in_Kioedru

finetune_data_path = os.path.join(dataset_path, "finetune")

In [10]:
usefor = "train"
aspect = "P"
model_name = "esm2"
organism_num = "9606"
residue_name = f"{usefor}_residue_{aspect}.pkl"
file_path = os.path.join(
    finetune_data_path, organism_num, f"residue_{model_name}", residue_name
)

residue = pd.read_pickle(file_path)
# # 找到张量的最小值和最大值
# residue_min = residue.min(dim=2, keepdim=True).values
# print(residue_min)
# residue_max = residue.max(dim=2, keepdim=True).values

# # 执行 min-max 归一化
# residue = (residue - residue_min) / (residue_max - residue_min)
layernorm = torch.nn.LayerNorm(480)
residue = layernorm(residue)
print(torch.isnan(residue).any())

tensor(False)


In [5]:
import torch
import torch.nn as nn


class FC_Decoder(nn.Module):
    def __init__(self, num_class, dim_feedforward, dropout, input_num=6):
        super().__init__()
        self.num_class = num_class

        self.output_layer1 = nn.Linear(
            dim_feedforward * input_num, dim_feedforward * int(input_num / 2)
        )
        self.activation1 = nn.GELU()
        self.dropout1 = torch.nn.Dropout(dropout)

        self.output_layer2 = nn.Linear(
            dim_feedforward * int(input_num / 2),
            dim_feedforward // int((input_num / 2)),
        )
        self.activation2 = torch.nn.Sigmoid()
        self.dropout2 = torch.nn.Dropout(dropout)

        self.output_layer3 = nn.Linear(dim_feedforward // int(input_num / 2), num_class)

        self.hs_transformer_linear = nn.Linear(3, 1)
        self.hs_mamba_linear = nn.Linear(3, 1)
        self.hs_P_linear = nn.Linear(2, 1)
        self.hs_F_linear = nn.Linear(2, 1)
        self.hs_C_linear = nn.Linear(2, 1)

    def forward(self, hs):  # hs[6, 32, 512] 前三个是transformer ， 后三个是mamba
        # 维度转换 第0维和第1维互换
        hs = hs.permute(1, 0, 2)  # [32, 6, 512]
        avg_hs = nn.functional.adaptive_avg_pool1d(hs, output_size=1)  # avg_hs[32,6,1]
        # 分成前面一半 hs_transformer [32, 3]
        hs_transformer = avg_hs[:, :3, 0]
        # 分成后面一半 hs_mamba [32, 3]
        hs_mamba = avg_hs[:, 3:, 0]
        hs_transformer = self.hs_transformer_linear(hs_transformer)  # [32,1]
        hs_mamba = self.hs_mamba_linear(hs_mamba)  # [32,1]
        hs_encoder = torch.cat((hs_transformer, hs_mamba), dim=1)  # [32,2]
        hs_encoder = nn.functional.softmax(hs_encoder, dim=1)  # [32,2]

        # 分别提取所需的列并移除第三个维度
        hs_P = torch.cat(
            (avg_hs[:, 0, 0].unsqueeze(1), avg_hs[:, 3, 0].unsqueeze(1)), dim=1
        )
        hs_F = torch.cat(
            (avg_hs[:, 1, 0].unsqueeze(1), avg_hs[:, 4, 0].unsqueeze(1)), dim=1
        )
        hs_C = torch.cat(
            (avg_hs[:, 2, 0].unsqueeze(1), avg_hs[:, 5, 0].unsqueeze(1)), dim=1
        )
        hs_P = self.hs_P_linear(hs_P)  # [32,1]
        hs_F = self.hs_F_linear(hs_F)  # [32,1]
        hs_C = self.hs_C_linear(hs_C)  # [32,1]
        hs_aspect = torch.cat((hs_P, hs_F, hs_C), dim=1)  # [32,3]
        hs_aspect = nn.functional.softmax(hs_aspect, dim=1)  # [32,3]

        hs_transformer = hs_encoder[:, 0].unsqueeze(1)  # [32,1]
        hs_mamba = hs_encoder[:, 1].unsqueeze(1)  # [32,1]

        hs_transformer_aspect = hs_transformer * hs_aspect  # [32,3]
        hs_mamba_aspect = hs_mamba * hs_aspect  # [32,3]
        sig_hs = torch.cat((hs_transformer_aspect, hs_mamba_aspect), dim=1)  # [32, 6]
        sig_hs = sig_hs.unsqueeze(2)  # [32,6,1]
        hs = sig_hs * hs  # [32,6,512]

        # 按第一维度展开
        hs = hs.flatten(1)  # [32,512*6]
        # 默认(512*2,512//2)，//表示下取整
        hs = self.output_layer1(hs)  # [32,512*(6/2)]
        hs = self.activation1(hs)
        hs = self.dropout1(hs)

        hs = self.output_layer2(hs)  # [32,512//(6/2)]
        hs = self.activation2(hs)
        hs = self.dropout2(hs)

        # (512//2,GO标签数)
        out = self.output_layer3(hs)
        # 后面还需要一个sigmoid，在测试输出后面直接加了
        return out


if __name__ == "__main__":
    hs = torch.rand(6, 32, 512).to("cuda:0")
    model = FC_Decoder(45, 512, 0.1).to("cuda:0")
    out = model(hs)
    print(out.shape)

torch.Size([32, 45])


In [2]:
import torch

# 假设已经有形状为 [32, 2] 的张量 a 和形状为 [32, 3] 的张量 b
a = torch.randn(32, 2)
b = torch.randn(32, 3)

# 分别提取 a 的第0列和第1列
a0 = a[:, 0].unsqueeze(1)  # 形状 [32, 1]
a1 = a[:, 1].unsqueeze(1)  # 形状 [32, 1]

# 计算 a 的第0列与 b 的点乘
result0 = a0 * b  # 形状 [32, 3]

# 计算 a 的第1列与 b 的点乘
result1 = a1 * b  # 形状 [32, 3]

# 拼接结果
result = torch.cat((result0, result1), dim=1)  # 形状 [32, 6]

# 打印结果以检查形状
print("result shape:", result.shape)  # 应该是 [32, 6]

result shape: torch.Size([32, 6])
