# MHA PoC

In [124]:
import torch

embed_dim=64
kdim=5
vdim=6
num_heads=16

seqlen = 2
batch_size = 3

mha = torch.nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, kdim=kdim, vdim=vdim, bias=False)

x = torch.rand(seqlen, batch_size, embed_dim)  # (sequence_length, batch_size, embed_dim)
y = torch.rand(seqlen, batch_size, kdim)  # (sequence_length, batch_size, kdim)
z = torch.rand(seqlen, batch_size, vdim)  # (sequence_length, batch_size, vdim)

attn_output, attn_output_weights = mha(x, y, z)
print("Attention output shape:", attn_output.shape)  # (sequence_length, batch_size, embed_dim)

Attention output shape: torch.Size([2, 3, 64])


If self-attention (kdim == vdim)
    in_proj (3*embed_dim, embed_dim)
else:
    w_q (embed_dim, embed_dim)
    w_k (embed_dim, kdim)
    w_v (embed_dim, vdim)

out_proj (embed_dim, embed_dim)

In [57]:
sd = mha.state_dict()
for k, v in sd.items():
    print(f"{k}: {v.shape}")

q_proj_weight: torch.Size([32, 32])
k_proj_weight: torch.Size([32, 5])
v_proj_weight: torch.Size([32, 6])
out_proj.weight: torch.Size([32, 32])


In [None]:
import copy

# def get_perm_idxs(n_c):
#     return torch.flip(torch.arange(0, n_c, dtype=int), dims=[0])

def get_perm_idxs(n_c):
    return torch.randperm(n_c, dtype=int)



def mha_perm(sd, inplace:bool=False):

    if inplace:
        sd_perm = sd
    else:
        sd_perm = copy.deepcopy(sd)

    q_proj_weight_key = 'q_proj_weight'
    k_proj_weight_key = 'k_proj_weight'
    v_proj_weight_key = 'v_proj_weight'

    in_proj_bias_key = 'in_proj_bias'

    out_proj_weight_key = 'out_proj.weight'
    out_proj_bias_key = 'out_proj.bias'

    embed_dim, _ = sd_perm[q_proj_weight_key].shape

    params_per_head = embed_dim // num_heads

    perm_idxs = [get_perm_idxs(params_per_head)+(i*params_per_head) for i in range(num_heads)]
    perm_idxs = torch.cat(perm_idxs, dim=0)

    # Permute the weights
    sd_perm[q_proj_weight_key] = sd_perm[q_proj_weight_key][perm_idxs, ...]
    sd_perm[k_proj_weight_key] = sd_perm[k_proj_weight_key][perm_idxs, ...]
    sd_perm[v_proj_weight_key] = sd_perm[v_proj_weight_key][perm_idxs, ...]
    sd_perm[out_proj_weight_key] = sd_perm[out_proj_weight_key][..., perm_idxs]

    return sd_perm

sd_orig = mha.state_dict()
sd_perm = mha_perm(sd_orig, inplace=False)

output_orig, attn_orig_weights = mha(x, y, z)

mha.load_state_dict(sd_perm)
output_perm, attn_perm_weights = mha(x, y, z)

mha.load_state_dict(sd_orig)

torch.testing.assert_close(output_orig, output_perm)
torch.testing.assert_close(attn_orig_weights, attn_perm_weights)

# Llama-3.2-1B / Llama-2 7B

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

# model_id = "meta-llama/Llama-2-7b-hf"
model_id = "meta-llama/Llama-3.2-1B"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)

In [5]:
cfg = model.config

print(f"num_attention_heads: {cfg.num_attention_heads}")
print(f"num_key_value_heads: {cfg.num_key_value_heads}")
print(f"head_dim: {cfg.head_dim}")
print(f"hidden_size: {cfg.hidden_size}")

num_attention_heads: 32
num_key_value_heads: 32
head_dim: 128
hidden_size: 4096


In [None]:
import torch
import copy

def get_perm_idxs(n_c):
    # return torch.flip(torch.arange(0, n_c, dtype=int), dims=[0])
    return torch.randperm(n_c, dtype=int)

def permute_llama2_layer(
    sd,
    layer_name="model.layers.0",
    inplace:bool=False,
    num_heads=32,

    attn_perm:bool=True,
    mlp_perm:bool=True,

    gqa:bool=True,
):
    if inplace:
        sd_perm = sd
    else:
        sd_perm = copy.deepcopy(sd)

    q_proj_weight_key = f"{layer_name}.self_attn.q_proj.weight"
    k_proj_weight_key = f"{layer_name}.self_attn.k_proj.weight"
    v_proj_weight_key = f"{layer_name}.self_attn.v_proj.weight"
    o_proj_weight_key = f"{layer_name}.self_attn.o_proj.weight"

    embed_dim, _ = sd_perm[q_proj_weight_key].shape
    kv_dim, _ = sd_perm[k_proj_weight_key].shape
    head_dim = embed_dim // num_heads

    n_kv_heads = kv_dim // head_dim
    params_per_kv_head = kv_dim // n_kv_heads

    group_size = embed_dim // kv_dim

    if attn_perm:
        if gqa:
            kv_idxs = torch.arange(kv_dim, dtype=int).reshape((n_kv_heads, params_per_kv_head))
            kv_head_perm_idxs = get_perm_idxs(n_kv_heads)
            kv_perm_idxs = kv_idxs[kv_head_perm_idxs, :].reshape(-1)

            q_idxs = torch.arange(embed_dim, dtype=int).reshape((n_kv_heads, group_size, params_per_kv_head))
            q_idxs = q_idxs[kv_head_perm_idxs, :, :].reshape(-1)

            sd_perm[q_proj_weight_key] = sd_perm[q_proj_weight_key][q_idxs, ...]
            sd_perm[k_proj_weight_key] = sd_perm[k_proj_weight_key][kv_perm_idxs, ...]
            sd_perm[v_proj_weight_key] = sd_perm[v_proj_weight_key][kv_perm_idxs, ...]
            sd_perm[o_proj_weight_key] = sd_perm[o_proj_weight_key][..., q_idxs]

        else:
            idxs = torch.arange(embed_dim, dtype=int).reshape((num_heads, head_dim))
            head_perm_idxs = get_perm_idxs(num_heads)
            perm_idxs = idxs[head_perm_idxs, :].reshape(-1)

            # # # Permute the weights
            sd_perm[q_proj_weight_key] = sd_perm[q_proj_weight_key][perm_idxs, ...]
            sd_perm[k_proj_weight_key] = sd_perm[k_proj_weight_key][perm_idxs, ...]
            sd_perm[v_proj_weight_key] = sd_perm[v_proj_weight_key][perm_idxs, ...]
            sd_perm[o_proj_weight_key] = sd_perm[o_proj_weight_key][..., perm_idxs]


    if mlp_perm:
        mlp_gate_weight_key = f"{layer_name}.mlp.gate_proj.weight"
        mlp_up_proj_weight_key = f"{layer_name}.mlp.up_proj.weight"
        mlp_down_proj_weight_key = f"{layer_name}.mlp.down_proj.weight"

        gate_hidden_dim, _ = sd_perm[mlp_gate_weight_key].shape
        perm_idxs_mlp = get_perm_idxs(gate_hidden_dim)
        sd_perm[mlp_gate_weight_key] = sd_perm[mlp_gate_weight_key][perm_idxs_mlp, ...]
        sd_perm[mlp_up_proj_weight_key] = sd_perm[mlp_up_proj_weight_key][perm_idxs_mlp, ...]
        sd_perm[mlp_down_proj_weight_key] = sd_perm[mlp_down_proj_weight_key][..., perm_idxs_mlp]

    return sd_perm

def permute_llama2_all_layers(sd,layers= range(0, 32),inplace:bool=False):
    if inplace:
        sd_perm = sd
    else:
        sd_perm = copy.deepcopy(sd)

    for i in layers:
        layer_name = f"model.layers.{i}"
        permute_llama2_layer(sd_perm, layer_name=layer_name, inplace=True)

    return sd_perm

sd_orig = copy.deepcopy(model.to('cpu').state_dict())
sd_perm = permute_llama2_all_layers(sd_orig, layers=range(0, 16), inplace=False)

In [3]:
from notebook_utils import *

model.load_state_dict(sd_orig)
results_orig = eval_on_sqad_ds(model, tokenizer)
print(f"Results for original model: {results_orig}")
# Results for original model: {'exact_match': 0.0, 'f1': 28.052425969092635}

model.load_state_dict(sd_perm)
results_perm = eval_on_sqad_ds(model, tokenizer)
# model.load_state_dict(sd_orig)

print(f"Results for permuted model: {results_perm}")

6it [00:01,  3.41it/s]                       


Results for original model: {'exact_match': 0.0, 'f1': 16.58730158730159}


6it [00:00,  7.65it/s]                       

Results for permuted model: {'exact_match': 0.0, 'f1': 0.0}





In [None]:
def extract_weights_sd(sd):
    ws = [w.cpu().detach().flatten() for w in sd.values()]
    w = torch.concatenate(ws)

    return w

def compare_sds(sd_orig, sd_perm):
    n_weights = 0
    for k, v in sd_orig.items():
        n_weights += v.numel()
    n_changed_w = 0
    for k, v in sd_orig.items():
        n_changed_w += torch.sum(sd_orig[k] != sd_perm[k]).item()
    
    return n_weights, n_changed_w

n_weights, n_changed_w = compare_sds(sd_orig, sd_perm)

# print(f"Model: {model_name}")
print(f"Number of weights: {n_weights}")
print(f"Number of changed weights: {n_changed_w}")
print(f"Percentage of changed weights: {n_changed_w / n_weights:.2%}")

Number of weights: 1498482688
Number of changed weights: 889112217
Percentage of changed weights: 59.33%
