In [1]:
import numpy as np
import os, torch

In [2]:
hf_weight_base_path = "/usr0/home/goliaro/Desktop/FlexFlow/tests/peft/hf_peft_tensors"
ff_weight_base_path = "/usr0/home/goliaro/Desktop/FlexFlow/build/inference_tensors"
def compare_tensors(hf_tensor_filepath, ff_tensor_filepath, tolerance=1e-2):
    assert(os.path.exists(hf_tensor_filepath) and os.path.exists(ff_tensor_filepath))
    hf_tensor = torch.load(hf_tensor_filepath)
    if type(hf_tensor) == tuple or type(hf_tensor) == list:
        assert(len(hf_tensor) == 1)
        hf_tensor = hf_tensor[0]
    hf_tensor = torch.nan_to_num(hf_tensor)
    hf_tensor = hf_tensor.flatten().detach().cpu().numpy()
    ff_tensor = np.loadtxt(ff_tensor_filepath, delimiter=',')

    len_hf_tensor = hf_tensor.shape[0]
    ff_tensor = ff_tensor[:len_hf_tensor]
    
    mismatches = []
    if not np.allclose(ff_tensor, hf_tensor, atol=tolerance):
        print(f"mismatch between {hf_tensor_filepath} and {ff_tensor_filepath}")
        print(f"HF: {hf_tensor}\nFF:{ff_tensor}")
        print(np.isclose(ff_tensor, hf_tensor, atol=tolerance))
        mismatches = np.where(~np.isclose(ff_tensor, hf_tensor, atol=tolerance))[0]
        print(mismatches)
        #print(np.nonzero(hf_tensor)[0])
        # print(np.where(np.isclose(ff_tensor, hf_tensor, atol=tolerance) ==0)[0])
        # print(ff_tensor[36], hf_tensor[36])
    #assert(np.allclose(ff_tensor, hf_tensor, atol=tolerance))
    assert(len(mismatches) <= .05*len_hf_tensor)
    print("Ok!")
def compare_tensors_difference(hf_tensor_filepath, ff_tensor1_filepath, ff_tensor2_filepath, tolerance=1e-2):
    assert(os.path.exists(hf_tensor_filepath))
    assert(os.path.exists(ff_tensor1_filepath))
    assert(os.path.exists(ff_tensor2_filepath))
    hf_tensor = torch.load(hf_tensor_filepath)
    if type(hf_tensor) == tuple or type(hf_tensor) == list:
        assert(len(hf_tensor) == 1)
        hf_tensor = hf_tensor[0]
    hf_tensor = torch.nan_to_num(hf_tensor)
    hf_tensor = hf_tensor.flatten().detach().cpu().numpy()
    ff_tensor1 = np.loadtxt(ff_tensor1_filepath, delimiter=',')
    ff_tensor2 = np.loadtxt(ff_tensor2_filepath, delimiter=',')

    len_hf_tensor = hf_tensor.shape[0]
    ff_tensor1 = ff_tensor1[:len_hf_tensor]
    ff_tensor2 = ff_tensor2[:len_hf_tensor]
    ff_tensor = ff_tensor1 - ff_tensor2
    
    mismatches = []
    if not np.allclose(ff_tensor, hf_tensor, atol=tolerance):
        print(f"mismatch between {hf_tensor_filepath} and {ff_tensor1_filepath} - {ff_tensor2_filepath}")
        print(f"HF: {hf_tensor}\nFF:{ff_tensor}")
        print(np.isclose(ff_tensor, hf_tensor, atol=tolerance))
        mismatches = np.where(~np.isclose(ff_tensor, hf_tensor, atol=tolerance))[0]
        print(mismatches)
        #print(np.nonzero(hf_tensor)[0])
        # print(np.where(np.isclose(ff_tensor, hf_tensor, atol=tolerance) ==0)[0])
        # print(ff_tensor[36], hf_tensor[36])
    #assert(np.allclose(ff_tensor, hf_tensor, atol=tolerance))
    assert(len(mismatches) <= .05*len_hf_tensor)
    print("Ok!")
def compare_hf_tensors(tensor1_fp, tensor2_fp):
    assert(os.path.exists(tensor1_fp) and os.path.exists(tensor2_fp))
    hf_tensor1 = torch.load(tensor1_fp)
    hf_tensor2 = torch.load(tensor2_fp)
    if type(hf_tensor1) == tuple or type(hf_tensor1) == list:
        assert(len(hf_tensor1) == 1)
        hf_tensor1 = hf_tensor1[0]
    if type(hf_tensor2) == tuple or type(hf_tensor2) == list:
        assert(len(hf_tensor2) == 1)
        hf_tensor2 = hf_tensor2[0]
    assert(torch.squeeze(hf_tensor1).shape == torch.squeeze(hf_tensor2).shape)
    hf_tensor1 = torch.nan_to_num(hf_tensor1)
    hf_tensor2 = torch.nan_to_num(hf_tensor2)
    if not (np.allclose(hf_tensor1.detach().cpu().numpy(), hf_tensor2.detach().cpu().numpy())):
        print(f"mismatch between {tensor1_fp} and {tensor2_fp}")
        print(hf_tensor1)
        print(hf_tensor2)
        print(np.isclose(hf_tensor1.detach().cpu().numpy(), hf_tensor2.detach().cpu().numpy()))
        mismatches = np.where(~np.isclose(hf_tensor1.detach().cpu().numpy(), hf_tensor2.detach().cpu().numpy()))[0]
        print(mismatches)
        assert(False)
    print("Ok!")

def check_hf_sum_tensors(tensor_sum_fp, tensor1_fp, tensor2_fp):
    assert(os.path.exists(tensor_sum_fp) and os.path.exists(tensor1_fp) and os.path.exists(tensor2_fp))
    hf_tensor_sum = torch.load(tensor_sum_fp)
    hf_tensor1 = torch.load(tensor1_fp)
    hf_tensor2 = torch.load(tensor2_fp)
    if type(hf_tensor_sum) == tuple or type(hf_tensor_sum) == list:
        assert(len(hf_tensor_sum) == 1)
        hf_tensor_sum = hf_tensor_sum[0]
    if type(hf_tensor1) == tuple or type(hf_tensor1) == list:
        assert(len(hf_tensor1) == 1)
        hf_tensor1 = hf_tensor1[0]
    if type(hf_tensor2) == tuple or type(hf_tensor2) == list:
        assert(len(hf_tensor2) == 1)
        hf_tensor2 = hf_tensor2[0]
    assert(torch.squeeze(hf_tensor_sum).shape == torch.squeeze(hf_tensor1).shape)
    assert(torch.squeeze(hf_tensor1).shape == torch.squeeze(hf_tensor2).shape)
    hf_tensor1 = torch.nan_to_num(hf_tensor1)
    hf_tensor2 = torch.nan_to_num(hf_tensor2)
    hf_tensor_sum = torch.nan_to_num(hf_tensor_sum)
    sum_check_tensor = hf_tensor1 + hf_tensor2
    if not (np.allclose(sum_check_tensor.detach().cpu().numpy(), hf_tensor_sum.detach().cpu().numpy())):
        print(f"mismatch between {sum_check_tensor} and {tensor1_fp} + {tensor2_fp}")
        print(tensor_sum_fp)
        print(sum_check_tensor)
        print(hf_tensor1)
        print(hf_tensor2)
        print(np.isclose(sum_check_tensor.detach().cpu().numpy(), hf_tensor_sum.detach().cpu().numpy()))
        mismatches = np.where(~np.isclose(sum_check_tensor.detach().cpu().numpy(), hf_tensor_sum.detach().cpu().numpy()))[0]
        print(mismatches)
        assert(False)
    print("Ok!")
def check_hf_zero_tensor(hf_tensor_fp):
    assert(os.path.exists(hf_tensor_fp))
    hf_tensor1 = torch.load(hf_tensor_fp)
    if type(hf_tensor1) == tuple or type(hf_tensor1) == list:
        assert(len(hf_tensor1) == 1)
        hf_tensor1 = hf_tensor1[0]
    assert(torch.count_nonzero(torch.nan_to_num(hf_tensor1)).sum() == 0)
def print_tensors(hf_tensor_filepath, ff_tensor_filepath, txt=""):
    assert(os.path.exists(hf_tensor_filepath) and os.path.exists(ff_tensor_filepath))
    hf_tensor = torch.load(hf_tensor_filepath)
    if type(hf_tensor) == tuple or type(hf_tensor) == list:
        assert(len(hf_tensor) == 1)
        hf_tensor = hf_tensor[0]
    hf_tensor = torch.nan_to_num(hf_tensor)
    hf_tensor = hf_tensor.flatten().detach().cpu().numpy()
    ff_tensor = np.loadtxt(ff_tensor_filepath, delimiter=',')

    len_hf_tensor = hf_tensor.shape[0]
    ff_tensor = ff_tensor[:len_hf_tensor]

    print(f"{txt} - HF tensor:")
    print(hf_tensor)
    print(f"{txt} - FF tensor: ")
    print(ff_tensor)
def compare_flexflow_tensors(ff_tensor1_fp, ff_tensor2_fp, tolerance=1e-5, max_len=-1):
    assert(os.path.exists(ff_tensor1_fp) and os.path.exists(ff_tensor2_fp))
    ff_tensor1 = np.loadtxt(ff_tensor1_fp, delimiter=',')
    ff_tensor2 = np.loadtxt(ff_tensor2_fp, delimiter=',')

    if (ff_tensor1.shape != ff_tensor2.shape):
        print(ff_tensor1.shape, ff_tensor2.shape)
    assert(ff_tensor1.shape == ff_tensor2.shape)

    if max_len > -1:
        ff_tensor1 = ff_tensor1[:max_len]
        ff_tensor2 = ff_tensor2[:max_len]
    
    mismatches = []
    if not np.allclose(ff_tensor1, ff_tensor2, atol=tolerance):
        print(f"mismatch between {ff_tensor1_fp} and {ff_tensor2_fp}")
        print(f"Tensor1: {ff_tensor1}\nTensor2:{ff_tensor2}")
        print(np.isclose(ff_tensor1, ff_tensor2, atol=tolerance))
        mismatches = np.where(~np.isclose(ff_tensor1, ff_tensor2, atol=tolerance))[0]
        print(mismatches)
    #assert(np.allclose(ff_tensor, hf_tensor, atol=tolerance))
    assert(len(mismatches) <= .05*len(ff_tensor1))
    print("Ok!")
def compare_flexflow_tensors_shortest(ff_tensor1_fp, ff_tensor2_fp, tolerance=1e-5):
    assert(os.path.exists(ff_tensor1_fp) and os.path.exists(ff_tensor2_fp))
    ff_tensor1 = np.loadtxt(ff_tensor1_fp, delimiter=',')
    ff_tensor2 = np.loadtxt(ff_tensor2_fp, delimiter=',')
    minlen = min(ff_tensor1.shape[0], ff_tensor2.shape[0])
    ff_tensor1 = ff_tensor1[:minlen]
    ff_tensor2 = ff_tensor2[:minlen]
    mismatches = []
    if not np.allclose(ff_tensor1, ff_tensor2, atol=tolerance):
        print(f"mismatch between {ff_tensor1_fp} and {ff_tensor2_fp}")
        print(f"Tensor1: {ff_tensor1}\nTensor2:{ff_tensor2}")
        print(np.isclose(ff_tensor1, ff_tensor2, atol=tolerance))
        mismatches = np.where(~np.isclose(ff_tensor1, ff_tensor2, atol=tolerance))[0]
        print(mismatches)
    #assert(np.allclose(ff_tensor, hf_tensor, atol=tolerance))
    assert(len(mismatches) <= .05*len(ff_tensor1))
    print("Ok!")
def check_flexflow_tensors_sum(ff_tensor_sum_fp, ff_tensor1_fp, ff_tensor2_fp, tolerance=1e-5):
    assert(os.path.exists(ff_tensor1_fp) and os.path.exists(ff_tensor2_fp))
    ff_tensor1 = np.loadtxt(ff_tensor1_fp, delimiter=',')
    ff_tensor2 = np.loadtxt(ff_tensor2_fp, delimiter=',')
    ff_tensor_sum = np.loadtxt(ff_tensor_sum_fp, delimiter=',')
    
    ff_sum = ff_tensor1 + ff_tensor2
    assert(ff_tensor1.shape == ff_tensor2.shape)
    
    mismatches = []
    if not np.allclose(ff_tensor_sum, ff_sum, atol=tolerance):
        print(f"mismatch between {ff_tensor_sum_fp} and sum of {ff_tensor1_fp} + {ff_tensor2_fp}")
        print(f"Tensor1: {ff_tensor1}\nTensor2:{ff_tensor2}")
        print(f"Sum Tensor: {ff_tensor_sum}\nActual sum:{ff_sum}")
        print(np.isclose(ff_tensor_sum, ff_sum, atol=tolerance))
        mismatches = np.where(~np.isclose(ff_tensor_sum, ff_sum, atol=tolerance))[0]
        print(mismatches)
    #assert(np.allclose(ff_tensor, hf_tensor, atol=tolerance))
    assert(len(mismatches) <= .05*len(ff_tensor1))
    print("Ok!")

In [3]:
tot_num_layers = 12
for layer_num in range(tot_num_layers):
    hf_input_ln_out = f"{hf_weight_base_path}/fwd_step_0_layers.{layer_num}.input_layernorm.output_0"
    ff_input_ln_out = f"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_RMSNorm_shard-id_0_output_0"
    if layer_num > 0:
        ff_input_ln_out = f"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_norm_shard-id_0_output_1"
    compare_tensors(hf_input_ln_out, ff_input_ln_out)
    hf_attn_out = f"{hf_weight_base_path}/fwd_step_0_layers.{layer_num}.self_attn.o_proj.output_0"
    ff_attn_out = f"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_output_0"
    compare_tensors(hf_attn_out, ff_attn_out)
    hf_ffn_norm_out = f"{hf_weight_base_path}/fwd_step_0_layers.{layer_num}.post_attention_layernorm.output_0"
    ff_ffn_norm_out = f"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_ffn_norm_shard-id_0_output_1"
    compare_tensors(hf_ffn_norm_out, ff_ffn_norm_out)
    # w1
    hf_gate_proj_out = f"{hf_weight_base_path}/fwd_step_0_layers.{layer_num}.mlp.gate_proj.output_0"
    ff_gate_proj_out = f"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w1_shard-id_0_output_0"
    compare_tensors(hf_gate_proj_out, ff_gate_proj_out)
    # w3
    hf_up_proj_out = f"{hf_weight_base_path}/fwd_step_0_layers.{layer_num}.mlp.up_proj.output_0" 
    ff_up_proj_out = f"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w3_shard-id_0_output_0"
    compare_tensors(hf_up_proj_out, ff_up_proj_out)
    # w2
    hf_down_proj_in = f"{hf_weight_base_path}/fwd_step_0_layers.{layer_num}.mlp.down_proj.input_0"
    hf_down_proj_out = f"{hf_weight_base_path}/fwd_step_0_layers.{layer_num}.mlp.down_proj.output_0"
    ff_down_proj_in = f"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_shard-id_0_input_0"
    ff_down_proj_out = f"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_shard-id_0_output_0"
    compare_tensors(hf_down_proj_in, ff_down_proj_in)
    # compare_tensors(hf_down_proj_out, ff_down_proj_out)
    # LORA input
    hf_lora_A_in = f"{hf_weight_base_path}/fwd_step_0_layers.{layer_num}.mlp.down_proj.lora_A.default.input_0"
    ff_lora_A_in = f"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_lora_shard-id_0_input_0"
    compare_hf_tensors(hf_down_proj_in, hf_lora_A_in)
    compare_tensors(hf_lora_A_in, ff_lora_A_in)
    # LORA weights
    hf_lora_A_weight_fp = f"{hf_weight_base_path}/base_model.model.model.layers.{layer_num}.mlp.down_proj.lora_A.default.weight"
    ff_lora_A_weight_fp = f"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_lora_shard-id_0_weight_A"
    compare_tensors(hf_lora_A_weight_fp, ff_lora_A_weight_fp)
    hf_lora_B_weight_fp = f"{hf_weight_base_path}/base_model.model.model.layers.{layer_num}.mlp.down_proj.lora_B.default.weight"
    ff_lora_B_weight_fp = f"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_lora_shard-id_0_weight_B"
    compare_tensors(hf_lora_B_weight_fp, ff_lora_B_weight_fp)
    # LORA intermediate hf
    hf_lora_A_out = f"{hf_weight_base_path}/fwd_step_0_layers.{layer_num}.mlp.down_proj.lora_A.default.output_0"
    hf_lora_B_in = f"{hf_weight_base_path}/fwd_step_0_layers.{layer_num}.mlp.down_proj.lora_B.default.input_0"
    compare_hf_tensors(hf_lora_A_out, hf_lora_B_in)
    # LORA output
    hf_lora_out = f"{hf_weight_base_path}/fwd_step_0_layers.{layer_num}.mlp.down_proj.lora_B.default.output_0"
    ff_lora_out = f"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_lora_shard-id_0_output_0"
    # compare_tensors(hf_lora_out, ff_lora_out)
    # compare_flexflow_tensors(ff_down_proj_out, ff_lora_out)
    # compare_tensors(hf_down_proj_out, ff_lora_out)
    compare_tensors_difference(hf_lora_out, ff_lora_out, ff_down_proj_out)
    

# After last layer only
hf_norm_out = f"{hf_weight_base_path}/fwd_step_0_norm.output_0"
ff_norm_out = f"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{tot_num_layers-1}_layer-name_norm_shard-id_0_output_1"
compare_tensors(hf_norm_out, ff_norm_out)
hf_lm_head_out = f"{hf_weight_base_path}/fwd_step_0_base_model.model.lm_head.output_0"
ff_lm_head_out = f"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{tot_num_layers-1}_layer-name_output_shard-id_0_output_0"
compare_tensors(hf_lm_head_out, ff_lm_head_out)

Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!


In [4]:
tot_num_layers = 12

ff_BWD_softmax_in = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_100_layer-name_Softmax_shard-id_0_input_0"

hf_BWD_lm_head_out = f"{hf_weight_base_path}/bwd_step_0_base_model.model.lm_head.go_0"
ff_BWD_lm_head_out = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{tot_num_layers-1}_layer-name_output_shard-id_0_output_0"
compare_tensors(hf_BWD_lm_head_out, ff_BWD_lm_head_out, tolerance=1e-5)
# compare weights
hf_lm_head_weight = f"{hf_weight_base_path}/base_model.model.lm_head.weight"
ff_lm_head_weight = f"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{tot_num_layers-1}_layer-name_output_shard-id_0_weight_0"
compare_tensors(hf_lm_head_weight, ff_lm_head_weight, tolerance=1e-5)
hf_BWD_lm_head_in = f"{hf_weight_base_path}/bwd_step_0_base_model.model.lm_head.gi_0"
ff_BWD_lm_head_in = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{tot_num_layers-1}_layer-name_output_shard-id_0_input_0"
compare_tensors(hf_BWD_lm_head_in, ff_BWD_lm_head_in, tolerance=1e-5)
# # Manually check the matmul
# ff_tensor_out = np.loadtxt(ff_BWD_lm_head_out, delimiter=',')
# ff_weight = np.loadtxt(ff_lm_head_weight, delimiter=',').reshape((4096,32000), order='F')
# ff_tensor_out = ff_tensor_out[:32000*24].reshape((32000,24), order='F')
# print(ff_tensor_out.shape)
# print(ff_weight.shape)
# print(np.matmul(ff_weight, ff_tensor_out))
# compare_tensors(hf_BWD_lm_head_in, ff_BWD_lm_head_in)
# ff_tensor = np.loadtxt(ff_tensor_filepath, delimiter=',')

hf_BWD_norm_out = f"{hf_weight_base_path}/bwd_step_0_norm.go_0"
ff_BWD_norm_out = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{tot_num_layers-1}_layer-name_norm_shard-id_0_output_0"
compare_hf_tensors(hf_BWD_lm_head_in, hf_BWD_norm_out)
compare_tensors(hf_BWD_norm_out, ff_BWD_norm_out)
ff_BWD_norm_weight = f"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{tot_num_layers-1}_layer-name_norm_shard-id_0_weight_0"
hf_FWD_norm_weight = f"{hf_weight_base_path}/base_model.model.model.norm.weight"
compare_tensors(hf_FWD_norm_weight, ff_BWD_norm_weight, tolerance=1e-5)
hf_BWD_norm_in = f"{hf_weight_base_path}/bwd_step_0_norm.gi_0"
ff_BWD_norm_in = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{tot_num_layers-1}_layer-name_norm_shard-id_0_input_1"
compare_tensors(hf_BWD_norm_in, ff_BWD_norm_in, tolerance=1e-5)


Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!


In [5]:
from torch import nn
class LlamaRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:seq_len].to(dtype=x.dtype),
            self.sin_cached[:seq_len].to(dtype=x.dtype),
        )
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2] # first half
    x2 = x[..., x.shape[-1] // 2 :] # second half
    return torch.cat((x2, -x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`):
            The position indices of the tokens corresponding to the query and key tensors. For example, this can be
            used to pass offsetted position ids when working with a KV-cache.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed
head_dim = 64
max_position_embeddings = 2048
rope_theta=10_000
kv_seq_len = 24
rotary_emb = LlamaRotaryEmbedding(
    head_dim,
    max_position_embeddings=max_position_embeddings,
    base=rope_theta,
)

In [11]:
tot_num_layers = 12
for layer_num in range(tot_num_layers-1, -1, -1):
    # HuggingFace filepaths
    hf_BWD_norm_in = f"{hf_weight_base_path}/bwd_step_0_norm.gi_0"
    hf_BWD_loraB_out = f"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.mlp.down_proj.lora_B.default.go_0"
    hf_BWD_loraB_in = f"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.mlp.down_proj.lora_B.default.gi_0"
    hf_BWD_loraA_out = f"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.mlp.down_proj.lora_A.default.go_0"
    hf_BWD_loraA_in = f"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.mlp.down_proj.lora_A.default.gi_0"
    hf_loraA_weight = f"{hf_weight_base_path}/base_model.model.model.layers.{layer_num}.mlp.down_proj.lora_A.default.weight"
    hf_loraB_weight = f"{hf_weight_base_path}/base_model.model.model.layers.{layer_num}.mlp.down_proj.lora_B.default.weight"
    hf_BWD_lora_dropout_out = f"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.mlp.down_proj.lora_dropout.default.go_0"
    hf_BWD_lora_dropout_in = f"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.mlp.down_proj.lora_dropout.default.gi_0"
    hf_BWD_w2_out = f"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.mlp.down_proj.go_0"
    hf_BWD_w2_in = f"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.mlp.down_proj.gi_0"
    hf_w2_weight = f"{hf_weight_base_path}/base_model.model.model.layers.{layer_num}.mlp.down_proj.weight"
    hf_BWD_w3_out = f"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.mlp.up_proj.go_0"
    hf_BWD_w3_in = f"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.mlp.up_proj.gi_0"
    hf_BWD_w1_out = f"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.mlp.gate_proj.go_0"
    hf_BWD_w1_in = f"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.mlp.gate_proj.gi_0"
    hf_BWD_act_fn_in = f"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.mlp.act_fn.gi_0"
    hf_BWD_act_fn_out = f"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.mlp.act_fn.go_0"
    hf_BWD_ffn_norm_out = f"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.post_attention_layernorm.go_0"
    hf_BWD_ffn_norm_in = f"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.post_attention_layernorm.gi_0"
    hf_BWD_attn_out_out = f"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.self_attn.o_proj.go_0"
    hf_BWD_attn_q_in = f"{hf_weight_base_path}/bwd_step_0_layers.11.self_attn.q_proj.gi_0"
    hf_FWD_w1_out = f"{hf_weight_base_path}/fwd_step_0_layers.{layer_num}.mlp.gate_proj.output_0"
    hf_FWD_w3_out = f"{hf_weight_base_path}/fwd_step_0_layers.{layer_num}.mlp.up_proj.output_0"
    hf_FWD_act_fn_out = f"{hf_weight_base_path}/fwd_step_0_layers.{layer_num}.mlp.act_fn.output_0"
    hf_BWD_attn_oproj_in = f"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.self_attn.o_proj.gi_0"
    hf_attn_qproj_weight = f"{hf_weight_base_path}/base_model.model.model.layers.{layer_num}.self_attn.q_proj.weight"
    hf_attn_kproj_weight = f"{hf_weight_base_path}/base_model.model.model.layers.{layer_num}.self_attn.k_proj.weight"
    hf_attn_vproj_weight = f"{hf_weight_base_path}/base_model.model.model.layers.{layer_num}.self_attn.v_proj.weight"
    hf_attn_oproj_weight = f"{hf_weight_base_path}/base_model.model.model.layers.{layer_num}.self_attn.o_proj.weight"
    # hf_BWD_attn_vproj_in = f"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.self_attn.v_proj.gi_0"
    # FlexFlow filepaths
    ff_BWD_w2_out = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_shard-id_0_output_0"
    ff_BWD_w2_in = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_shard-id_0_input_0"
    ff_BWD_w2_in_pre = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_shard-id_0_pre_input_0"
    ff_w2_weight = f"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_shard-id_0_weight_0"
    ff_BWD_ssm_out = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_SigmoidSiluMulti_shard-id_0_output_0"
    ff_BWD_ssm_in1 = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_SigmoidSiluMulti_shard-id_0_input_0"
    ff_BWD_ssm_in2 = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_SigmoidSiluMulti_shard-id_0_input_1"
    ff_BWD_w3_out = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w3_shard-id_0_output_0"
    ff_BWD_w3_in = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w3_shard-id_0_input_0"
    ff_BWD_lora_A_in = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_lora_shard-id_0_input_0"
    ff_BWD_lora_B_out = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_lora_shard-id_0_output_0"
    ff_lora_A_weight = f"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_lora_shard-id_0_weight_A"
    ff_lora_B_weight = f"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_lora_shard-id_0_weight_B"
    ff_BWD_w1_out = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w1_shard-id_0_output_0"
    ff_BWD_w1_in = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w1_shard-id_0_input_0"
    ff_BWD_w1_in_pre = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w1_shard-id_0_pre_input_0"
    ff_w1_weight = f"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w1_shard-id_0_weight_0"
    ff_BWD_ffn_norm_in1 = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_ffn_norm_shard-id_0_input_0"
    ff_BWD_ffn_norm_in2 = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_ffn_norm_shard-id_0_input_1"
    ff_BWD_ffn_norm_out = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_ffn_norm_shard-id_0_output_0"
    ff_BWD_attn_out = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_output_0"
    ff_BWD_attn_in = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_input_0"
    ff_BWD_ssm_cached_w1_input = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_SigmoidSiluMulti_shard-id_0_cached_w1_output"
    ff_BWD_ssm_cached_w3_input = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_SigmoidSiluMulti_shard-id_0_cached_w3_output"
    ff_FWD_w1_out = f"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w1_shard-id_0_output_0"
    ff_FWD_w3_out = f"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w3_shard-id_0_output_0"
    ff_FWD_act_fnc_out = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_SigmoidSiluMulti_shard-id_0_act_fn_output"
    ff_BWD_attn_o_proj_in = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_o_proj_in_grad"
    # ff_BWD_attn_v_proj_in = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_v_proj_in_grad"
    ff_attn_oproj_weight = f"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_11_layer-name_layers_11_attention_shard-id_0_weight_0"
    # ff_attn_qk_prods_softmax = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_qk_prods_softmax"

    # xxx = torch.load(hf_BWD_attn_out_out)
    # xxx.detach().cpu().numpy().tofile(f"{hf_BWD_attn_out_out}.flexflow")
    # print(f"{hf_BWD_attn_out_out}.flexflow")
    
    # HuggingFace checks
    print("\nHuggingface checks:")
    if layer_num == tot_num_layers-1:
        compare_hf_tensors(hf_BWD_norm_in, hf_BWD_loraB_out)
        compare_hf_tensors(hf_BWD_norm_in, hf_BWD_w2_out)
    compare_hf_tensors(hf_BWD_loraB_out, hf_BWD_w2_out)
    compare_hf_tensors(hf_BWD_loraB_in, hf_BWD_loraA_out)
    # compare_hf_tensors(hf_BWD_w3_out, hf_BWD_w2_out)
    compare_hf_tensors(hf_BWD_act_fn_in, hf_BWD_w1_out)
    check_hf_sum_tensors(hf_BWD_ffn_norm_out, hf_BWD_w1_in, hf_BWD_w3_in)
    check_hf_sum_tensors(hf_BWD_attn_out_out, hf_BWD_ffn_norm_in, hf_BWD_norm_in)

    # FlexFlow checks
    print("\nFlexFlow checks:")
    compare_flexflow_tensors(ff_BWD_w2_out, ff_BWD_lora_B_out)
    compare_flexflow_tensors(ff_BWD_w2_in_pre, ff_BWD_lora_A_in)
    compare_flexflow_tensors(ff_BWD_w2_in, ff_BWD_ssm_out)
    compare_flexflow_tensors(ff_BWD_ssm_in2, ff_BWD_w3_out)
    compare_flexflow_tensors(ff_BWD_ssm_in1, ff_BWD_w1_out)
    compare_flexflow_tensors(ff_BWD_w1_in, ff_BWD_ffn_norm_out)
    compare_flexflow_tensors(ff_BWD_w1_in_pre, ff_BWD_w3_in)
    compare_flexflow_tensors(ff_BWD_ffn_norm_in1, ff_BWD_ffn_norm_in2, max_len=24*768)
    #compare_flexflow_tensors(ff_BWD_ffn_norm_in2, ff_BWD_attn_out, max_len=24*768) # should fail

    # HF-FlexFlow checks
    print("\nHuggingface-FlexFlow checks:")
    compare_tensors(hf_BWD_w2_out, ff_BWD_w2_out, tolerance=1e-5)
    compare_tensors(hf_w2_weight, ff_w2_weight, tolerance=1e-5)
    #print(torch.load(hf_w2_weight).shape)
    compare_tensors(hf_loraA_weight, ff_lora_A_weight, tolerance=1e-5)
    compare_tensors(hf_loraB_weight, ff_lora_B_weight, tolerance=1e-5)

    compare_tensors(hf_BWD_loraB_out, ff_BWD_lora_B_out)
    compare_tensors(hf_BWD_loraA_in, ff_BWD_lora_A_in)

    compare_tensors(hf_BWD_w2_in, ff_BWD_ssm_out)
    compare_tensors(hf_BWD_w2_in, ff_BWD_w2_in)
    compare_tensors(hf_BWD_w1_out, ff_BWD_w1_out)
    compare_tensors_difference(hf_BWD_w1_in, ff_BWD_w1_in, ff_BWD_w1_in_pre)

    compare_tensors(hf_FWD_w1_out, ff_FWD_w1_out)
    compare_tensors(hf_FWD_w3_out, ff_FWD_w3_out)
    compare_tensors(hf_BWD_w3_out, ff_BWD_w3_out)
    compare_tensors(hf_BWD_w3_in, ff_BWD_w3_in)
    compare_tensors(hf_BWD_w1_out, ff_BWD_w1_out)
    # compare_tensors(hf_BWD_ffn_norm_out, ff_BWD_ffn_norm_out)
    # compare_tensors(hf_BWD_ffn_norm_in, ff_BWD_ffn_norm_in2)
    # compare_tensors(hf_BWD_attn_out_out, ff_BWD_ffn_norm_in2)
    compare_tensors(hf_BWD_attn_out_out, ff_BWD_attn_out)

    # compare attn weight tensors
    hidden_size = 768
    qProjSize = 64
    num_heads = 12
    num_new_tokens = num_tokens = 24
    ff_attn_weight_tensor = np.loadtxt(ff_attn_oproj_weight, delimiter=',')
    ff_attn_qproj_weight_tensor = ff_attn_weight_tensor[:hidden_size*qProjSize*num_heads].reshape((hidden_size,qProjSize*num_heads), order = 'F')
    ff_attn_kproj_weight_tensor = ff_attn_weight_tensor[hidden_size*qProjSize*num_heads:2*hidden_size*qProjSize*num_heads].reshape((hidden_size,qProjSize*num_heads), order = 'F')
    ff_attn_vproj_weight_tensor = ff_attn_weight_tensor[2*hidden_size*qProjSize*num_heads:3*hidden_size*qProjSize*num_heads].reshape((hidden_size,qProjSize*num_heads), order = 'F')
    ff_attn_oproj_weight_tensor = ff_attn_weight_tensor[3*hidden_size*qProjSize*num_heads:].reshape((qProjSize*num_heads,hidden_size), order='F')
    
    hf_attn_qproj_weight_tensor = torch.load(hf_attn_qproj_weight).T.detach().cpu().numpy()
    hf_attn_kproj_weight_tensor = torch.load(hf_attn_kproj_weight).T.detach().cpu().numpy()
    hf_attn_vproj_weight_tensor = torch.load(hf_attn_vproj_weight).T.detach().cpu().numpy()
    hf_attn_oproj_weight_tensor = torch.load(hf_attn_oproj_weight).T.detach().cpu().numpy()
    
    assert(np.allclose(ff_attn_qproj_weight_tensor, hf_attn_qproj_weight_tensor, atol=1e-5))
    assert(np.allclose(ff_attn_kproj_weight_tensor, hf_attn_kproj_weight_tensor, atol=1e-5))
    assert(np.allclose(ff_attn_vproj_weight_tensor, hf_attn_vproj_weight_tensor, atol=1e-5))
    assert(np.allclose(ff_attn_oproj_weight_tensor, hf_attn_oproj_weight_tensor, atol=1e-5))
    
    # Compare attn outproj grad in tensors
    compare_tensors(hf_BWD_attn_oproj_in, ff_BWD_attn_o_proj_in)
    
    ########### Compare value projs grads ######################
    # 1. compare qk prods softmax
    hf_qk_prods_softmax = f"{hf_weight_base_path}/fwd_step_0_layers.{layer_num}.self_attn.qk_prods_softmax"
    ff_attn_qk_prods_softmax = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_qk_prods_softmax"
    
    hf_qk_prods_softmax = torch.load(hf_qk_prods_softmax)
    ff_qk_prods_softmax = np.loadtxt(ff_attn_qk_prods_softmax, delimiter=',').reshape((num_new_tokens, num_tokens, num_heads), order = 'F')

    for head_idx in range(num_heads):
        hf_qkps = hf_qk_prods_softmax.squeeze()[head_idx, :, :].detach().cpu().numpy()
        ff_qkps = ff_qk_prods_softmax[:,:,head_idx]
        assert(np.allclose(ff_qkps, hf_qkps, atol=1e-5))
    
    # 2. compare attn heads grads
    hf_attn_heads_grads = f"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.self_attn.o_proj.gi_0"
    ff_attn_heads_grads = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_o_proj_in_grad"

    hf_attn_heads_grads = torch.load(hf_attn_heads_grads).T.squeeze().detach().cpu().numpy()
    ff_attn_heads_grads = np.loadtxt(ff_attn_heads_grads, delimiter=',').reshape((qProjSize*num_heads, num_new_tokens), order = 'F')
    assert(np.allclose(ff_attn_heads_grads, hf_attn_heads_grads, atol=1e-2))

    # 3. vproj grads
    hf_vproj_grads = f"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.self_attn.v_proj.go_0"
    ff_vproj_grads = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_v_proj_in_grad"

    hf_vproj_grads = torch.load(hf_vproj_grads).squeeze().detach().cpu().numpy()
    ff_vproj_grads = np.loadtxt(ff_vproj_grads, delimiter=',').reshape((num_tokens, qProjSize*num_heads), order='F')
    assert(np.allclose(hf_vproj_grads, ff_vproj_grads, atol=1e-2))

    
    
    
    ##############################
    hf_value_states = f"{hf_weight_base_path}/fwd_step_0_layers.11.self_attn.value_states"
    hf_value_states = torch.load(hf_value_states).squeeze().permute(2,0,1).detach().cpu().numpy()
    # print(hf_value_states.shape)
    ff_value_states = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_vcache"
    ff_value_states = np.loadtxt(ff_value_states, delimiter=',').reshape((qProjSize, num_heads, num_tokens), order='F')
    # print(ff_value_states.shape)
    assert(np.allclose(hf_value_states, ff_value_states, atol=1e-2))
    
    
    
    ########## Compare key and query projs grads ##################
    ff_devQKVPRojArray = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_devQKVPRojArray"
    ff_devQKVPRojArray = np.loadtxt(ff_devQKVPRojArray, delimiter=',').reshape((num_tokens, qProjSize*num_heads, 3), order = 'F')
    ff_qProjGrads = ff_devQKVPRojArray[:,:,0]
    ff_kProjGrads = ff_devQKVPRojArray[:,:,1]
    ff_vProjGrads = ff_devQKVPRojArray[:,:,2]
    assert(np.allclose(ff_vProjGrads, ff_vproj_grads, atol=1e-5))

    # simulate qk_prods_softmax
    ff_attn_heads_grads = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_o_proj_in_grad"
    ff_attn_heads_grads = np.loadtxt(ff_attn_heads_grads, delimiter=',').reshape((qProjSize,num_heads, num_new_tokens), order = 'F')
    ff_attn_heads_grads = torch.from_numpy(ff_attn_heads_grads)
    ff_attn_heads_grads = ff_attn_heads_grads.permute(1,2,0)
    ff_value_states = torch.from_numpy(ff_value_states)
    ff_value_states = ff_value_states.permute(1,0,2)
    # print(ff_attn_heads_grads.shape)
    # print(ff_value_states.shape)
    simulated_qk_prods_softmax_grads = torch.matmul(ff_attn_heads_grads, ff_value_states)
    #simulated_qk_prods_softmax_grads = simulated_qk_prods_softmax_grads
    #print("Simulated QK prods grads:")
    #print(simulated_qk_prods_softmax_grads[0,:,:])

    # qk prods softmax right before softmax
    hf_qk_prods_softmax2 = f"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.self_attn.softmax_op.go_0"
    hf_qk_prods_softmax2 = torch.load(hf_qk_prods_softmax2)
    ff_qk_prods_softmax2 = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_qk_prods_softmax_grad"
    ff_qk_prods_softmax2 = np.loadtxt(ff_qk_prods_softmax2, delimiter=',').reshape((num_new_tokens, num_tokens, num_heads), order = 'F')
    hf_qk_prods_softmax2 = hf_qk_prods_softmax2.squeeze().permute(1,2,0)
    hf_qk_prods_softmax2 = hf_qk_prods_softmax2.detach().cpu().numpy()
    # assert(np.allclose(ff_qk_prods_softmax2, hf_qk_prods_softmax2, atol=1e-2))
    mismatches = np.where(~np.isclose(ff_qk_prods_softmax2, hf_qk_prods_softmax2))
    mismatches = [(mismatches[0][i],mismatches[1][i], mismatches[2][i]) for i in range(len(mismatches[0]))]
    pct_mismatch = len(mismatches) / (hf_qk_prods_softmax2.shape[0] * hf_qk_prods_softmax2.shape[1] * hf_qk_prods_softmax2.shape[2])
    print(f"{pct_mismatch*100}% mismatch in QK prods softmax out grad")
    assert(pct_mismatch <= 0.05)

    # qk prods softmax right after softmax
    hf_qk_prods_softmax2 = f"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.self_attn.softmax_op.gi_0"
    hf_qk_prods_softmax2 = torch.load(hf_qk_prods_softmax2)
    ff_qk_prods_softmax2 = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_qk_prods_softmax_grad_in"
    ff_qk_prods_softmax2 = np.loadtxt(ff_qk_prods_softmax2, delimiter=',').reshape((num_new_tokens, num_tokens, num_heads), order = 'F')
    hf_qk_prods_softmax2 = hf_qk_prods_softmax2.squeeze().permute(1,2,0)
    hf_qk_prods_softmax2 = hf_qk_prods_softmax2.detach().cpu().numpy()
    assert(np.allclose(ff_qk_prods_softmax2, hf_qk_prods_softmax2, atol=1e-2))
    
    # qk prods softmax after mask
    hf_qk_prods_softmax2 = f"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.self_attn.matmul_op.go_0"
    hf_qk_prods_softmax2 = torch.load(hf_qk_prods_softmax2)
    ff_qk_prods_softmax2 = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_qk_prods_softmax_grad_in_masked"
    ff_qk_prods_softmax2 = np.loadtxt(ff_qk_prods_softmax2, delimiter=',').reshape((num_new_tokens, num_tokens, num_heads), order = 'F')
    hf_qk_prods_softmax2 = hf_qk_prods_softmax2.squeeze().permute(1,2,0)
    hf_qk_prods_softmax2 = hf_qk_prods_softmax2.detach().cpu().numpy()
    assert(np.allclose(ff_qk_prods_softmax2, hf_qk_prods_softmax2, atol=1e-2))

    # Compare query activation
    hf_query_activation = f"{hf_weight_base_path}/fwd_step_0_layers.{layer_num}.self_attn.query_activation"
    hf_query_activation = torch.load(hf_query_activation)
    ff_query_activation = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_query_activation"
    ff_query_activation = np.loadtxt(ff_query_activation, delimiter=',').reshape((qProjSize, num_heads, num_new_tokens), order = 'F')
    hf_query_activation = hf_query_activation.squeeze().permute(2,0,1).detach().cpu().numpy()
    assert(np.allclose(ff_query_activation, hf_query_activation, atol=1e-2))
    
    ########################################## ROPE and Kproj ##########################################

    # Compare FF kproj with intermediate kproj data from HF
    hf_kproj_grads_post_rotary = f"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.self_attn.identity_kv_post_rotary.go_0"
    hf_kproj_grads_post_rotary = torch.load(hf_kproj_grads_post_rotary)
    hf_kproj_grads_post_rotary_copy = hf_kproj_grads_post_rotary.squeeze().permute(1,2,0).detach().cpu().numpy()
    # print("hf_kproj_grads_post_rotary: ", hf_kproj_grads_post_rotary_copy.shape)
    # print(hf_kproj_grads_post_rotary_copy[:,:,0])
    # Check hf ROPE 
    cos, sin = rotary_emb(hf_kproj_grads_post_rotary, seq_len=24)
    cos = cos.cuda()
    sin = sin.cuda()
    # query_states:  torch.Size([1, 12, 24, 64])
    # key_states:  torch.Size([1, 12, 24, 64])
    # position_ids:  torch.Size([1, 24])
    # tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
    #          18, 19, 20, 21, 22, 23]], device='cuda:0')
    query_states = torch.zeros([1, 12, 24, 64]).cuda()
    position_ids = torch.arange(24).unsqueeze(0).cuda()
    query_states, hf_kproj_grads_post_rotary = apply_rotary_pos_emb(query_states, hf_kproj_grads_post_rotary, cos, sin, position_ids)
    hf_kproj_grads_post_rotary = hf_kproj_grads_post_rotary.squeeze().permute(1,2,0).detach().cpu().numpy()
    # print("hf_kproj_grads_post_rotary: ", hf_kproj_grads_post_rotary.shape)
    # print(hf_kproj_grads_post_rotary[:,:,0])
    
    hf_kproj_grads_before_rotary = f"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.self_attn.identity_kv_before_rotary.go_0"
    hf_kproj_grads_before_rotary = torch.load(hf_kproj_grads_before_rotary)
    hf_kproj_grads_before_rotary = hf_kproj_grads_before_rotary.squeeze().permute(1,2,0).detach().cpu().numpy()
    # print("hf_kproj_grads_before_rotary: ", hf_kproj_grads_before_rotary.shape)
    # print(hf_kproj_grads_before_rotary[:,:,0])
    # Compare HF rope with manual ROPE
    assert(np.allclose(hf_kproj_grads_post_rotary, hf_kproj_grads_before_rotary, atol=1e-5))
    # Compare HF Kproj with FF Kproj (before ROPE) 
    ff_kproj_pre = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_devkproj_pre"
    ff_kproj_pre = np.loadtxt(ff_kproj_pre, delimiter=',').reshape((num_tokens, qProjSize, num_heads), order = 'F')
    # print("ff_kproj_pre: ", ff_kproj_pre.shape)
    #print(ff_kproj_pre[:,:,0])
    mismatches = np.where(~np.isclose(ff_kproj_pre, hf_kproj_grads_post_rotary_copy, atol=1e-5))
    mismatches = [(mismatches[0][i],mismatches[1][i], mismatches[2][i]) for i in range(len(mismatches[0]))]
    pct_mismatch = len(mismatches) / (ff_kproj_pre.shape[0] * ff_kproj_pre.shape[1] * ff_kproj_pre.shape[2])
    print(f"{pct_mismatch*100}% mismatch between HF and FF for kproj (before applying ROPE)")
    assert(pct_mismatch <= 0.05)
    #assert(np.allclose(ff_kproj_pre, hf_kproj_grads_post_rotary_copy, atol=1e-5))
    
    ff_kproj = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_devkproj"
    ff_kproj = np.loadtxt(ff_kproj, delimiter=',').reshape((num_tokens, qProjSize, num_heads), order = 'F')
    # print("ff_kproj: ", ff_kproj.shape)
    #print(ff_kproj[:,:,0])
    mismatches = np.where(~np.isclose(ff_kproj, hf_kproj_grads_before_rotary, atol=1e-5))
    mismatches = [(mismatches[0][i],mismatches[1][i], mismatches[2][i]) for i in range(len(mismatches[0]))]
    pct_mismatch = len(mismatches) / (ff_kproj.shape[0] * ff_kproj.shape[1] * ff_kproj.shape[2])
    print(f"{pct_mismatch*100}% mismatch between HF and FF for kproj (after applying ROPE)")
    assert(pct_mismatch <= 0.05)
    #assert(np.allclose(ff_kproj, hf_kproj_grads_before_rotary, atol=1e-5))
    
    
    #assert(np.allclose(hf_kproj_grads_post_rotary, hf_kproj_grads_before_rotary, atol=1e-2))
    hf_kproj_grads = f"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.self_attn.k_proj.go_0"
    hf_kproj_grads = torch.load(hf_kproj_grads).squeeze()
    #print("hf_kproj_grads: ", hf_kproj_grads.shape)
    #print(hf_kproj_grads[:,:64])
    reshaped_tensor = hf_kproj_grads.view(24, 12, 64).transpose(1, 2).contiguous().detach().cpu().numpy()
    #print(reshaped_tensor.shape)
    assert(np.allclose(ff_kproj, reshaped_tensor, atol=1e-2))

    ########################################## Qproj (with ROPE) ##########################################

    # Compare QProj
    hf_qproj_grads = f"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.self_attn.q_proj.go_0"
    hf_qproj_grads = torch.load(hf_qproj_grads).squeeze()
    # print("HF Qproj:")
    # print(hf_qproj_grads.shape)
    reshaped_tensor = hf_qproj_grads.view(24, 12, 64).transpose(1, 2).contiguous().detach().cpu().numpy()
    # print("\t reshaped: ", reshaped_tensor.shape)
    # print(reshaped_tensor[:,:,0])
    ff_qproj = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_devQKVPRojArray"
    ff_qproj = np.loadtxt(ff_qproj, delimiter=',').reshape((num_tokens, qProjSize, num_heads, 3), order = 'F')[:,:,:,0]
    # print("FF Qproj:")
    # print(ff_qproj.shape)
    # print(ff_qproj[:,:,0])
    assert(np.allclose(ff_qproj, reshaped_tensor, atol=1e-2))

    hf_attn_in = f"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.input_layernorm.go_0"
    hf_attn_in = torch.load(hf_attn_in)
    # print("hf_attn_in: ", hf_attn_in.shape)
    hf_attn_in = hf_attn_in.squeeze().T
    hf_attn_in = hf_attn_in.detach().cpu().numpy()
    # print("hf_attn_in: ", hf_attn_in.shape)
    # print(hf_attn_in)

    ff_attn_in = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_attn_final_grad_in"
    ff_attn_in = np.loadtxt(ff_attn_in, delimiter=',').reshape((768,num_tokens), order = 'F')
    # print("ff_attn_in: ", ff_attn_in.shape)
    # print(ff_attn_in)
    #assert(np.allclose(ff_attn_in, hf_attn_in, atol=1e-2))

    mismatches = np.where(~np.isclose(ff_attn_in, hf_attn_in))
    mismatches = [(mismatches[0][i], mismatches[1][i]) for i in range(len(mismatches[0]))]
    pct_mismatch = len(mismatches) / (hf_attn_in.shape[0] * hf_attn_in.shape[1])
    print(f"{pct_mismatch*100}% mismatch in attention input grads")
    assert(pct_mismatch <= 0.05)
    

    assert False


Huggingface checks:
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!

FlexFlow checks:
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!

Huggingface-FlexFlow checks:
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
mismatch between /usr0/home/goliaro/Desktop/FlexFlow/tests/peft/hf_peft_tensors/bwd_step_0_layers.11.mlp.down_proj.gi_0 and /usr0/home/goliaro/Desktop/FlexFlow/build/inference_tensors/model_0_bwd-step_0_layer-num_11_layer-name_SigmoidSiluMulti_shard-id_0_output_0
HF: [ 6.4350547e+03 -6.4898600e+05  1.1761116e+05 ...  2.1410337e+01
  1.2096541e+01  3.6424692e+00]
FF:[ 6.43506250e+03 -6.48986000e+05  1.17611156e+05 ...  2.14103374e+01
  1.20965424e+01  3.64246750e+00]
[ True  True  True ...  True  True  True]
[2394]
Ok!
mismatch between /usr0/home/goliaro/Desktop/FlexFlow/tests/peft/hf_peft_tensors/bwd_step_0_layers.11.mlp.down_proj.gi_0 and /usr0/home/goliaro/Desktop/FlexFlow/build/inference_tensors/model_0_bwd-step_0_layer-num_11_layer-name_layers_11_feed_forward_w2_shard-id_0_input_0
HF: [ 6.4350547e+03 -6.4898600e+05  1.1761116e+05 .

AssertionError: 

In [None]:
# value states: torch.Size([1, 12, 24, 64])
value_states=torch.from_numpy(hf_kproj_grads_post_rotary).permute(2,0,1).unsqueeze(0)
key_states = value_states
cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
# query_states:  torch.Size([1, 12, 24, 64])
# key_states:  torch.Size([1, 12, 24, 64])
# position_ids:  torch.Size([1, 24])
# tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
#          18, 19, 20, 21, 22, 23]], device='cuda:0')
query_states = torch.zeros([1, 12, 24, 64])
position_ids = torch.arange(24).unsqueeze(0)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
key_states = key_states.squeeze()
print(key_states.shape)
print(key_states[0,:,:])
print(hf_kproj_grads_before_rotary.shape)
print(hf_kproj_grads_before_rotary[:,:,0])

torch.Size([12, 24, 64])
tensor([[-1.5730e-02, -4.1161e-02,  3.0593e-02,  ...,  3.8630e-01,
          3.2884e-01,  3.6067e-01],
        [-2.8613e+01, -5.5872e+00,  2.9385e+01,  ...,  3.8782e+01,
          9.6901e+01,  9.8470e+01],
        [ 3.3027e+00,  1.8276e-01, -1.8497e+00,  ..., -4.4052e+01,
         -2.0010e+01, -2.9788e+01],
        ...,
        [-7.6471e-02, -1.8892e-01,  3.6430e-01,  ..., -2.7493e-01,
          5.7017e-01, -1.5986e-01],
        [ 2.5780e+00, -1.8153e+00,  2.5088e+00,  ..., -1.0776e+01,
          6.2167e-01,  8.3755e-01],
        [-6.8324e-02,  1.7568e-01, -3.2311e-01,  ...,  3.1202e+00,
         -2.6652e-01, -1.1917e+00]])
(24, 64, 12)
[[-1.5729919e-02 -4.1160699e-02  3.0592799e-02 ...  3.8629669e-01
   3.2884139e-01  3.6066702e-01]
 [-2.8613457e+01 -5.5871558e+00  2.9384506e+01 ...  3.8781765e+01
   9.6900581e+01  9.8469597e+01]
 [ 3.3027239e+00  1.8275940e-01 -1.8496730e+00 ... -4.4052174e+01
  -2.0009745e+01 -2.9787930e+01]
 ...
 [-7.6470733e-02 -1.8891659e

In [None]:
torch.arange(24).unsqueeze(0).cuda()

tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23]], device='cuda:0')

In [None]:
layer_num = 11
hf_qk_prods_softmax = f"{hf_weight_base_path}/fwd_step_0_layers.11.self_attn.qk_prods_softmax"
ff_qk_prods_softmax = f"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_qk_prods_softmax"

hf_value_states = f"{hf_weight_base_path}/fwd_step_0_layers.11.self_attn.value_states"

hf_qk_prods_softmax = torch.load(hf_qk_prods_softmax)#.squeeze().T.detach().cpu().numpy()
ff_qk_prods_softmax = np.loadtxt(ff_qk_prods_softmax, delimiter=',').reshape((24, 24, 12), order = 'F')
print(hf_qk_prods_softmax.shape)
#print(ff_qk_prods_softmax.shape)
#print(hf_qk_prods_softmax[:,:,0])
#print()
#print(ff_qk_prods_softmax[:,:,0])

for head_idx in range(12):
    hf_qkps = hf_qk_prods_softmax.squeeze()[head_idx, :, :].detach().cpu().numpy()
    ff_qkps = ff_qk_prods_softmax[:,:,head_idx]
    assert(np.allclose(ff_qkps, hf_qkps, atol=1e-5))


hf_value_states = torch.load(hf_value_states)#.squeeze().T.detach().cpu().numpy()
print(hf_value_states.shape)
attn_output = torch.matmul(hf_qk_prods_softmax, hf_value_states)
print()
print(attn_output.shape)
print(attn_output.transpose(1, 2).contiguous().shape)
print("Hf attn heads")
print(torch.load("/usr0/home/goliaro/Desktop/FlexFlow/tests/peft/hf_peft_tensors/fwd_step_0_layers.11.self_attn.o_proj.input_0").shape)

print("Attn heads grads:")
hf_attn_heads_grads = f"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.self_attn.o_proj.gi_0"
print(torch.load(hf_attn_heads_grads).shape)
print("HF value grads:")
vproj_grads = f"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.self_attn.v_proj.gi_0"
print(torch.load(vproj_grads).shape)


torch.Size([1, 12, 24, 24])


AssertionError: 

In [None]:
a = torch.randn(2,3,4)
print(a.shape)
print(a.T.shape)

torch.Size([2, 3, 4])
torch.Size([4, 3, 2])


In [None]:
a = "./hf_peft_tensors/bwd_step_0_layers.11.post_attention_layernorm.gi_0"
b = "./hf_peft_tensors/bwd_step_0_layers.11.self_attn.o_proj.go_0"
a = torch.load(a)
b = torch.load(b)
print(a)
print(b)

tensor([[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
             0.0000],
         [  27.8890,  -21.5089,   45.8214,  ...,    5.4010,  -10.8787,
            39.7619],
         [  19.2197,   27.4681,  -68.7141,  ...,  102.3280,   66.7925,
          -160.8711],
         ...,
         [  63.9532,   17.4273,  -29.4416,  ...,  101.6105,   67.5937,
          -198.4432],
         [  31.2799,   13.0724,  -44.7179,  ...,  132.4898,   42.3135,
          -194.4037],
         [  42.3453,  -16.2693,  -55.7386,  ...,   90.5921,   52.2032,
          -124.1802]]], device='cuda:0')
tensor([[[-1.1845e+06, -6.7460e+05,  7.4494e+05,  ..., -9.1441e+05,
          -1.4912e+05,  3.5769e+06],
         [-7.3920e+01, -7.9389e+01,  1.1027e+02,  ..., -7.3020e+01,
          -2.3540e+01,  3.4587e+02],
         [-5.3885e+01, -1.7373e+01, -1.9780e+01,  ...,  4.1291e+01,
           5.5099e+01,  5.5910e+01],
         ...,
         [-2.1948e+01, -3.2109e+01,  2.8364e+01,  ...,  3.4321e+01,
           5

In [None]:
# # Manual matmul checks
# ff_w2_grad_out_tensor = np.loadtxt(ff_BWD_w2_out, delimiter=',').reshape((768,128), order='F')
# ff_w2_weight_tensor = np.loadtxt(ff_w2_weight, delimiter=',').reshape((3072,768), order='F')
# ff_w2_gradin_tensor = np.matmul(ff_w2_weight_tensor, ff_w2_grad_out_tensor).reshape((3072,128), order='F')

# ff_lora_gradout_tensor = np.loadtxt(ff_BWD_lora_B_out, delimiter=',').reshape((768,128), order='F')
# ff_lora_A_weight_tensor = np.loadtxt(ff_lora_A_weight, delimiter=',').reshape((3072,16), order='F')
# ff_lora_B_weight_tensor = np.loadtxt(ff_lora_B_weight, delimiter=',').reshape((16,768), order='F')
# ff_lora_int_grad_tensor = np.matmul(ff_lora_B_weight_tensor, ff_lora_gradout_tensor)
# ff_lora_gradint_tensor = np.matmul(ff_lora_A_weight_tensor, ff_lora_int_grad_tensor)

# # ff_w2_gradin_tensor = ff_w2_gradin_tensor + ff_lora_gradint_tensor
# #print(ff_w2_gradin_tensor[:,:24])
# print("calculated LORA grad in")
# print(ff_lora_gradint_tensor[:,:24])
# # ff_BWD_w2_in_pre_tensor = np.loadtxt(ff_BWD_w2_in_pre, delimiter=',').reshape((3072,128), order='F')
# ff_BWD_lora_A_in_tensor = np.loadtxt(ff_BWD_lora_A_in, delimiter=',').reshape((3072,128), order='F')
# print("FlexFlow LORA grad in")
# print(ff_BWD_lora_A_in_tensor[:,:24])
# # print(ff_BWD_w2_in_pre_tensor[:,:24])
# print("HF lora grad in")
# print(torch.load(hf_BWD_loraA_in).squeeze().T.detach().cpu().numpy())
# compare_tensors(hf_BWD_loraA_in, ff_BWD_lora_A_in)

# simulate act_fn_grad
# ssm_out_grad_tensor = np.loadtxt(ff_BWD_ssm_out, delimiter=',').reshape((3072,128), order='F')
# w3_fwd_out_tensor = np.loadtxt(ff_FWD_w3_out, delimiter=',').reshape((3072,128), order='F')
# #print(ssm_out_grad_tensor.shape, w3_fwd_out_tensor.shape)
# act_fn_out_check = np.multiply(ssm_out_grad_tensor, w3_fwd_out_tensor)
# print("simulated act fn out - simulated")
# print(act_fn_out_check[:,:24])
# print("simulated act fn out - HF")
# print(torch.load(hf_BWD_act_fn_out).detach().cpu().numpy().squeeze().T)

# Simulated w3_grad
# ssm_out_grad_tensor = np.loadtxt(ff_BWD_ssm_out, delimiter=',').reshape((3072,128), order='F')[:,:24]
# act_fnc_out_tensor = np.loadtxt(ff_FWD_act_fnc_out, delimiter=',').reshape((3072,24), order='F')
# w3_out_gard_check = np.multiply(ssm_out_grad_tensor, act_fnc_out_tensor)
# print("simulated w3 out - FF")
# print(w3_out_gard_check)
# ff_BWD_w3_out_tensor = np.loadtxt(ff_BWD_w3_out, delimiter=',').reshape((3072,128), order='F')
# hf_BWD_w3_out_tensor = torch.load(hf_BWD_w3_out).detach().cpu().numpy().squeeze().T
# print("w3 out, FF")
# print(ff_BWD_w3_out_tensor[:,:24])
# print("w3 out, HF")
# print(hf_BWD_w3_out_tensor)

# print_tensors(hf_BWD_w3_out, ff_BWD_w3_out, "w3 out")
# assert False
# print()
# print()
# print_tensors(hf_BWD_w3_out, ff_BWD_w3_out, "w3 out")
# print_tensors(hf_BWD_w3_in, ff_BWD_w3_in, "w3 in")
# print_tensors(hf_BWD_w1_out, ff_BWD_w1_out, "w1 out")
# print_tensors(hf_BWD_w1_in, ff_BWD_w1_in, "w1 in")
# print_tensors(hf_BWD_ffn_norm_out, ff_BWD_ffn_norm_out, "ffn norm out")
# print_tensors(hf_BWD_ffn_norm_in, ff_BWD_ffn_norm_in2, "ffn norm in")
# print()
# ff_w1_out_tensor = np.loadtxt(ff_BWD_w1_out, delimiter=',').reshape((3072,128), order='F')
# ff_w1_in_tensor = np.loadtxt(ff_BWD_w1_in, delimiter=',').reshape((768,128), order='F')
# ff_w1_in_pre_tensor = np.loadtxt(ff_BWD_w1_in_pre, delimiter=',').reshape((768,128), order='F')
# ff_w1_only_in_tensor = ff_w1_in_tensor - ff_w1_in_pre_tensor
# ff_w1_weight_tensor = np.loadtxt(ff_w1_weight, delimiter=',').reshape((768,3072), order='F')
# ff_w1_in_check_tensor = np.matmul(ff_w1_weight_tensor, ff_w1_out_tensor)
# print("W1 in (simulated):")
# print(ff_w1_in_check_tensor[:,:24])
# print("W1 in (FF):")
# print(ff_w1_only_in_tensor[:,:24])
# print("W1 in (HF):")
# print(torch.load(hf_BWD_w1_in).squeeze().T.detach().cpu().numpy())

# compare_tensors_difference(hf_BWD_w2_in, ff_BWD_w2_in, ff_BWD_lora_A_in)
# compare_tensors(hf_BWD_w3_out, ff_BWD_w3_out)
#compare_hf_tensors(hf_BWD_ffn_norm_in, hf_BWD_attn_out_out)
# print("\nw1 out:")

# print_tensors(hf_BWD_w1_out, ff_BWD_w1_out)
# print("\nW1 in\n")
# print_tensors(hf_BWD_w1_in, ff_BWD_w1_in)
# compare_tensors(hf_BWD_w1_in, ff_BWD_w1_in)
# print("\nffn_norm")
# compare_tensors(hf_BWD_ffn_norm_out, ff_BWD_ffn_norm_out)


In [None]:
for layer_num in range(12):
    hf_lora_A_weight_fp = f"{hf_weight_base_path}/base_model.model.model.layers.{layer_num}.mlp.down_proj.lora_A.default.weight"
    ff_lora_A_weight_fp = f"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_lora_shard-id_0_weight_A"
    compare_tensors(hf_lora_A_weight_fp, ff_lora_A_weight_fp, tolerance=1e-5)
    hf_lora_B_weight_fp = f"{hf_weight_base_path}/base_model.model.model.layers.{layer_num}.mlp.down_proj.lora_B.default.weight"
    ff_lora_B_weight_fp = f"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_lora_shard-id_0_weight_B"
    compare_tensors(hf_lora_B_weight_fp, ff_lora_B_weight_fp, tolerance=1e-5)
    hf_w1_weight = f"{hf_weight_base_path}/base_model.model.model.layers.{layer_num}.mlp.gate_proj.weight"
    ff_w1_weight = f"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w1_shard-id_0_weight_0"
    compare_tensors(hf_w1_weight, ff_w1_weight, tolerance=1e-5)
    hf_w3_weight = f"{hf_weight_base_path}/base_model.model.model.layers.{layer_num}.mlp.up_proj.weight"
    ff_w3_weight = f"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w3_shard-id_0_weight_0"
    compare_tensors(hf_w3_weight, ff_w3_weight, tolerance=1e-5)
    hf_w2_weight = f"{hf_weight_base_path}/base_model.model.model.layers.{layer_num}.mlp.down_proj.weight"
    ff_w2_weight = f"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_shard-id_0_weight_0"
    compare_tensors(hf_w2_weight, ff_w2_weight, tolerance=1e-5)
    

Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
Ok!
