In [62]:
import torch
import torch.nn as nn

import copy


# 定义预训练模型
class TransformerEncoder(nn.Module):
    def __init__(
        self,
        activation,
        dim_feedforward=2048,
        nhead=8,
        num_encoder_layers=6,
        dropout=0.1,
    ):
        super().__init__()
        self.activation = activation
        self.dim_feedforward = dim_feedforward
        self.num_encoder_layers = num_encoder_layers
        encoder_layer = Encoder_Layer(
            dim_feedforward,  # 512
            nhead,  # 8
            dropout,  # 0.1
            activation,  # gelu
        )
        # 设置多层encoder_layer
        self.encoder = nn.ModuleList(
            [copy.deepcopy(encoder_layer) for _ in range(num_encoder_layers)]
        )

        self.nhead = nhead

    def forward(self, src):
        out = src
        for module in self.encoder:
            out = module(out)

        return out


# 定义encoder_layer
class Encoder_Layer(nn.Module):
    def __init__(
        self,
        dim_feedforward,
        nhead,  # 8
        dropout,
        activation,
    ):
        super().__init__()

        self.dropout = nn.Dropout(dropout)
        self.self_attn = nn.MultiheadAttention(dim_feedforward, nhead, dropout=dropout)
        self.fc = nn.Sequential(
            nn.Linear(dim_feedforward, 2048),
            activation,
            nn.Dropout(dropout),
            nn.Linear(2048, dim_feedforward),
            nn.Dropout(dropout),
        )
        self.norm1 = nn.LayerNorm(dim_feedforward)
        self.norm2 = nn.LayerNorm(dim_feedforward)

        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        src,  # 3,32,512
    ):

        # 自注意力
        src2, corr = self.self_attn(
            query=src,
            key=src,
            value=src,
        )

        src = src + self.dropout(src2)
        src = self.norm1(src)
        src2 = self.fc(src)
        src = src + src2
        src = self.norm2(src)
        return src

In [31]:
model = TransformerEncoder(
    nn.GELU(), dim_feedforward=512, nhead=8, num_encoder_layers=6, dropout=0.1
)
src = torch.rand(3, 32, 512)
output = model(src)
print(output.shape)
torch.save(model.state_dict(), "demo_model.pth")

torch.Size([3, 32, 512])


In [32]:
class Finetune(nn.Module):
    def __init__(self, dim_feedforward, nhead, dropout, class_num, num_encoder_layers):
        super().__init__()
        # 载入预训练模型
        pretrain_model = TransformerEncoder(
            nn.GELU(),
            dim_feedforward=dim_feedforward,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            dropout=dropout,
        )
        pretrain_model.load_state_dict(torch.load("demo_model.pth"))
        self.pretrain = pretrain_model

        self.fc = nn.Sequential(
            nn.Linear(dim_feedforward * 3, dim_feedforward),
            nn.Linear(dim_feedforward, class_num),
            nn.Sigmoid(),
        )

    def forward(
        self,
        input,  # 3,32,512
    ):

        output = self.pretrain(input)
        output = torch.einsum("LBD->BLD", output).flatten(1)
        output = self.fc(output)
        return output

In [33]:
finetune_model = Finetune(512, 8, 0.1, 10, 6)
input = torch.rand(3, 32, 512)
output = finetune_model(input)
print(output.shape)

torch.Size([32, 10])


In [70]:
class TransformerEncoder(nn.Module):
    def __init__(
        self,
        activation,
        dim_feedforward=2048,
        nhead=8,
        num_encoder_layers=6,
        dropout=0.1,
    ):
        super().__init__()
        self.activation = activation
        self.dim_feedforward = dim_feedforward
        self.num_encoder_layers = num_encoder_layers
        encoder_layer = Encoder_Layer(
            dim_feedforward,  # 512
            nhead,  # 8
            dropout,  # 0.1
            activation,  # gelu
        )
        # 设置多层encoder_layer
        self.encoder = nn.ModuleList(
            [copy.deepcopy(encoder_layer) for _ in range(num_encoder_layers)]
        )

        self.nhead = nhead

    def forward(self, src):
        out = src
        for module in self.encoder:
            out = module(out)

        return out

In [75]:
class Finetune(TransformerEncoder):
    def __init__(
        self, activation, dim_feedforward, nhead, num_encoder_layers, dropout, class_num
    ):
        super().__init__(
            activation, dim_feedforward, nhead, num_encoder_layers, dropout
        )
        self.num_encoder_layers = num_encoder_layers
        # 载入预训练模型
        cross_attention = nn.MultiheadAttention(dim_feedforward, nhead, dropout=dropout)
        self.cross_layer1 = nn.ModuleList(
            [copy.deepcopy(cross_attention) for _ in range(num_encoder_layers)]
        )

    def forward(self, input1, input2):
        output = input1
        for num in self.num_encoder_layers:
            output = self.cross_layer[num](output, input2, input2)
            output = self.encoder[num](output)
        return output

In [82]:
finetune_model = Finetune(nn.GELU(), 512, 8, 6, 0.1, 10)
pretrained_state_dict = torch.load("demo_model.pth")
model_state_dict = finetune_model.state_dict()

# 过滤掉不匹配的键
pretrained_state_dict = {
    k: v for k, v in pretrained_state_dict.items() if k in model_state_dict
}

# 更新现有的state_dict
model_state_dict.update(pretrained_state_dict)

# 加载更新后的state_dict
finetune_model.load_state_dict(model_state_dict)

<All keys matched successfully>

In [81]:
# 假设 finetune_model 是已经加载了预训练权重的模型

# 冻结原模型（TransformerEncoder）的参数
for param in finetune_model.encoder.parameters():
    param.requires_grad = False

# 只训练新添加的参数
# 例如 cross_layer1 和其他可能的新参数
optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, finetune_model.parameters()), lr=1e-4
)

# 检查哪些参数会被训练
for name, param in finetune_model.named_parameters():
    print(f"{name}: {'Yes' if param.requires_grad else 'No'}")

encoder.0.self_attn.in_proj_weight: No
encoder.0.self_attn.in_proj_bias: No
encoder.0.self_attn.out_proj.weight: No
encoder.0.self_attn.out_proj.bias: No
encoder.0.fc.0.weight: No
encoder.0.fc.0.bias: No
encoder.0.fc.3.weight: No
encoder.0.fc.3.bias: No
encoder.0.norm1.weight: No
encoder.0.norm1.bias: No
encoder.0.norm2.weight: No
encoder.0.norm2.bias: No
encoder.1.self_attn.in_proj_weight: No
encoder.1.self_attn.in_proj_bias: No
encoder.1.self_attn.out_proj.weight: No
encoder.1.self_attn.out_proj.bias: No
encoder.1.fc.0.weight: No
encoder.1.fc.0.bias: No
encoder.1.fc.3.weight: No
encoder.1.fc.3.bias: No
encoder.1.norm1.weight: No
encoder.1.norm1.bias: No
encoder.1.norm2.weight: No
encoder.1.norm2.bias: No
encoder.2.self_attn.in_proj_weight: No
encoder.2.self_attn.in_proj_bias: No
encoder.2.self_attn.out_proj.weight: No
encoder.2.self_attn.out_proj.bias: No
encoder.2.fc.0.weight: No
encoder.2.fc.0.bias: No
encoder.2.fc.3.weight: No
encoder.2.fc.3.bias: No
encoder.2.norm1.weight: No
enc

In [67]:
class Animal:
    def __init__(self, name):
        pass

    def speak(self):
        return "Some sound"

In [68]:
class Cat(Animal):
    def __init__(self, name, color):
        super().__init__(name)  # 调用父类的构造方法
        self.color = color  # 新属性

    def speak(self):
        return "Meow!"

    def purr(self):  # 新方法
        return "Purr..."

In [69]:
xianluo = Cat("xianluo", "black")

In [79]:
import torch.nn as nn


# 假设这是原始模型类
class OriginalModel(nn.Module):
    def __init__(self):
        super(OriginalModel, self).__init__()
        self.layer1 = nn.Linear(10, 20)
        self.layer2 = nn.Linear(20, 10)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        return x


# 新模型类，继承自原始模型
class ModifiedModel(OriginalModel):
    def __init__(self):
        super(ModifiedModel, self).__init__()

    # 重写forward函数
    def forward(self, x):
        x = self.layer1(x)
        # 在这里加入新的逻辑
        x = torch.relu(x)
        x = self.layer2(x)
        return x


# 加载原模型的权重
original_model = OriginalModel()
# 假设 original_model 已经训练好并保存
torch.save(original_model.state_dict(), "model.pth")

# 创建新模型实例
modified_model = ModifiedModel()
modified_model.load_state_dict(torch.load("model.pth"))

# 使用修改后的模型
input_tensor = torch.randn(1, 10)
output = modified_model(input_tensor)

In [109]:
import torch
import sys

sys.path.append("/home/Kioedru/code/SSGO")
ppi_feature_pre_model = torch.load(
    "/home/Kioedru/code/SSGO/codespace/pretrain/bimamba_seq1024/9606/bimamba_seq1024.pkl",
    map_location="cuda:0",
)
print(ppi_feature_pre_model.transformerEncoder.encoder.layers)
# 假设 finetune_model 是已经加载了预训练权重的模型

# 冻结原模型（TransformerEncoder）的参数
for layer in ppi_feature_pre_model.transformerEncoder.encoder.layers:
    for param in layer.cross_attn.parameters():
        param.requires_grad = True
# for param in ppi_feature_pre_model.parameters():
#         param.requires_grad = False
#     # print(param)
#     param.requires_grad = False

# # 只训练新添加的参数
# # 例如 cross_layer1 和其他可能的新参数
# optimizer = torch.optim.Adam(
#     filter(lambda p: p.requires_grad, finetune_model.parameters()), lr=1e-4
# )

# 检查哪些参数会被训练
for name, param in ppi_feature_pre_model.named_parameters():
    print(f"{name}: {'Yes' if param.requires_grad else 'No'}")

ModuleList(
  (0-5): 6 x TransformerEncoderLayer(
    (dropout): Dropout(p=0.1, inplace=False)
    (self_attn): BiMamba(
      (in_proj): Linear(in_features=1, out_features=4, bias=False)
      (conv1d): Conv1d(2, 2, kernel_size=(4,), stride=(1,), padding=(3,), groups=2)
      (conv1d_b): Conv1d(2, 2, kernel_size=(4,), stride=(1,), padding=(3,), groups=2)
      (act): SiLU()
      (x_proj): Linear(in_features=2, out_features=33, bias=False)
      (x_proj_b): Linear(in_features=2, out_features=33, bias=False)
      (dt_proj): Linear(in_features=1, out_features=2, bias=True)
      (dt_proj_b): Linear(in_features=1, out_features=2, bias=True)
      (out_proj): Linear(in_features=2, out_features=1, bias=False)
    )
    (linear1): Linear(in_features=512, out_features=2048, bias=True)
    (linear2): Linear(in_features=2048, out_features=512, bias=True)
    (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (d

In [100]:
torch.save(ppi_feature_pre_model.state_dict(), "ppi_feature_pre_model.pth")

In [87]:
import copy
from typing import Optional, List

import torch.nn.functional as F
from torch import nn, Tensor
from torch.nn import MultiheadAttention