In [2]:
import gc
import os
import sys
import json
import warnings
from pathlib import Path
from tempfile import TemporaryDirectory
from argparse import ArgumentParser, Namespace

import torch
from tqdm.auto import trange
from transformers import LlamaConfig, LlamaForCausalLM, AutoModelForCausalLM


from megatron.training.tokenizer import build_tokenizer

def permute_qkv(qkv_w: torch.Tensor, dim: int, n_heads: int,
                n_heads_kv: int, revert: bool = False) -> torch.Tensor:

    def permute(x):
        if revert:
            return x.view(head_dim//2, 2, dim).transpose(0, 1).reshape(head_dim, dim)
        return x.view(2, head_dim//2, dim).transpose(0, 1).reshape(head_dim, dim)

    head_dim = dim//n_heads
    n_qs_per_kv = n_heads//n_heads_kv
    n_groups = qkv_w.size(0)//head_dim//(n_qs_per_kv + 2)
    groups = torch.chunk(qkv_w, n_groups, dim=0)
    new = []
    for group in groups:
        *qs, k, v = torch.split(group, head_dim, dim=0)
        assert len(qs) == n_qs_per_kv, f"{len(qs)}, {n_qs_per_kv}"
        new += list(map(permute, qs)) + [permute(k), v]
    return torch.cat(new, dim=0)

def write_json(text, path):
    with open(path, "w") as f:
        json.dump(text, f)


def convert_wqkv(llama_mega, layer_idx=0, n_heads=32, n_heads_kv=8):
    qkv_w = llama_mega[f'decoder.layers.{layer_idx}.self_attention.linear_qkv.weight']
    n_hidden = qkv_w.size(1)
    hidden_dim = n_hidden//n_heads
    print(n_hidden, hidden_dim, n_heads, n_heads_kv)
    qkv_w = permute_qkv(qkv_w, n_hidden, n_heads, n_heads_kv, revert=True)

    n_qs_per_kv = n_heads//n_heads_kv
    n_groups = qkv_w.size(0)//hidden_dim//(n_qs_per_kv + 2)
    qkv_w = list(torch.split(qkv_w, hidden_dim, dim=0))

    wq, wk, wv = [], [], []
    for group in range(n_groups):
        for qs in range(n_qs_per_kv):
            wq.append(qkv_w[0])
            del qkv_w[0]
        wk.append(qkv_w[0])
        del qkv_w[0]
        wv.append(qkv_w[0])
        del qkv_w[0]
    assert len(qkv_w) == 0

    wq = torch.concat(wq, dim=0)
    wk = torch.concat(wk, dim=0)
    wv = torch.concat(wv, dim=0)
    return wq, wk, wv


def convert_ffn(llama_mega, layer_idx=0, n_dense=11008):
    mega_ffn = llama_mega[f"decoder.layers.{layer_idx}.mlp.linear_fc1.weight"]
    ffn_w3, ffn_w1 = mega_ffn.split(n_dense, dim=0)
    return ffn_w1, ffn_w3

# copy the functions from the original MergeLM code
def copy_params_to_model(params, model):  # copying code from model_merging_methods.merging_methods
    for param_name, param_value in model.named_parameters():
        if param_name in params:
            param_value.data.copy_(params[param_name])
        else:
            print(f"param_name {param_name} not in params")

def write_llama_model(model_path,
                input_base_path,
                num_output_shards: int=2,
                norm_eps: float=1e-05,
                rope_theta: float=1e4):

    # Preliminaries
    print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
    os.makedirs(model_path, exist_ok=True)
    with open(os.path.join(input_base_path, 'latest_checkpointed_iteration.txt')) as f:
        iteration = f.read()
    if iteration != "release":
        iteration = f"iter_{int(iteration):07d}"
    print(f"Fetching iteration {iteration}")

    # Load weights
    base_path = Path(input_base_path)/iteration
    assert len(list(base_path.glob("mp_rank_*"))) == 1, "Unshard your model with checkpoint_util.py first!"
    loaded = torch.load(base_path/"mp_rank_00"/"model_optim_rng.pt", map_location="cpu")
    args = loaded['args']

    loaded = loaded['model']
    if False:  # 'transformer' not in loaded:  # normalize key names
        loaded["transformer"] = loaded.pop("encoder")
        for key in list(loaded["transformer"].keys()):
            loaded["transformer"][key.replace("self_attention", "attention")] = loaded["transformer"].pop(key)
        loaded["embedding"]["word_embeddings.weight"] = loaded["embedding"].pop("word_embeddings")["weight"]
        args.num_layers = args.encoder_num_layers

    # Load arguments
    n_layers = args.num_layers
    n_heads = args.num_attention_heads
    n_heads_kv = getattr(args, "num_query_groups", n_heads)
    n_dense = args.ffn_hidden_size
    n_hidden = args.hidden_size
    hidden_per_head = n_hidden // n_heads
    dim = n_hidden // n_heads
    q_size = n_heads//n_heads_kv
    intermediate_size = args.ffn_hidden_size
    inv_freq = 1.0 / (rope_theta ** (torch.arange(0, hidden_per_head, 2).float() / hidden_per_head))

    print('Llama-Megatron Loaded!')
    param_count = 0
    index_dict = {"weight_map": {}}
    
    weight_parameters = {}
    
    # Start conversion
    # with TemporaryDirectory(prefix=model_path) as tmp_model_path:
    print(f'Weighted Converting for {n_layers} layers...')
    for layer_i in range(n_layers):
        # filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
        megatron_qkv_weight = loaded[f'decoder.layers.{layer_i}.self_attention.linear_qkv.weight'].view(n_heads_kv, (q_size+2)*dim, -1)
        q_proj_w_megatron = megatron_qkv_weight[:, :q_size*dim, :]
        k_proj_w_megatron = megatron_qkv_weight[:, q_size*dim:(q_size+1)*dim, :]
        v_proj_w_megatron = megatron_qkv_weight[:, (q_size+1)*dim:, :]
        wq_proj = q_proj_w_megatron.reshape(-1, n_hidden)
        wk_proj = k_proj_w_megatron.reshape(-1, n_hidden)
        wv_proj = v_proj_w_megatron.reshape(-1, n_hidden)
        print(wq_proj.shape, wk_proj.shape, wv_proj.shape)

        ffn_w1, ffn_w3 = convert_ffn(llama_mega=loaded, 
                                    layer_idx=layer_i, 
                                    n_dense=n_dense)
        state_dict = {
            f"model.layers.{layer_i}.self_attn.q_proj.weight": wq_proj,
            f"model.layers.{layer_i}.self_attn.k_proj.weight": wk_proj,
            f"model.layers.{layer_i}.self_attn.v_proj.weight": wv_proj,
            f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"decoder.layers.{layer_i}.self_attention.linear_proj.weight"],
            f"model.layers.{layer_i}.mlp.gate_proj.weight": ffn_w3,
            f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"decoder.layers.{layer_i}.mlp.linear_fc2.weight"],
            f"model.layers.{layer_i}.mlp.up_proj.weight": ffn_w1,
            f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"decoder.layers.{layer_i}.self_attention.linear_qkv.layer_norm_weight"],
            f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"decoder.layers.{layer_i}.mlp.linear_fc1.layer_norm_weight"],
            f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq": inv_freq
        }
        for k, v in state_dict.items():
            print(f'key: {k}, shape: {v.shape}')
        
        weight_parameters.update(state_dict)

        # for k, v in state_dict.items():
        #     index_dict["weight_map"][k] = filename
        #     param_count += v.numel()
        # # print("start writing...", flush=True)
        # # print(state_dict)
        # # print("saving to", os.path.join(tmp_model_path, filename))
        # torch.save(state_dict, os.path.join(tmp_model_path, filename))
        print(f'Sharded file saved')

    # filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
    state_dict = {
        "model.norm.weight": loaded['decoder.final_layernorm.weight'],
        "lm_head.weight": loaded['output_layer.weight'],
        "model.embed_tokens.weight": loaded['embedding.word_embeddings.weight']
    }
    
    weight_parameters.update(state_dict)

    for k, v in weight_parameters.items():
        param_count += v.numel()
    torch_dtype = state_dict["lm_head.weight"].dtype
    # print("start writing...", flush=True)
    # torch.save(state_dict, os.path.join(tmp_model_path, filename))
    # print(f'Sharded file saved to {filename}')

    # # Write configs and save
    # index_dict["metadata"] = {"total_size": param_count * 2}
    # write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
    # config = LlamaConfig.from_json_file('/home/tian/mtian8/LLM4Tutorial/model/HF_model/Meta-Llama-3-8B/config.json')
    # config.rope_theta = rope_theta
    # config.save_pretrained(tmp_model_path)

    # Make space so we can load the model properly now.
    del state_dict
    del loaded
    gc.collect()

    print("Loading the checkpoint in a Llama model...")
    model = LlamaForCausalLM.from_pretrained("/home/tian/mtian8/LLM4Tutorial/model/HF_model/Meta-Llama-3-8B/", torch_dtype=torch_dtype)
    copy_params_to_model(weight_parameters, model)
    
    # Avoid saving this as part of the config.
    del model.config._name_or_path

    print("Saving in the Transformers format.")
    max_num_params_per_shard = param_count*2 // max(1,(num_output_shards-1))
    print(f"Saving with a maximum of {max_num_params_per_shard} parameters per shard.")
    print(f"There are {param_count*2} parameters in total. num_output_shards={num_output_shards}")
    model.save_pretrained(model_path, max_shard_size=max_num_params_per_shard)

def main():
    # make sure megatron is importable
    output_dir = "/home/tian/mtian8/LLM4Tutorial/model/HF_converted_model"
    input_dir = "/home/tian/mtian8/LLM4Tutorial/model/MG_model/llama3_8B_tp_1_pp_4"
    num_output_shards = 2
    if True:
        eps = 1e-5
        rotary_base = 1e4
        rope_theta = rotary_base
        write_llama_model(
            model_path=output_dir,
            input_base_path=input_dir,
            num_output_shards=num_output_shards,
            norm_eps=eps,
            rope_theta=rope_theta,
        )

if __name__ == "__main__":
    main()

Fetching all parameters from the checkpoint at /home/tian/mtian8/LLM4Tutorial/model/MG_model/llama3_8B_tp_1_pp_4.
Fetching iteration iter_0000001


AssertionError: Unshard your model with checkpoint_util.py first!

In [1]:
import gc
import os
import sys
import json
import warnings
from pathlib import Path
from tempfile import TemporaryDirectory
from argparse import ArgumentParser, Namespace

import torch
from tqdm.auto import trange
from transformers import LlamaConfig, LlamaForCausalLM, AutoModelForCausalLM


from megatron.training.tokenizer import build_tokenizer

def write_json(text, path):
    with open(path, "w") as f:
        json.dump(text, f)


def convert_ffn(llama_mega, layer_idx=0, n_dense=11008):
    mega_ffn = llama_mega[f"decoder.layers.{layer_idx}.mlp.linear_fc1.weight"]
    ffn_w3, ffn_w1 = mega_ffn.split(n_dense, dim=0)
    return ffn_w1, ffn_w3


def write_llama_model(model_path,
                input_base_path,
                num_output_shards: int=2,
                norm_eps: float=1e-05,
                rope_theta: float=1e4):

    # Preliminaries
    print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
    os.makedirs(model_path, exist_ok=True)
    with open(os.path.join(input_base_path, 'latest_checkpointed_iteration.txt')) as f:
        iteration = f.read()
    if iteration != "release":
        iteration = f"iter_{int(iteration):07d}"
    print(f"Fetching iteration {iteration}")

    # Load weights
    base_path = Path(input_base_path)/iteration
    assert len(list(base_path.glob("mp_rank_*"))) == 1, "Unshard your model with checkpoint_util.py first!"
    loaded = torch.load(base_path/"mp_rank_00"/"model_optim_rng.pt", map_location="cpu")
    args = loaded['args']

    loaded = loaded['model']
    if False:  # 'transformer' not in loaded:  # normalize key names
        loaded["transformer"] = loaded.pop("encoder")
        for key in list(loaded["transformer"].keys()):
            loaded["transformer"][key.replace("self_attention", "attention")] = loaded["transformer"].pop(key)
        loaded["embedding"]["word_embeddings.weight"] = loaded["embedding"].pop("word_embeddings")["weight"]
        args.num_layers = args.encoder_num_layers

    # Load arguments
    n_layers = args.num_layers
    n_heads = args.num_attention_heads
    n_heads_kv = getattr(args, "num_query_groups", n_heads)
    n_dense = args.ffn_hidden_size
    n_hidden = args.hidden_size
    hidden_per_head = n_hidden // n_heads
    dim = n_hidden // n_heads
    q_size = n_heads//n_heads_kv
    intermediate_size = args.ffn_hidden_size
    inv_freq = 1.0 / (rope_theta ** (torch.arange(0, hidden_per_head, 2).float() / hidden_per_head))

    print('Llama-Megatron Loaded!')
    param_count = 0
    index_dict = {"weight_map": {}}
    
    weight_parameters = {}
    
    # Start conversion
    with TemporaryDirectory(prefix=model_path) as tmp_model_path:
        print(f'Weighted Converting for {n_layers} layers...')
        for layer_i in range(n_layers):
            filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
            megatron_qkv_weight = loaded[f'decoder.layers.{layer_i}.self_attention.linear_qkv.weight'].view(n_heads_kv, (q_size+2)*dim, -1)
            q_proj_w_megatron = megatron_qkv_weight[:, :q_size*dim, :]
            k_proj_w_megatron = megatron_qkv_weight[:, q_size*dim:(q_size+1)*dim, :]
            v_proj_w_megatron = megatron_qkv_weight[:, (q_size+1)*dim:, :]
            wq_proj = q_proj_w_megatron.reshape(-1, n_hidden)
            wk_proj = k_proj_w_megatron.reshape(-1, n_hidden)
            wv_proj = v_proj_w_megatron.reshape(-1, n_hidden)
            
            ffn_w1, ffn_w3 = convert_ffn(llama_mega=loaded, 
                                        layer_idx=layer_i, 
                                        n_dense=n_dense)
            state_dict = {
                f"model.layers.{layer_i}.self_attn.q_proj.weight": wq_proj,
                f"model.layers.{layer_i}.self_attn.k_proj.weight": wk_proj,
                f"model.layers.{layer_i}.self_attn.v_proj.weight": wv_proj,
                f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"decoder.layers.{layer_i}.self_attention.linear_proj.weight"],
                f"model.layers.{layer_i}.mlp.gate_proj.weight": ffn_w3,
                f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"decoder.layers.{layer_i}.mlp.linear_fc2.weight"],
                f"model.layers.{layer_i}.mlp.up_proj.weight": ffn_w1,
                f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"decoder.layers.{layer_i}.self_attention.linear_qkv.layer_norm_weight"],
                f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"decoder.layers.{layer_i}.mlp.linear_fc1.layer_norm_weight"],
                f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq": inv_freq
            }

            for k, v in state_dict.items():
                index_dict["weight_map"][k] = filename
                param_count += v.numel()
            torch.save(state_dict, os.path.join(tmp_model_path, filename))
            print(f'Sharded file saved to {filename}')

        filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
        state_dict = {
            "model.norm.weight": loaded['decoder.final_layernorm.weight'],
            "lm_head.weight": loaded['output_layer.weight'],
            "model.embed_tokens.weight": loaded['embedding.word_embeddings.weight']
        }

        for k, v in weight_parameters.items():
            index_dict["weight_map"][k] = filename
            param_count += v.numel()
        torch_dtype = state_dict["lm_head.weight"].dtype
        print('model.norm.weight:',state_dict["model.norm.weight"])
        print("lm_head.weight:", state_dict["lm_head.weight"])
        print("model.embed_tokens.weight:",state_dict["model.embed_tokens.weight"])
        torch.save(state_dict, os.path.join(tmp_model_path, filename))
        print(f'Sharded file saved to {filename}')

        # Write configs and save
        # index_dict["metadata"] = {"total_size": param_count * 2}
        # write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
        #
        # Change for Config File 
        #
        config = LlamaConfig.from_json_file('/home/tian/mtian8/LLM4Tutorial/model/HF_model/Meta-Llama-3-8B/config.json')
        config.rope_theta = rope_theta
        config.save_pretrained(tmp_model_path)

        # Make space so we can load the model properly now.
        del state_dict
        del loaded
        gc.collect()

        print("Loading the checkpoint in a Llama model...")
        #
        # Change for model
        #
        model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch_dtype)
        # Avoid saving this as part of the config.
        del model.config._name_or_path

    print("Saving in the Transformers format.")
    max_num_params_per_shard = param_count*2 // max(1,(num_output_shards-1))
    print(f"Saving with a maximum of {max_num_params_per_shard} parameters per shard.")
    print(f"There are {param_count*2} parameters in total. num_output_shards={num_output_shards}")
    model.save_pretrained(model_path, max_shard_size=max_num_params_per_shard)

def main():
    # make sure megatron is importable
    output_dir = "/home/tian/mtian8/LLM4Tutorial/model/llama3_convert/"
    input_dir = "/home/tian/mtian8/LLM4Tutorial/model/MG_model/llama3_8B_tp_1"
    num_output_shards = 2
    if True:
        eps = 1e-5
        rotary_base = 1e4
        rope_theta = rotary_base
        write_llama_model(
            model_path=output_dir,
            input_base_path=input_dir,
            num_output_shards=num_output_shards,
            norm_eps=eps,
            rope_theta=rope_theta,
        )

if __name__ == "__main__":
    main()

  from .autonotebook import tqdm as notebook_tqdm


Fetching all parameters from the checkpoint at /home/tian/mtian8/LLM4Tutorial/model/MG_model/llama3_8B_tp_1.
Fetching iteration iter_0000001
Llama-Megatron Loaded!
Weighted Converting for 32 layers...
Sharded file saved to pytorch_model-1-of-33.bin
Sharded file saved to pytorch_model-2-of-33.bin
Sharded file saved to pytorch_model-3-of-33.bin
Sharded file saved to pytorch_model-4-of-33.bin
Sharded file saved to pytorch_model-5-of-33.bin
Sharded file saved to pytorch_model-6-of-33.bin
Sharded file saved to pytorch_model-7-of-33.bin
Sharded file saved to pytorch_model-8-of-33.bin
Sharded file saved to pytorch_model-9-of-33.bin
Sharded file saved to pytorch_model-10-of-33.bin
Sharded file saved to pytorch_model-11-of-33.bin
Sharded file saved to pytorch_model-12-of-33.bin
Sharded file saved to pytorch_model-13-of-33.bin
Sharded file saved to pytorch_model-14-of-33.bin
Sharded file saved to pytorch_model-15-of-33.bin
Sharded file saved to pytorch_model-16-of-33.bin
Sharded file saved to py

OSError: Error no file named pytorch_model.bin, model.safetensors, tf_model.h5, model.ckpt.index or flax_model.msgpack found in directory /home/tian/mtian8/LLM4Tutorial/model/llama3_convert/xf3l504j.

In [4]:
from transformers import AutoModelForCausalLM,LlamaForCausalLM
import torch
def diff(model1, model2):  # compare two models
    for name1, param1 in model1.named_parameters():
        param2 = model2.state_dict()[name1]
        if not torch.equal(param1, param2):
            print(name1, "different by", torch.norm(param1 - param2).item())
model_1 = AutoModelForCausalLM.from_pretrained("/home/tian/mtian8/LLM4Tutorial/model/HF_model/Meta-Llama-3-8B")
model_2 = AutoModelForCausalLM.from_pretrained("/home/tian/mtian8/LLM4Tutorial/model/HF_converted_model/")
model_3 = AutoModelForCausalLM.from_pretrained("/home/tian/mtian8/LLM4Tutorial/model/HF_converted_from_mcore")
diff(model_1, model_3)

Loading checkpoint shards: 100%|██████████| 4/4 [01:50<00:00, 27.74s/it]
Loading checkpoint shards: 100%|██████████| 7/7 [00:00<00:00,  8.95it/s]
Loading checkpoint shards: 100%|██████████| 7/7 [00:05<00:00,  1.22it/s]


In [2]:
print(model_1.state_dict()["model.layers.0.self_attn.k_proj.weight"].shape)
model_1.state_dict()["model.layers.0.self_attn.k_proj.weight"]

torch.Size([1024, 4096])


tensor([[-0.1040, -0.1543,  0.0737,  ...,  0.0312, -0.0231,  0.0442],
        [-0.0564, -0.0869,  0.0188,  ...,  0.0193, -0.0073,  0.0293],
        [-0.0200, -0.0564,  0.0417,  ...,  0.0056, -0.0159,  0.0449],
        ...,
        [ 0.0085,  0.0250, -0.0197,  ..., -0.0177, -0.0069,  0.0030],
        [ 0.0136,  0.0356, -0.0162,  ..., -0.0177,  0.0018,  0.0102],
        [ 0.0039, -0.0100,  0.0118,  ..., -0.0153,  0.0016, -0.0206]])

In [3]:
print(model_2.state_dict()["model.layers.0.self_attn.k_proj.weight"].shape)
model_2.state_dict()["model.layers.0.self_attn.k_proj.weight"]

torch.Size([1024, 4096])


tensor([[-0.1040, -0.1543,  0.0737,  ...,  0.0312, -0.0231,  0.0442],
        [-0.0564, -0.0869,  0.0188,  ...,  0.0193, -0.0073,  0.0293],
        [-0.0200, -0.0564,  0.0417,  ...,  0.0056, -0.0159,  0.0449],
        ...,
        [ 0.0085,  0.0250, -0.0197,  ..., -0.0177, -0.0069,  0.0030],
        [ 0.0136,  0.0356, -0.0162,  ..., -0.0177,  0.0018,  0.0102],
        [ 0.0039, -0.0100,  0.0118,  ..., -0.0153,  0.0016, -0.0206]])

In [12]:
model_1.state_dict()["model.layers.0.self_attn.q_proj.weight"]

tensor([[-2.7618e-03, -2.9053e-02, -3.1586e-03,  ...,  7.3547e-03,
         -4.6875e-02, -2.1606e-02],
        [-1.2512e-02, -6.9824e-02, -3.8605e-03,  ..., -1.2573e-02,
         -4.9805e-02,  2.0508e-02],
        [-1.8799e-02, -4.6631e-02, -4.7607e-03,  ...,  1.1475e-02,
         -1.3245e-02,  1.1536e-02],
        ...,
        [-4.5776e-03, -4.0283e-02,  7.1777e-02,  ...,  5.0659e-03,
         -2.3956e-03,  2.5024e-03],
        [-5.2795e-03, -1.4709e-02,  4.1504e-02,  ...,  5.4321e-03,
         -3.2349e-03,  4.4346e-05],
        [-4.1504e-03, -1.6724e-02,  3.0396e-02,  ...,  8.6060e-03,
          8.0872e-04,  3.1433e-03]])

In [13]:
model_2.state_dict()["model.layers.0.self_attn.q_proj.weight"]

tensor([[-2.7618e-03, -2.9053e-02, -3.1586e-03,  ...,  7.3547e-03,
         -4.6875e-02, -2.1606e-02],
        [-1.8799e-02, -4.6631e-02, -4.7607e-03,  ...,  1.1475e-02,
         -1.3245e-02,  1.1536e-02],
        [ 4.7302e-03, -4.7607e-02,  3.9816e-05,  ...,  1.7700e-02,
         -1.6479e-02, -1.8066e-02],
        ...,
        [ 2.1484e-02,  5.3955e-02, -2.6245e-02,  ..., -7.1106e-03,
          8.4229e-03, -2.2736e-03],
        [-4.5776e-03, -4.0283e-02,  7.1777e-02,  ...,  5.0659e-03,
         -2.3956e-03,  2.5024e-03],
        [-4.1504e-03, -1.6724e-02,  3.0396e-02,  ...,  8.6060e-03,
          8.0872e-04,  3.1433e-03]])

In [3]:
import torch

# Path to the model checkpoint
checkpoint_path = '/home/tian/mtian8/LLM4Tutorial/model/MG_model/llama3_8B_tp_1/iter_0000001/mp_rank_00/model_optim_rng.pt'

# Load the model checkpoint
loaded = torch.load(checkpoint_path, map_location='cpu')
args = loaded['args']
n_layers = args.num_layers
n_heads = args.num_attention_heads
n_heads_kv = getattr(args, "num_query_groups", n_heads)
n_dense = args.ffn_hidden_size
n_hidden = args.hidden_size

In [4]:
hidden_per_head = n_hidden // n_heads
dim = n_hidden // n_heads
q_size = n_heads//n_heads_kv
megatron_qkv_weight = loaded["model"]['decoder.layers.0.self_attention.linear_qkv.weight'].view(n_heads_kv, (q_size+2)*dim, -1)
q_proj_w_megatron = megatron_qkv_weight[:, :q_size*dim, :]
k_proj_w_megatron = megatron_qkv_weight[:, q_size*dim:(q_size+1)*dim, :]
v_proj_w_megatron = megatron_qkv_weight[:, (q_size+1)*dim:, :]
wq_proj = q_proj_w_megatron.reshape(-1, n_hidden)
wk_proj = k_proj_w_megatron.reshape(-1, n_hidden)
wv_proj = v_proj_w_megatron.reshape(-1, n_hidden)

In [5]:
wq_proj.shape

torch.Size([4096, 4096])