In [None]:
import os

os.environ['HF_HOME'] ="/workspace/cache"

In [None]:
import safetensors
import torch
from glob import glob
from transformers import Qwen2AudioForConditionalGeneration
from safetensors import safe_open
from tqdm import tqdm

In [None]:
from transformers.trainer_utils import get_last_checkpoint

latest = get_last_checkpoint("lora-embedding-128-qwen2audio-7b")
latest

In [None]:
ori_model = Qwen2AudioForConditionalGeneration.from_pretrained('Qwen/Qwen2-Audio-7B-Instruct')

In [None]:
state_dict = ori_model.state_dict()

In [None]:
files = glob(os.path.join(latest, '*.safetensors'))
for f in files:
    print(f)
    f = safe_open(f, framework="pt", device='cpu')
    keys = f.keys()
    keys = sorted(list(set([k.split('.lora')[0] for k in keys if '.lora' in k])))

    for k in tqdm(keys):
        if 'lm_head' in k:
            actual_k = 'language_model.lm_head.weight'
        else:
            actual_k = k.replace('.base_model.model.model.', '.model.') + '.weight'
        if 'embed_tokens' in k:
            post_A = '.lora_embedding_A.default'
            post_B = '.lora_embedding_B.default'
        else:
            post_A = '.lora_A.default.weight'
            post_B = '.lora_B.default.weight'
        A = k + post_A
        B = k + post_B

        W = state_dict[actual_k]
        if 'embed_tokens' not in k:
            W = W.t()

        A = f.get_tensor(A).type(W.dtype)
        B = f.get_tensor(B).type(W.dtype)

        with torch.no_grad():
            W.addmm_(A.t(), B.t(), alpha = 2)