/
one_half_model.py
87 lines (66 loc) · 2.34 KB
/
one_half_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, LlamaModel, LlamaConfig
from transformers.activations import GELUActivation
from torch import nn
import math
import json
config = json.load(open("config_half.json"))
config = LlamaConfig(**config)
new_model = LlamaForCausalLM(config)
MODEL_DIR = "."
model = AutoModelForCausalLM.from_pretrained(MODEL_DIR, torch_dtype="auto")
# embedding
w = model.model.embed_tokens.weight
ww = w[:, : w.size(-1)//2]
new_model.model.embed_tokens.weight = nn.Parameter(ww)
# lm_head
w = model.lm_head.weight
ww = w[:, : w.size(-1)//2]
new_model.lm_head.weight = nn.Parameter(ww)
def w_2x2(w):
s0, s1 = w.size(0), w.size(1)
return w[:s0//2, :s1//2]
# layers
for idx, l in enumerate(model.model.layers):
print(f"layer {idx}")
new_layer = new_model.model.layers[idx]
# input_layernorm
w = l.input_layernorm.weight
ww = w[:w.size(-1)//2]
new_layer.input_layernorm.weight = nn.Parameter(ww)
# self_attn
qw = l.self_attn.q_proj.weight
kw = l.self_attn.k_proj.weight
vw = l.self_attn.v_proj.weight
ow = l.self_attn.o_proj.weight
qww = w_2x2(qw)
kww = w_2x2(kw)
vww = w_2x2(vw)
oww = w_2x2(ow)
new_layer.self_attn.q_proj.weight = nn.Parameter(qww)
new_layer.self_attn.k_proj.weight = nn.Parameter(kww)
new_layer.self_attn.v_proj.weight = nn.Parameter(vww)
new_layer.self_attn.o_proj.weight = nn.Parameter(oww)
# swiglu
f1w = l.mlp.gate_proj.weight
f2w = l.mlp.up_proj.weight
f3w = l.mlp.down_proj.weight
f1ww = w_2x2(f1w)
f2ww = w_2x2(f2w)
f3ww = w_2x2(f3w)
new_layer.mlp.gate_proj.weight = nn.Parameter(f1ww)
new_layer.mlp.up_proj.weight = nn.Parameter(f2ww)
new_layer.mlp.down_proj.weight = nn.Parameter(f3ww)
# post_attention_layernorm
w = l.post_attention_layernorm.weight
ww = w[: w.size(-1) // 2]
new_layer.post_attention_layernorm.weight = nn.Parameter(ww)
# last norm
w = model.model.norm.weight
ww = w[: w.size(-1) // 2]
new_model.model.norm.weight = nn.Parameter(ww)
# Save the new model
new_model.save_pretrained("model_ds", safe_serialization=False)
# # save again to remove shared parameter
model = AutoModelForCausalLM.from_pretrained("model_ds", torch_dtype="auto")
model.save_pretrained("model_ds_A", max_shard_size="10GB", safe_serialization=False)