In [1]:
import torch
from vit import ViT
import numpy as np

In [2]:
def load_official_transformer_params(custom_model_state_dict, official_model_state_dict, dim = 384):
    dim = dim
    param_mapping = {}
    for layer_index in range(12):  # 假设有12层
        layer_param_mapping = {
            f"transformer.layers.{layer_index}.0.residual.norm.weight": f"Transformer/encoderblock_{layer_index}/LayerNorm_0/scale",
            f"transformer.layers.{layer_index}.0.residual.norm.bias": f"Transformer/encoderblock_{layer_index}/LayerNorm_0/bias",
            # 根据需要添加每一层的其他参数映射...
            f"transformer.layers.{layer_index}.0.residual.norm.weight":        f"Transformer/encoderblock_{layer_index}/LayerNorm_0/scale",
            f"transformer.layers.{layer_index}.0.residual.norm.bias":          f"Transformer/encoderblock_{layer_index}/LayerNorm_0/bias",
            f"transformer.layers.{layer_index}.0.residual.fn.to_q.weight":     f"Transformer/encoderblock_{layer_index}/MultiHeadDotProductAttention_1/query/kernel",
            f"transformer.layers.{layer_index}.0.residual.fn.to_k.weight":     f"Transformer/encoderblock_{layer_index}/MultiHeadDotProductAttention_1/key/kernel",
            f"transformer.layers.{layer_index}.0.residual.fn.to_v.weight":     f"Transformer/encoderblock_{layer_index}/MultiHeadDotProductAttention_1/value/kernel",
            f"transformer.layers.{layer_index}.0.residual.fn.to_out.0.weight": f"Transformer/encoderblock_{layer_index}/MultiHeadDotProductAttention_1/out/kernel",
            f"transformer.layers.{layer_index}.0.residual.fn.to_out.0.bias":   f"Transformer/encoderblock_{layer_index}/MultiHeadDotProductAttention_1/out/bias",
            f"transformer.layers.{layer_index}.1.residual.norm.weight":        f"Transformer/encoderblock_{layer_index}/LayerNorm_2/scale",
            f"transformer.layers.{layer_index}.1.residual.norm.bias":          f"Transformer/encoderblock_{layer_index}/LayerNorm_2/bias",
            f"transformer.layers.{layer_index}.1.residual.fn.FFN1.0.weight":   f"Transformer/encoderblock_{layer_index}/MlpBlock_3/Dense_0/kernel",
            f"transformer.layers.{layer_index}.1.residual.fn.FFN1.0.bias":     f"Transformer/encoderblock_{layer_index}/MlpBlock_3/Dense_0/bias",
            f"transformer.layers.{layer_index}.1.residual.fn.FFN2.0.weight":   f"Transformer/encoderblock_{layer_index}/MlpBlock_3/Dense_1/kernel",
            f"transformer.layers.{layer_index}.1.residual.fn.FFN2.0.bias":     f"Transformer/encoderblock_{layer_index}/MlpBlock_3/Dense_1/bias"
        }
        # 将当前层的映射合并到总映射中
        param_mapping.update(layer_param_mapping)
  
    # 加载参数
    for custom_param, official_param in param_mapping.items():
        official_weight = official_model_state_dict[official_param]  
        # 根据参数名称选择适当的转换操作
        if "kernel" in official_param:
            if "FFN" in custom_param:
                # 对于FFN层的权重，进行转置操作
                transformed_weight = torch.tensor(official_weight).transpose(0, 1)
                custom_model_state_dict[custom_param] = transformed_weight
            else:
                # 对于MultiHeadDotProductAttention层的权重，进行reshape操作
                if "out" in custom_param:
                    transformed_weight = torch.tensor(official_weight).reshape(-1, dim)
                    custom_model_state_dict[custom_param] = transformed_weight
                else:
                    #对于qkv
                    transformed_weight = torch.tensor(official_weight).reshape(dim, -1)
                    custom_model_state_dict[custom_param] = transformed_weight
        else:
            # 对于偏置和LayerNorm的权重，直接使用
            transformed_weight = torch.tensor(official_weight)
            custom_model_state_dict[custom_param] = transformed_weight
            
    custom_model_state_dict['patch_embedding.pos_embedding'] = torch.tensor(official_model_state_dict['Transformer/posembed_input/pos_embedding'])
    custom_model_state_dict['patch_embedding.cls_token'] = torch.tensor(official_model_state_dict['cls'])
    custom_model_state_dict['patch_embedding.patch_to_embedding.bias'] = torch.tensor(official_model_state_dict['embedding/bias'])
    custom_model_state_dict['patch_embedding.patch_to_embedding.weight'] = torch.tensor(official_model_state_dict['embedding/kernel']).reshape(-1, dim).transpose(0, 1)
            
    return custom_model_state_dict


    

In [3]:
if __name__ == "__main__":
    model_small = ViT(
            image_size = 224,
            patch_size = 16,
            num_classes = 100,
            dim = 384,
            depth = 12,
            heads = 6,
            mlp_dim = 1536,
            dropout = 0.01,
            emb_dropout = 0.01)

    # 载入预训练权重
    pretrained_weights_samll = np.load(r"../pretrain/augreg_S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz")

    # 调用函数加载参数
    custom_model_state_dict = model_small.state_dict()
    custom_model_state_dict = load_official_transformer_params(custom_model_state_dict, pretrained_weights_samll)
    model_small.load_state_dict(custom_model_state_dict)

    torch.save(model_small.state_dict(), r'../pretrain/S_16_model_parameters.pth')