In [8]:
import transformers
import torch
import numpy as np
import gc
import tempfile

from transformers import (LlamaForCausalLM, 
                          LlamaTokenizer,
                          AutoTokenizer, 
                          AutoModelForCausalLM)

from typing import List, Dict, Any


  from .autonotebook import tqdm as notebook_tqdm


### 加载模型

In [14]:
model_path = '/workspace/acl/model_zoo/llama/llama-2-7b-chat-hf'
tokenizer = AutoTokenizer.from_pretrained(model_path,trust_remote_code=False, device_map = "auto")
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)

Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.32s/it]


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
    (rotary_emb)

### 加载模型参数

In [None]:
def load_model(model_name_or_path, trust_remote_code:bool=True, device_map="auto"):
    model = AutoModelForCausalLM(model_name_or_path, trust_remote_code=trust_remote_code, device_map=device_map)
    return model

def get_model_param_list(model_names: List[str], model_type:str):
    model_param_list = []
    for name in model_names:
        print(f"loading {name} -----------------")
        model = load_model(name, model_type=model_type)
        model_param_list.append(model.state_dict())
    return model_param_list


### 模型融合

In [None]:
def merge_param(model_param_list: List[Dict], weights: List[float]):
    new_param = {}
    for k in model_param_list[0].keys():
        for w, param in zip(weights, model_param_list):
            if param[k].dtype == torch.int64 or param[k].dtype == torch.int32:
                new_param[k] = param[k]
            elif k not in new_param:
                new_param[k] = w * param[k]
            else:
                new_param[k] += w * param[k]
    return new_param

In [6]:
def test(**kwargs):
    if 'a' in kwargs.keys():
        print(kwargs['a'])
        return True
    
test(a=2)

2


True

In [33]:
import copy
# print(model.state_dict())
a = copy.deepcopy(model)
for idx, k in enumerate(a.state_dict().keys()):
    if idx > 0:
        break
    print(k, a.state_dict()[k][0,0])
    a.state_dict()[k] *= 2
    print(k, a.state_dict()[k][0,0])
    print(k, model.state_dict()[k][0,0])
    model.load_state_dict(a.state_dict())
    print(k, model.state_dict()[k][0,0])
    


model.embed_tokens.weight tensor(0.0012)
model.embed_tokens.weight tensor(0.0024)
model.embed_tokens.weight tensor(0.0012)
model.embed_tokens.weight tensor(0.0024)


### 在python中调用bash命令

In [9]:
import subprocess
import os
# output = subprocess.run(['python', 'wbw_test.py'])
#print(output.decode())
output2 = os.system('python wbw_test.py')
#print(output2)

Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.09s/it]
We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)
Traceback (most recent call last):
  File "/workspace/acl/wbw_test.py", line 17, in <module>
    print(tokenizer.decode(outputs))
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/amadeus/lib/python3.11/site-packages/transformers/tokenization_utils_base.py", line 4034, in decode
    return self._decode(
           ^^^^^^^^^^^^^
  File "/opt/conda/envs/amadeus/lib/python3.11/site-packages/transformers/tokenization_utils_fast.py", line 651, in _decode
    text = self._tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: argument 'ids': 'dict' obj

In [10]:
import os
with open('test.txt', mode='w', encoding='utf-8') as f:
    print(1)

1


## 加载CITB数据集

In [6]:
from datasets import *
import os


In [None]:
path = '/workspace/acl/datasets/CITB/data/CIT_data/initial_multitask_learning/defintion_pos_2/train'
dataset = load_from_disk(path)