自定义修改预训练模型层级的5种方法

In [None]:
# 直接层替换
from transformers import BertModel 
model = BertModel.from_pretrained("bert-base-uncased") 
 
# 替换第3个Transformer层 
from torch import nn 
class CustomLayer(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.linear  = nn.Linear(hidden_size, hidden_size*2)
        
model.encoder.layer[2]  = CustomLayer(model.config.hidden_size)   # 替换第3层 

In [None]:
# 参数冻结+微调
# 冻结前6层，只训练后6层 
for i, layer in enumerate(model.encoder.layer): 
    if i < 6:
        for param in layer.parameters(): 
            param.requires_grad  = False 


In [None]:
# 插入适配器模块
import torch 
import torch.nn  as nn 
import torch.nn.functional  as F 
 
class Adapter(nn.Module):
    def __init__(self, dim, reduction=4):
        super().__init__()
        self.down  = nn.Linear(dim, dim//reduction)  # 降维 
        self.up  = nn.Linear(dim//reduction, dim)    # 升维 
        self.act  = nn.GELU()
        self.scale  = nn.Parameter(torch.ones(1))     # 可学习缩放系数 
 
    def forward(self, x):
        return x + self.scale  * self.up(self.act(self.down(x))) 
    
from transformers import BertModel 
 
model = BertModel.from_pretrained("bert-base-uncased") 
 
# 为每层添加适配器 
for layer in model.encoder.layer: 
    layer.adapter  = Adapter(model.config.hidden_size) 
 
# 冻结主干参数 
for param in model.parameters(): 
    param.requires_grad  = False 
for layer in model.encoder.layer: 
    layer.adapter.requires_grad_(True)   # 仅训练适配器    

optimizer = torch.optim.AdamW( 
    filter(lambda p: p.requires_grad,  model.parameters()),  
    lr=1e-3,
    weight_decay=0.01 
)
 
# 混合精度训练 
scaler = torch.cuda.amp.GradScaler() 
with torch.amp.autocast(device_type='cuda'): 
    outputs = model(input_ids)
    loss = outputs.loss  
scaler.scale(loss).backward() 
scaler.step(optimizer) 
scaler.update() 

# 动态适配器路由
class DynamicAdapter(nn.Module):
    def __init__(self, dim, num_adapters=4):
        super().__init__()
        self.adapters  = nn.ModuleList([Adapter(dim) for _ in range(num_adapters)])
        self.gate  = nn.Linear(dim, num_adapters)
 
    def forward(self, x):
        gate_scores = F.softmax(self.gate(x.mean(dim=1)),  dim=-1)  # [B, num_adapters]
        return sum(score * adapter(x) for score, adapter in zip(gate_scores, self.adapters)) 

# 低秩适配器（LoRA） 
class LoRA(nn.Module):
    def __init__(self, in_dim, out_dim, rank=8):
        super().__init__()
        self.A = nn.Parameter(torch.randn(in_dim,  rank))
        self.B = nn.Parameter(torch.zeros(rank,  out_dim))
        self.scale  = 1.0 / rank 
 
    def forward(self, x):
        return x @ (self.A @ self.B) * self.scale  
  

In [None]:
# 动态权重混合
# 混合原始参数与新参数 
original_weight = model.embeddings.word_embeddings.weight  
custom_weight = torch.randn_like(original_weight) 
model.embeddings.word_embeddings.weight  = nn.Parameter( 
    0.3*original_weight + 0.7*custom_weight)

In [None]:
# 结构重参数化
import torch 
import torch.nn  as nn 
import torch.nn.functional  as F 
 
class RepBlock(nn.Module):
    """训练阶段的多分支结构"""
    def __init__(self, in_channels):
        super().__init__()
        # 分支1：3x3卷积 
        self.conv3x3  = nn.Conv2d(in_channels, in_channels, 3, padding=1, bias=False)
        # 分支2：1x1卷积 
        self.conv1x1  = nn.Conv2d(in_channels, in_channels, 1, bias=False)
        # 分支3：Identity 
        self.bn  = nn.BatchNorm2d(in_channels)
        
    def forward(self, x):
        return self.bn(self.conv3x3(x)  + self.conv1x1(x)  + x)  # 三路求和 
 
    def reparameterize(self):
        """转换为推理结构"""
        # 融合卷积权重 
        fused_weight = self.conv3x3.weight  + F.pad(self.conv1x1.weight,  [1,1,1,1])
        # 融合BN参数 
        fused_conv = nn.Conv2d(
            self.conv3x3.in_channels,  
            self.conv3x3.out_channels, 
            kernel_size=3,
            padding=1,
            bias=True 
        )
        # 计算融合后的权重与偏置 
        fused_conv.weight.data  = fused_weight 
        fused_conv.bias.data  = self.bn.bias  - (self.bn.weight  * self.bn.running_mean)  / torch.sqrt(self.bn.running_var  + self.bn.eps) 
        return fused_conv 
 
class RepResNet(nn.Module):
    """完整模型示例"""
    def __init__(self):
        super().__init__()
        self.stage1  = nn.Sequential(
            nn.Conv2d(3, 64, 7, stride=2, padding=3),
            RepBlock(64),
            RepBlock(64)
        )
        # ... 其他层定义 
        
    def forward(self, x):
        return self.stage1(x) 
        
    def deploy(self):
        """转换整个模型为推理模式"""
        for name, module in self.named_children(): 
            if isinstance(module, RepBlock):
                setattr(self, name, module.reparameterize()) 
        return self 
    
# 训练阶段 
model = RepResNet()
train_output = model(torch.randn(1,3,224,224)) 
 
# 部署阶段 
model.eval()   # 必须先设为eval模式 
deployed_model = model.deploy() 
deployed_output = deployed_model(torch.randn(1,3,224,224)) 
 
# 验证输出一致性 
print(torch.allclose(train_output,  deployed_output, atol=1e-5))  # 应输出True     