Skip to content

Commit

Permalink
improve hf conversion script
Browse files Browse the repository at this point in the history
  • Loading branch information
hamishivi committed Aug 13, 2024
1 parent 02726bf commit dbf2212
Showing 1 changed file with 67 additions and 41 deletions.
108 changes: 67 additions & 41 deletions EasyLM/models/llama/convert_hf_to_easylm.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,109 +67,130 @@
"n_kv_heads": 8,
"norm_eps": 1e-5,
},
'8b31': {
'dim': 4096,
'intermediate_size': 14336,
'n_layers': 32,
'n_heads': 32,
'n_kv_heads': 8,
'norm_eps': 1e-6,
'vocab_size': 128256,
'rope_theta': 500000,
'max_position_embeddings': 131072,
'rope_scaling': {
"factor": 8.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3"
},
},
}


def inverse_permute(params, w):
n_layers = params["n_layers"]
n_heads = params["n_heads"]
dim = params["dim"]
reshaped_w = w.reshape(n_heads, 2, dim // n_heads // 2, dim)
def inverse_permute(w, n_heads, input_dim, output_dim):
reshaped_w = w.reshape(n_heads, 2, output_dim // n_heads // 2, input_dim)
transposed_w = reshaped_w.transpose(0, 2, 1, 3)
inverted_w = transposed_w.reshape(dim, dim)
return inverted_w


def inverse_permute_kv(params, w):
n_layers = params["n_layers"]
n_kv_heads = params["n_kv_heads"]
n_heads = params["n_heads"]
dim = params["dim"]
reshaped_w = w.reshape(n_kv_heads, 2, dim // n_heads // 2, dim)
transposed_w = reshaped_w.transpose(0, 2, 1, 3)
inverted_w = transposed_w.reshape(dim, n_kv_heads * (dim // n_heads))
inverted_w = transposed_w.reshape(output_dim, input_dim)
return inverted_w


def main(args):
start = time.time()
params = LLAMA_STANDARD_CONFIGS[args.model_size]

ckpt_paths = sorted(Path(args.checkpoint_dir).glob("*.bin"))
ckpt = {}
for i, ckpt_path in enumerate(ckpt_paths):
checkpoint = torch.load(ckpt_path, map_location="cpu")
for k, v in checkpoint.items():
if k.startswith("model."):
k = k[6:]
ckpt[k] = v
if args.use_safetensors:
from safetensors import safe_open
ckpt_paths = sorted(Path(args.checkpoint_dir).glob("*.safetensors"))
for i, ckpt_path in enumerate(ckpt_paths):
with safe_open(ckpt_path, framework="pt", device="cpu") as f:
for key in f.keys():
if key.startswith("model."):
k = key[6:]
ckpt[k] = f.get_tensor(key)
else:
ckpt[key] = f.get_tensor(key)

else:
ckpt_paths = sorted(Path(args.checkpoint_dir).glob("*.bin"))
for i, ckpt_path in enumerate(ckpt_paths):
checkpoint = torch.load(ckpt_path, map_location="cpu")
for k, v in checkpoint.items():
if k.startswith("model."):
k = k[6:]
ckpt[k] = v
print(f"Start convert weight to easylm format...")
jax_weights = {
"transformer": {
"wte": {"embedding": ckpt["embed_tokens.weight"].to(torch.float32).numpy()},
"ln_f": {"kernel": ckpt["norm.weight"].to(torch.float32).numpy()},
"wte": {"embedding": ckpt["embed_tokens.weight"].to(torch.float16).numpy()[:-8, :]},
"ln_f": {"kernel": ckpt["norm.weight"].to(torch.float16).numpy()},
"h": {
"%d"
% (layer): {
"attention": {
"wq": {
"kernel": inverse_permute(
params,
ckpt[f"layers.{layer}.self_attn.q_proj.weight"].to(torch.float32).numpy(),
ckpt[f"layers.{layer}.self_attn.q_proj.weight"].to(torch.float16).numpy(),
n_heads=params["n_heads"],
input_dim=params["dim"],
output_dim=params["dim"],
).transpose()
},
"wk": {
"kernel": inverse_permute_kv(
params,
ckpt[f"layers.{layer}.self_attn.k_proj.weight"].to(torch.float32).numpy(),
)
"kernel": inverse_permute(
ckpt[f"layers.{layer}.self_attn.k_proj.weight"].to(torch.float16).numpy(),
n_heads=params.get("n_kv_heads", params["n_heads"]),
input_dim=params["dim"],
output_dim=(params["dim"] // (params["n_heads"] // params.get("n_kv_heads", params["n_heads"]))),
).transpose()
},
"wv": {
"kernel": ckpt[f"layers.{layer}.self_attn.v_proj.weight"]
.to(torch.float32)
.to(torch.float16)
.numpy()
.transpose()
},
"wo": {
"kernel": ckpt[f"layers.{layer}.self_attn.o_proj.weight"]
.to(torch.float32)
.to(torch.float16)
.numpy()
.transpose()
},
},
"feed_forward": {
"w1": {
"kernel": ckpt[f"layers.{layer}.mlp.gate_proj.weight"]
.to(torch.float32)
.to(torch.float16)
.numpy()
.transpose()
},
"w2": {
"kernel": ckpt[f"layers.{layer}.mlp.down_proj.weight"]
.to(torch.float32)
.to(torch.float16)
.numpy()
.transpose()
},
"w3": {
"kernel": ckpt[f"layers.{layer}.mlp.up_proj.weight"]
.to(torch.float32)
.to(torch.float16)
.numpy()
.transpose()
},
},
"attention_norm": {
"kernel": ckpt[f"layers.{layer}.input_layernorm.weight"].to(torch.float32).numpy()
"kernel": ckpt[f"layers.{layer}.input_layernorm.weight"].to(torch.float16).numpy()
},
"ffn_norm": {
"kernel": ckpt[
f"layers.{layer}.post_attention_layernorm.weight"
].to(torch.float32).numpy()
].to(torch.float16).numpy()
},
}
for layer in range(params["n_layers"])
},
},
"lm_head": {"kernel": ckpt["lm_head.weight"].to(torch.float32).numpy().transpose()},
"lm_head": {"kernel": ckpt["lm_head.weight"].to(torch.float16).numpy().transpose()[:, :-8]},
}
print(f"Convert weight to easylm format finished...")
print(f"Start to save...")
Expand Down Expand Up @@ -200,7 +221,7 @@ def main(args):
"--model_size",
type=str,
default="7b",
choices=["7b", "13b", "30b", "65b", "70b"],
choices=list(LLAMA_STANDARD_CONFIGS.keys()),
help="model size",
)
parser.add_argument(
Expand All @@ -209,6 +230,11 @@ def main(args):
default=True,
help="whether is model weight saved stream format",
)
parser.add_argument(
'--use_safetensors',
action='store_true',
help='Load SafeTensors for model weights',
)

args = parser.parse_args()

Expand Down

0 comments on commit dbf2212

Please sign in to comment.