# Interpreting LoRA Finetuning of LLaMa with Singular Value Decomposition

### Installs and imports packages

In [2]:
# get everything set up
# more rapidly install node
# !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
# # install repo with the data
# !git clone https://github.com/BerenMillidge/svd_directions
# %cd svd_directions
# !bash setup.sh

In [7]:
import transformers
import textwrap
from transformers import LlamaTokenizer, LlamaForCausalLM
import os
import sys
from typing import List


import fire
import torch
from datasets import load_dataset
import pandas as pd

import matplotlib.pyplot as plt
import matplotlib as mpl
# import seaborn as sns
from pylab import rcParams

%matplotlib inline
# sns.set(rc={'figure.figsize':(10, 7)})
# sns.set(rc={'figure.dpi':100})
# sns.set(style='white', palette='muted', font_scale=1.2)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE



'cuda'

In [8]:
import torch
from collections import Counter
import numpy as np
import matplotlib.pyplot as plt
# import seaborn as sns
import gc
from copy import deepcopy
from tqdm.auto import tqdm, trange
import re
from collections import defaultdict
from transformers import AutoModelForCausalLM, AutoTokenizer
# utils
import json
from torch import nn
import torch.nn.functional as F
from datasets import load_dataset
from copy import deepcopy
from torch.nn import functional as F
from tabulate import tabulate
from tqdm import tqdm, trange
import functools
import math

# this resets up the site so you don't have to restart the runtime to use pysvelte
import site
site.main()


# sns.set_palette('colorblind')
# cmap = sns.color_palette('colorblind')

## Load Dataset

We will be using the BTC Tweets Sentiment dataset4, which is available on Kaggle and contains around 50,000 tweets related to Bitcoin.

The format of the dataset5 in the original Alpaca repository consists of a JSON file that has a list of objects with instruction, input, and output strings.



In [9]:
!gdown 1xQ89cpZCnafsW5T3G3ZQWvR7q682t2BN

Downloading...
From: https://drive.google.com/uc?id=1xQ89cpZCnafsW5T3G3ZQWvR7q682t2BN
To: /home/ubuntu/notebooks/bitcoin-sentiment-tweets.csv
100%|████████████████████████████████████████| 242k/242k [00:00<00:00, 29.9MB/s]


In [10]:
df = pd.read_csv("bitcoin-sentiment-tweets.csv")
df.head()

Unnamed: 0,date,tweet,sentiment
0,Fri Mar 23 00:40:40 +0000 2018,@p0nd3ea Bitcoin wasn't built to live on excha...,1.0
1,Fri Mar 23 00:40:40 +0000 2018,@historyinflicks Buddy if I had whatever serie...,1.0
2,Fri Mar 23 00:40:42 +0000 2018,@eatBCH @Bitcoin @signalapp @myWickr @Samsung ...,0.0
3,Fri Mar 23 00:41:04 +0000 2018,@aantonop Even if Bitcoin crash tomorrow morni...,0.0
4,Fri Mar 23 00:41:07 +0000 2018,I am experimenting whether I can live only wit...,1.0


In [11]:
def sentiment_score_to_name(score: float):
    if score > 0:
        return "Positive"
    elif score < 0:
        return "Negative"
    return "Neutral"

dataset_data = [
    {
        "instruction": "Detect the sentiment of the tweet.",
        "input": row_dict["tweet"],
        "output": sentiment_score_to_name(row_dict["sentiment"])
    }
    for row_dict in df.to_dict(orient="records")
]

dataset_data[0]

{'instruction': 'Detect the sentiment of the tweet.',
 'input': "@p0nd3ea Bitcoin wasn't built to live on exchanges.",
 'output': 'Positive'}

In [12]:

with open("alpaca-bitcoin-sentiment-dataset.json", "w") as f:
   json.dump(dataset_data, f)

## Load pretrained-model weights
This code loads the pre-trained Llama model using the LlamaForCausalLM class from the Hugging Face Transformers library. The load_in_8bit=True parameter loads the model using 8-bit quantization to reduce memory usage and improve inference speed.

The code also loads the tokenizer for the same Llama model using the LlamaTokenizer class, and sets some additional properties for padding tokens. Specifically, it sets the pad_token_id to 0 to represent unknown tokens, and sets the padding_side to "left" to pad sequences on the left side.

In [13]:
BASE_MODEL = "decapoda-research/llama-7b-hf"

model = LlamaForCausalLM.from_pretrained(
    BASE_MODEL,
    # load_in_8bit=True,
    torch_dtype=torch.float16,
    # device_map="auto",
)

tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)

tokenizer.pad_token_id = (
    0  # unk. we want this to be different from the eos token
)
tokenizer.padding_side = "left"

Downloading (…)lve/main/config.json:   0%|          | 0.00/427 [00:00<?, ?B/s]

Downloading (…)model.bin.index.json:   0%|          | 0.00/25.5k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/33 [00:00<?, ?it/s]

Downloading (…)l-00001-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00002-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00003-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00004-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00005-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00006-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00007-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00008-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00009-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00010-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00011-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00012-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00013-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00014-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00015-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00016-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00017-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00018-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00019-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00020-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00021-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00022-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00023-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00024-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00025-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00026-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00027-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00028-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00029-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00030-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00031-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00032-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00033-of-00033.bin:   0%|          | 0.00/524M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/33 [00:00<?, ?it/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/2.00 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/141 [00:00<?, ?B/s]

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'LLaMATokenizer'. 
The class this function is called from is 'LlamaTokenizer'.


In [14]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=31999)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaR

## Dataset

In [15]:
data = load_dataset("json", data_files="alpaca-bitcoin-sentiment-dataset.json")
data["train"]

Downloading and preparing dataset json/default to /home/ubuntu/.cache/huggingface/datasets/json/default-005f1c3e35d3e366/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Dataset json downloaded and prepared to /home/ubuntu/.cache/huggingface/datasets/json/default-005f1c3e35d3e366/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]

Dataset({
    features: ['instruction', 'input', 'output'],
    num_rows: 1897
})

In [19]:
CUTOFF_LEN = 64

In [20]:
def generate_prompt(data_point):
    return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.  # noqa: E501
### Instruction:
{data_point["instruction"]}
### Input:
{data_point["input"]}
### Response:
{data_point["output"]}"""


def tokenize(prompt, add_eos_token=True):
    result = tokenizer(
        prompt,
        truncation=True,
        max_length=CUTOFF_LEN,
        padding=False,
        return_tensors=None,
    )
    if (
        result["input_ids"][-1] != tokenizer.eos_token_id
        and len(result["input_ids"]) < CUTOFF_LEN
        and add_eos_token
    ):
        result["input_ids"].append(tokenizer.eos_token_id)
        result["attention_mask"].append(1)

    result["labels"] = result["input_ids"].copy()

    return result

def generate_and_tokenize_prompt(data_point):
    full_prompt = generate_prompt(data_point)
    tokenized_full_prompt = tokenize(full_prompt)
    return tokenized_full_prompt

In [21]:
train_val = data["train"].train_test_split(
    test_size=200, shuffle=True, seed=42
)
train_data = (
    train_val["train"].map(generate_and_tokenize_prompt)
)
val_data = (
    train_val["test"].map(generate_and_tokenize_prompt)
)

Loading cached split indices for dataset at /home/ubuntu/.cache/huggingface/datasets/json/default-005f1c3e35d3e366/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-f37e27c74674db45.arrow and /home/ubuntu/.cache/huggingface/datasets/json/default-005f1c3e35d3e366/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-3de12248d3b3afe1.arrow


Map:   0%|          | 0/1697 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

### Interpretation of LLaMA using SVD

In [22]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=31999)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaR

## Training

In [15]:
LORA_R = 8
LORA_ALPHA = 16
LORA_DROPOUT= 0.05
LORA_TARGET_MODULES = [
    "q_proj",
    "v_proj",
]

BATCH_SIZE = 128
MICRO_BATCH_SIZE = 4
GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
LEARNING_RATE = 3e-4
TRAIN_STEPS = 300
OUTPUT_DIR = "experiments"



In [None]:
model = prepare_model_for_int8_training(model)
config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=LORA_TARGET_MODULES,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
model.print_trainable_parameters()

In [None]:
training_arguments = transformers.TrainingArguments(
    per_device_train_batch_size=MICRO_BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    warmup_steps=10,
    max_steps=TRAIN_STEPS,
    learning_rate=LEARNING_RATE,
    fp16=True,
    logging_steps=10,
    optim="adamw_torch",
    evaluation_strategy="steps",
    save_strategy="steps",
    eval_steps=10,
    save_steps=10,
    output_dir=OUTPUT_DIR,
    save_total_limit=3,
    load_best_model_at_end=True,
    report_to="tensorboard"
)

In [None]:
data_collator = transformers.DataCollatorForSeq2Seq(
    tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
)

In [None]:
trainer = transformers.Trainer(
    model=model,
    train_dataset=train_data,
    eval_dataset=val_data,
    args=training_arguments,
    data_collator=data_collator
)
model.config.use_cache = False
old_state_dict = model.state_dict
model.state_dict = (
    lambda self, *_, **__: get_peft_model_state_dict(
        self, old_state_dict()
    )
).__get__(model, type(model))

model = torch.compile(model)

trainer.train()
model.save_pretrained(OUTPUT_DIR)

In [None]:
from huggingface_hub import notebook_login

notebook_login()

model.push_to_hub("curiousily/alpaca-bitcoin-tweets-sentiment", use_auth_token=True)

## Inference

In [None]:
!git clone https://github.com/tloen/alpaca-lora.git
%cd alpaca-lora
!git checkout a48d947

In [None]:
!python generate.py \
    --load_8bit \
    --base_model 'decapoda-research/llama-7b-hf' \
    --lora_weights 'curiousily/alpaca-bitcoin-tweets-sentiment' \
    --share_gradio

### Interpretability

### LLaMa Model Architecture
We see that like GPT-2, LLaMa is a decoder-only Causal LM. There are 32 decoder blocks, consisting of:
- Self Attnetion (Masked)
- MLP
- Residual connections
- RMS Normalization
- SwiLU activation
- A hidden dimension of 4096
LLaMA is an autoregressive decoder only transformer. It has 32 blocks consisting of an attention and MLP layer. The attention layers have 16 heads. It has a residual dimension of 1024 dimensions and a typical sequence length also of 1024 tokens. The MLP hidden width is $4 \times 4096 = 4096$.

In [23]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=31999)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaR

In [24]:
for layer in model.state_dict().keys():
    if "layers.0" in layer:
        print(layer)

model.layers.0.self_attn.q_proj.weight
model.layers.0.self_attn.k_proj.weight
model.layers.0.self_attn.v_proj.weight
model.layers.0.self_attn.o_proj.weight
model.layers.0.self_attn.rotary_emb.inv_freq
model.layers.0.mlp.gate_proj.weight
model.layers.0.mlp.down_proj.weight
model.layers.0.mlp.up_proj.weight
model.layers.0.input_layernorm.weight
model.layers.0.post_attention_layernorm.weight


In [25]:
# # Load up the model and get all the key weight matrices.
# model, tokenizer, emb, device = get_model_tokenizer_embedding()
# my_tokenizer = tokenizer
# num_layers, num_heads, hidden_dim, head_size = get_model_info(model)
# all_tokens = [tokenizer.decode([i]) for i in range(tokenizer.vocab_size)]

# K,V = get_mlp_weights(model, num_layers = num_layers, hidden_dim = hidden_dim)
# W_Q_heads, W_K_heads, W_V_heads, W_O_heads = get_attention_heads(model, num_layers=num_layers, hidden_dim=hidden_dim, num_heads=num_heads, head_size = head_size)


In [26]:
def get_llama_de_embedding(model):
    """Embedding matrix is assumed orthogonal, so de-embedding matrix is its transpose"""
    return model.get_output_embeddings().weight.data.T.detach()
de_embedding = get_llama_de_embedding(model)
de_embedding.shape

torch.Size([4096, 32000])

In [27]:
def get_llama_info(model):
    info = {
        "num_layers" : model.config.num_hidden_layers, #number of hidden layers in the transformer encoder
        "hidden_dim" : model.config.hidden_size,
        "num_heads" : model.config.num_attention_heads, #number of attention heaqds for each attention layer in Transformer encoder
        "head_size" : model.config.hidden_size // model.config.num_attention_heads
    }
    return info

model_info = get_llama_info(model)
model_info

{'num_layers': 32, 'hidden_dim': 4096, 'num_heads': 32, 'head_size': 128}

In [28]:
all_tokens = [tokenizer.decode([i]) for i in range(tokenizer.vocab_size)]
print(len(all_tokens))
print(all_tokens[1000:1005])

32000
['ied', 'ER', 'stat', 'fig', 'me']


In [29]:
# def get_llama_mlp_weights(model,num_layers, hidden_dim, key="down_proj_weights"):
#     Ks = []
#     for j in range(num_layers):
#         down_param = f"model.layers.{j}.mlp.down_proj.weight"
#         up_param = f"model.layers.{j}.mlp.up_proj.weight"
#         gate_param = f"model.layers.{j}.mlp.gate_proj.weight"

#         deembedding = f'model.layers.{j}.self_attn.rotary_emb.inv_freq'
#         input_layer_norm = f'model.layers.{j}.input_layernorm.weight',
#         output_layer_norm = f'model.layers.{j}.post_attention_layernorm.weight'

#         parameters = {
#             # "down_proj_weights" : model.get_parameter(down_param),
#             # "up_proj_weights" : model.get_parameter(up_param),
#             "gate_proj_weights" : model.get_parameter(gate_param),
#             # "de_embedding" : model.get_parameter(deembedding),
#             # "input_layer_norm" : model.get_parameter(input_layer_norm),
#             # "output_layer_norm" : model.get_parameter(output_layer_norm),
#         }
#         param_of_interest = parameters[key]
#         Ks.append(param_of_interest)
#         # Ks.append(
#         #     torch.dequantize(param_of_interest)
#         #     )
#     Ks = torch.cat(Ks)
#     Ks = Ks.reshape(num_layers, -1, hidden_dim)
#     return Ks

# Ks = get_llama_mlp_weights(model,
#                           model_info["num_layers"],
#                           model_info["hidden_dim"],
#                           key="gate_proj_weights")

: 

: 

In [None]:
torch.cuda.empty_cache()

In [None]:
def get_max_token_length(tokens):
    maxlen = 0
    for t in tokens:
        l = len(t)
        if l > maxlen:
            maxlen = l
    return maxlen

def pad_with_space(t, maxlen):
    spaces_to_add = maxlen - len(t)
    for i in range(spaces_to_add):
        t += " "
    return t

def convert_to_tokens(indices, tokenizer, extended, extra_values_pos, strip=True, pad_to_maxlen=False):
    if extended:
        res = [tokenizer.convert_ids_to_tokens([idx])[0] if idx < len(tokenizer) else
               (f"[pos{idx-len(tokenizer)}]" if idx < extra_values_pos else f"[val{idx-extra_values_pos}]")
               for idx in indices]
    else:
        res = tokenizer.convert_ids_to_tokens(indices)
    if strip:
        res = list(map(lambda x: x[1:] if x[0] == 'Ġ' else "#" + x, res))
    if pad_to_maxlen:
      maxlen = get_max_token_length(res)
      res = list(map(lambda t: pad_with_space(t, maxlen), res))
    return res

def top_tokens(tokenizer, v_tok, k=100, only_english=False, only_ascii=True, with_values=False,
               exclude_brackets=False, extended=True, extra_values=None, pad_to_maxlen=False):
    v_tok = deepcopy(v_tok)
    ignored_indices = []
    if only_ascii:
        ignored_indices = [key for val, key in tokenizer.vocab.items() if not val.strip('Ġ').isascii()]
    if only_english:
        ignored_indices =[key for val, key in tokenizer.vocab.items() if not (val.strip('Ġ').isascii() and val.strip('Ġ[]').isalnum())]
    if exclude_brackets:
        ignored_indices = set(ignored_indices).intersection(
            {key for val, key in tokenizer.vocab.items() if not (val.isascii() and val.isalnum())})
        ignored_indices = list(ignored_indices)
    v_tok[ignored_indices] = -np.inf
    extra_values_pos = len(v_tok)
    if extra_values is not None:
        v_tok = torch.cat([v_tok, extra_values])
    values, indices = torch.topk(v_tok, k=k)
    res = convert_to_tokens(indices, tokenizer, extended=extended, extra_values_pos=extra_values_pos,pad_to_maxlen = pad_to_maxlen)
    if with_values:
        res = list(zip(res, values.cpu().numpy()))
    return res

def MLP_K_top_singular_vectors(K,emb,layer_idx, all_tokens, k=20,
                               N_singular_vectors=10, with_negative = False):
    W_matrix = K[layer_idx, :,:]
    U,S,V = torch.linalg.svd(W_matrix,full_matrices=False)
    print(V.shape)
    Vs = []

    for i in range(N_singular_vectors):
        acts = V[i,:].float() @ emb
        Vs.append(acts)

    Vs = [top_tokens(Vs[i].float().cpu(), k = k, pad_to_maxlen=True) for i in range(len(Vs))]
    # print(tabulate([*zip(*Vs)]))

emb = get_llama_de_embedding(model)
MLP_K_top_singular_vectors(Ks, emb, layer_idx = 22, k=20, N_singular_vectors= 50, all_tokens = all_tokens)

In [None]:
Ks

tensor([[[ -12,  -13,   79,  ...,  -25,  -28,   -8],
         [  -9,    8,  -52,  ...,  -22,   58,  -35],
         [  47,    5,   22,  ...,   34,   60,   44],
         ...,
         [  15,  -14,   13,  ...,   13,  -25,  -36],
         [ -17,   24,    9,  ...,   27,   -8,    4],
         [  25,   16,  -20,  ...,   62,   19,   -7]],

        [[  -4, -103,   30,  ...,   91,   -1,   -8],
         [ -21,   25,   36,  ...,   31,   -7,  -23],
         [ -75,  -43,  -32,  ...,  -21,    7,  -12],
         ...,
         [  25,   36,   21,  ...,   10,   51,  -15],
         [ -56,   18,   17,  ...,  -25,   19,    4],
         [  -7,    1,   83,  ...,   24,  -65,   18]],

        [[ -23,   19,   31,  ...,  -44,   29,  -59],
         [ -28,  -42,   36,  ...,  -38,   52,  -73],
         [  96,  -44,   43,  ...,  -58,   78,   -4],
         ...,
         [  13,   70,   10,  ...,   -1,   -7,   12],
         [  20,   -5,   49,  ...,  -36,   24,    3],
         [ -10,   -4,   10,  ...,  -17,   12,  -31]],

In [None]:
def get_attention_heads(model, num_layers, hidden_dim, num_heads, head_size):
  qkvs = []
  for j in range(num_layers):
    qkv = model.get_parameter(f"transformer.h.{j}.attn.c_attn.weight").detach().T
    ln_weight_1 = model.get_parameter(f"transformer.h.{j}.ln_1.weight").detach()

    qkv = qkv - torch.mean(qkv, dim=0)
    qkv = torch.einsum("oi,i -> oi", qkv, ln_weight_1)
    qkvs.append(qkv.T)

  W_Q, W_K, W_V = torch.cat(qkvs).chunk(3, dim=-1)
  W_O = torch.cat([model.get_parameter(f"transformer.h.{j}.attn.c_proj.weight") for j in range(num_layers)]).detach()
  W_V_heads = W_V.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
  W_O_heads = W_O.reshape(num_layers, num_heads, head_size, hidden_dim)
  W_Q_heads = W_Q.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
  W_K_heads = W_K.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
  return W_Q_heads, W_K_heads, W_V_heads, W_O_heads

In [None]:
W_Q_heads, W_K_heads, W_V_heads, W_O_heads = get_attention_heads(model, num_layers=num_layers, hidden_dim=hidden_dim, num_heads=num_heads, head_size = head_size)
