In [1]:
from llama.modeling_llama import LlamaForCausalLM
from llama.configuration_llama import LlamaConfig
from transformers import AutoTokenizer
import torch
from copy import deepcopy
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = LlamaForCausalLM.from_pretrained("/root/mfx/huggingface/meta-llama/Meta-Llama-3-8B", device_map='auto', attn_implementation="sdpa", partial_rotary_factor=1, rope_repeat=True)
tokenizer = AutoTokenizer.from_pretrained("/root/mfx/huggingface/meta-llama/Meta-Llama-3-8B")
model

Loading checkpoint shards: 100%|██████████| 4/4 [00:10<00:00,  2.70s/it]
Some weights of LlamaForCausalLM were not initialized from the model checkpoint at /root/mfx/huggingface/meta-llama/Meta-Llama-3-8B and are newly initialized: ['model.layers.0.self_attn.k_proj.bias', 'model.layers.0.self_attn.k_up_proj.weight', 'model.layers.0.self_attn.q_proj.bias', 'model.layers.0.self_attn.v_proj.bias', 'model.layers.0.self_attn.v_up_proj.weight', 'model.layers.1.self_attn.k_proj.bias', 'model.layers.1.self_attn.k_up_proj.weight', 'model.layers.1.self_attn.q_proj.bias', 'model.layers.1.self_attn.v_proj.bias', 'model.layers.1.self_attn.v_up_proj.weight', 'model.layers.10.self_attn.k_proj.bias', 'model.layers.10.self_attn.k_up_proj.weight', 'model.layers.10.self_attn.q_proj.bias', 'model.layers.10.self_attn.v_proj.bias', 'model.layers.10.self_attn.v_up_proj.weight', 'model.layers.11.self_attn.k_proj.bias', 'model.layers.11.self_attn.k_up_proj.weight', 'model.layers.11.self_attn.q_proj.bias', 'mod

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaMLAttention(
          (k_proj): Linear(in_features=4096, out_features=1024, bias=True)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=True)
          (q_proj): Linear(in_features=4096, out_features=4096, bias=True)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_up_proj): Linear(in_features=1024, out_features=4096, bias=False)
          (v_up_proj): Linear(in_features=1024, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
  

In [3]:
hidden_size = model.config.hidden_size
n_heads = model.config.num_attention_heads
kv_heads = model.config.num_key_value_heads
head_dim = model.config.hidden_size//model.config.num_attention_heads
latent_dim = kv_heads * head_dim
kv_groups = model.config.num_attention_heads // model.config.num_key_value_heads
model.config.partial_rotary_factor

1

In [4]:
# Insert identity matrices
for name,module in model.named_modules():
    if 'k_up_proj' in name or "v_up_proj" in name:
        module.weight.data = torch.stack([torch.eye(latent_dim).reshape(kv_heads, head_dim, latent_dim)]*kv_groups,dim=1).reshape(hidden_size, latent_dim).contiguous().to(module.weight.data.device,module.weight.data.dtype)

In [5]:
output = model.generate(**tokenizer("Tall me a story",return_tensors="pt").to("cuda:1"), max_new_tokens=500, do_sample=False)
print(tokenizer.batch_decode(output)[0])

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


<|begin_of_text|>Tall me a story
I am a tall woman. I am 5'10" and I have been tall my whole life. I have always been the tallest person in the room. I have always been the tallest person in my family. I have always been the tallest person in my class. I have always been the tallest person in my group of friends. I have always been the tallest person in my social circle. I have always been the tallest person in my social circle. I have always been the tallest person in my social circle. I have always been the tallest person in my social circle. I have always been the tallest person in my social circle. I have always been the tallest person in my social circle. I have always been the tallest person in my social circle. I have always been the tallest person in my social circle. I have always been the tallest person in my social circle. I have always been the tallest person in my social circle. I have always been the tallest person in my social circle. I have always been the tallest perso

In [6]:
for name,module in model.named_modules():
    if name.endswith("self_attn"):
        # Orthogonal q_proj and k_up_proj
        k_up_weight = deepcopy(module.k_up_proj.weight.data).reshape(n_heads, head_dim, latent_dim) # (n_heads, head_dim, latent_dim)
        q_weight = deepcopy(module.q_proj.weight.data).reshape(n_heads, head_dim, hidden_size) # (n_heads, head_dim, hidden_size)
        if module.q_proj.bias is not None:
            q_weight = torch.cat([q_weight,deepcopy(module.q_proj.bias.data).reshape(n_heads, head_dim, 1)],dim=-1)
        q_k_up = torch.einsum("hdc,hdD->hcD",k_up_weight, q_weight) # (n_heads, latent_dim, hidden_size), rank<=head_dim
        U,S,V = torch.svd_lowrank(q_k_up, head_dim, niter=16) # U(n_heads, latent_dim, head_dim), S(n_heads, head_dim), V(n_heads, hidden_size, head_dim)
        US_sqrt = torch.einsum('hLd,hd->hdL',U,torch.sqrt(S)) # (n_heads, head_dim, latent_dim)
        S_sqrtV = torch.einsum('hd,hDd->hdD',torch.sqrt(S),V) # (n_heads, head_dim, hidden_size)
        if module.q_proj.bias is not None:
            module.q_proj.bias.data = S_sqrtV[:,:,-1].reshape(-1).contiguous()
            S_sqrtV = S_sqrtV[:,:,:-1]
        module.k_up_proj.weight.data = US_sqrt.reshape(n_heads*head_dim, latent_dim).contiguous()
        module.q_proj.weight.data = S_sqrtV.reshape(n_heads*head_dim, hidden_size).contiguous()

        # Orthogonal o_proj and v_up_proj
        v_up_weight = deepcopy(module.v_up_proj.weight.data).reshape(n_heads, head_dim, latent_dim)
        o_weight = deepcopy(module.o_proj.weight.data).reshape(hidden_size, n_heads, head_dim)
        v_up_o = torch.einsum("hdc,Dhd->hcD",v_up_weight, o_weight) # (n_heads, latent_dim, hidden_size), rank<=head_dim
        U,S,V = torch.svd_lowrank(v_up_o, head_dim, niter=16) # U(n_heads, latent_dim, head_dim), S(n_heads, head_dim), V(n_heads, hidden_size, head_dim)
        US_sqrt = torch.einsum('hLd,hd->hdL',U,torch.sqrt(S)) # (n_heads, head_dim, latent_dim)
        S_sqrtV = torch.einsum('hd,hDd->Dhd',torch.sqrt(S),V) # (hidden_size, n_heads, head_dim)
        module.v_up_proj.weight.data = US_sqrt.reshape(hidden_size, latent_dim).contiguous()
        module.o_proj.weight.data = S_sqrtV.reshape(hidden_size, n_heads*head_dim).contiguous()

In [7]:
output = model.generate(**tokenizer("Tall me a story",return_tensors="pt").to("cuda:1"), max_new_tokens=500, do_sample=False)
print(tokenizer.batch_decode(output)[0])

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


<|begin_of_text|>Tall me a story
I am a tall woman. I am 5'10" and I have been tall my whole life. I have always been the tallest person in the room. I have always been the tallest person in my family. I have always been the tallest person in my class. I have always been the tallest person in my group of friends. I have always been the tallest person in my social circle. I have always been the tallest person in my social circle. I have always been the tallest person in my social circle. I have always been the tallest person in my social circle. I have always been the tallest person in my social circle. I have always been the tallest person in my social circle. I have always been the tallest person in my social circle. I have always been the tallest person in my social circle. I have always been the tallest person in my social circle. I have always been the tallest person in my social circle. I have always been the tallest person in my social circle. I have always been the tallest perso

In [9]:
for name,module in model.named_modules():
    if name.endswith("self_attn"):
        # Absorb k_up_proj into q_proj
        k_up_weight = deepcopy(module.k_up_proj.weight.data).reshape(n_heads, head_dim, latent_dim) # (n_heads, head_dim, latent_dim)
        q_weight = deepcopy(module.q_proj.weight.data).reshape(n_heads, head_dim, hidden_size) # (n_heads, head_dim, hidden_size)
        if module.q_proj.bias is not None:
            q_weight = torch.cat([q_weight,deepcopy(module.q_proj.bias.data).reshape(n_heads, head_dim, 1)],dim=-1)
        q_k_up = torch.einsum("hdc,hdD->hcD",k_up_weight, q_weight) # (n_heads, latent_dim, hidden_size), rank<=head_dim
        q_proj = torch.nn.Linear(hidden_size, n_heads*latent_dim, bias=(module.q_proj.bias is not None))
        q_proj = q_proj.to(device=module.q_proj.weight.device, dtype=module.q_proj.weight.dtype)
        if module.q_proj.bias is not None:
            q_proj.bias.data = q_k_up[:,:,-1].reshape(-1).contiguous()
            q_k_up = q_k_up[:,:,:-1]
        q_proj.weight.data = q_k_up.reshape(n_heads*latent_dim, hidden_size).contiguous()
        setattr(module, "q_proj", q_proj)
        delattr(module, "k_up_proj")
        # Absorb v_up_proj into o_proj
        v_up_weight = deepcopy(module.v_up_proj.weight.data).reshape(n_heads, head_dim, latent_dim) # (n_heads, head_dim, latent_dim)
        o_weight = deepcopy(module.o_proj.weight.data).reshape(hidden_size, n_heads, head_dim) # (n_heads, head_dim, hidden_size)
        v_up_o = torch.einsum("hdc,Dhd->Dhc",v_up_weight, o_weight) # (hidden_size, n_heads, latent_dim), rank<=head_dim
        o_proj = torch.nn.Linear(n_heads*latent_dim, hidden_size, bias=(module.o_proj.bias is not None))
        o_proj = o_proj.to(device=module.o_proj.weight.device, dtype=module.o_proj.weight.dtype)
        o_proj.weight.data = v_up_o.reshape(hidden_size, n_heads*latent_dim).contiguous()
        if module.o_proj.bias is not None:
            o_proj.bias.data = module.o_proj.bias
        setattr(module, "o_proj", o_proj)
        delattr(module, "v_up_proj")
        module.absorb = True

In [8]:
output = model.generate(**tokenizer("Tall me a story",return_tensors="pt").to("cuda:1"), max_new_tokens=500, do_sample=False)
print(tokenizer.batch_decode(output)[0])

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


<|begin_of_text|>Tall me a story
I am a tall woman. I am 5'10" and I have been tall my whole life. I have always been the tallest person in the room. I have always been the tallest person in my family. I have always been the tallest person in my class. I have always been the tallest person in my group of friends. I have always been the tallest person in my social circle. I have always been the tallest person in my social circle. I have always been the tallest person in my social circle. I have always been the tallest person in my social circle. I have always been the tallest person in my social circle. I have always been the tallest person in my social circle. I have always been the tallest person in my social circle. I have always been the tallest person in my social circle. I have always been the tallest person in my social circle. I have always been the tallest person in my social circle. I have always been the tallest person in my social circle. I have always been the tallest perso