In [1]:
import torch
from copy import deepcopy

In [2]:
dataset = "nextqa" #指定数据集的名称，这里为 "nextqa"。
# pretrained_path：预训练模型的路径。
# refined_path：微调后的模型的路径。
# output_path：合并后模型的保存路径，文件名中包含数据集名称。
pretrained_path = "/root/VideoQA/sevila/sevila_checkpoints/sevila_pretrained.pth"
refined_path = "/root/VideoQA/sevila/lavis/results/nextqa_sr/checkpoint_best.pth"
output_path = f"/root/VideoQA/sevila/sevila_checkpoints/sevila_pretrained_refined_{dataset}.pth"

加载模型检查点、 模型参数提取

In [3]:
pretrained_ckpt = torch.load(pretrained_path, map_location="cpu") # ：从 pretrained_path 加载的预训练模型检查点。
refined_ckpt = torch.load(refined_path, map_location="cpu")# 从 refined_path 加载的微调后模型检查点。
old_pretrained_ckpt = deepcopy(pretrained_ckpt)# 深拷贝 pretrained_ckpt，用于后续对比检查更新。

pretrained_ckpt_model = pretrained_ckpt["model"]# 从预训练模型中提取的模型参数。
refined_ckpt_model = refined_ckpt["model"]# 从微调后的模型中提取的模型参数。

 打印模型参数的键

In [None]:
# 打印预训练模型和微调模型中的参数名称（即键名）
print(pretrained_ckpt_model.keys())
print("\n" + "=" * 100 + "\n")
print(refined_ckpt_model.keys())

更新预训练模型的参数

In [4]:
# 遍历 refined_ckpt_model 中的参数键名，检查是否在 pretrained_ckpt_model 中存在：
# 如果键名中包含 "loc"，则更新预训练模型中的对应参数，并打印更新信息。
# 如果键名中不包含 "loc"，则打印提示信息，说明这个参数不属于“定位器”模块。
# 如果键名在预训练模型中不存在，则打印提示信息。
[key for key in refined_ckpt["model"] if key in pretrained_ckpt["model"] and "loc" not in key]

['ln_vision.weight',
 'ln_vision.bias',
 'Qformer.bert.embeddings.position_ids',
 't5_proj.weight',
 't5_proj.bias']

In [5]:

for key in refined_ckpt_model:
    if key in pretrained_ckpt_model:
        if "loc" in key:
            pretrained_ckpt_model[key] = refined_ckpt_model[key]
            print(f"Key '{key}' has been updated.")
        else:
            print(f"Key '{key}' does not belong to Localizer.")
    else:
        print(f"Key '{key}' does not exist.")


Key 'query_tokens_loc' has been updated.
Key 'ln_vision.weight' does not belong to Localizer.
Key 'ln_vision.bias' does not belong to Localizer.
Key 'ln_vision_loc.weight' has been updated.
Key 'ln_vision_loc.bias' has been updated.
Key 't5_model.encoder.embed_tokens.weight' does not exist.
Key 't5_model.decoder.embed_tokens.weight' does not exist.
Key 'Qformer.bert.embeddings.position_ids' does not belong to Localizer.
Key 't5_proj.weight' does not belong to Localizer.
Key 't5_proj.bias' does not belong to Localizer.
Key 'Qformer_loc.bert.embeddings.position_ids' has been updated.
Key 'Qformer_loc.bert.embeddings.LayerNorm.weight' has been updated.
Key 'Qformer_loc.bert.embeddings.LayerNorm.bias' has been updated.
Key 'Qformer_loc.bert.encoder.layer.0.attention.self.query.weight' has been updated.
Key 'Qformer_loc.bert.encoder.layer.0.attention.self.query.bias' has been updated.
Key 'Qformer_loc.bert.encoder.layer.0.attention.self.key.weight' has been updated.
Key 'Qformer_loc.bert.en

In [6]:
#将更新后的 pretrained_ckpt 保存到 output_path
torch.save(pretrained_ckpt, output_path)

In [7]:
#检查 pretrained_ckpt["model"].keys() 和 old_pretrained_ckpt["model"].keys() 是否一致，以确保更新前后的模型结构未发生变化。
pretrained_ckpt["model"].keys() == old_pretrained_ckpt["model"].keys()

True

In [8]:
#再次遍历预训练模型中的参数键名，检查是否包含 "loc" 且在更新后与原始值相等。如果发现未更新的定位器相关键名，则打印提示信息
for key in pretrained_ckpt["model"].keys():
    if "loc" in key and torch.equal(pretrained_ckpt["model"][key], old_pretrained_ckpt["model"][key]):
        print(f"{key}没更新")

Qformer_loc.bert.embeddings.position_ids没更新
