In [2]:
from transformers import AutoTokenizer, BitsAndBytesConfig
from transformers.models.llama.modeling_llama import * # type: ignore
# from modeling_llama_v1 import LlamaForCausalLM


class LlamaForCausalLM_v1(LlamaForCausalLM):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    @torch.no_grad()
    def show_cos_distance(
        self,
        input_ids: torch.LongTensor,
        layer_index: int,
    ) -> torch.Tensor:
        """
        Output the cosine distance between the input_hidden_states and output_hidden_states for a specific layer
        
        Args:
            layer_index (int): The layer index to use

        Returns:
            cosine_distance (torch.Tensor): The cosine distance between the input_hidden_states and output_hidden_states
        """
        assert layer_index > 0, "layer 0 does not have input_hidden_states"
        outputs = self.model(input_ids, output_hidden_states=True)
        input_hidden_states = outputs.hidden_states[layer_index - 1]
        output_hidden_states = outputs.hidden_states[layer_index]
        return F.cosine_similarity(input_hidden_states, output_hidden_states, dim=-1)

    @torch.no_grad()
    def show_topk_token(
        self,
        input_ids: torch.LongTensor,
        layer_index: int,
        k: int = 10,
    ) -> tuple[torch.Tensor, torch.LongTensor]:
        """
        Output the top k tokens for predicting the next token using a specific layer
        
        Args:
            tokenizer: The tokenizer to use
            model: The model to use
            input_text (str): The input text to use
            layer_index (int): The layer index to use
            k (int): The number of tokens to output

        Returns:
            values, tokens (tuple[torch.Tensor, list[str]]): A tuple containing the top k values and tokens
        """
        assert k <= self.vocab_size, "k cannot be greater than vocabulary size"
        outputs = self.model(input_ids, output_hidden_states=True)
        values, indices = torch.topk(self.lm_head(outputs.hidden_states[layer_index])[:, -1, :], k, dim=-1)
        return values, indices

    @torch.no_grad()
    def show_token_attention(
        self,
        input_ids: torch.LongTensor,
        layer_index: int,
        token_a_index: int,
        token_b_index: int,
    ):
        """
        Output the attention value between two tokens in the layer_index layer
        
        Args:
            tokenizer: The tokenizer to use
            model: The model to use
            input_text (str): The input text to use
            layer_index (int): The layer index to use
            token_a_index (int): The index of the first token
            token_b_index (int): The index of the second token

        Returns:
            attention (torch.Tensor): The attention value between the two tokens
        """
        # with `output_attentions=True`, calculation of attention falls back
        # to the original implementation instead of `torch.nn.functional.scaled_dot_product_attention`
        outputs = self.model(input_ids, output_attentions=True)
        # select the attention for the layer and average over the heads
        layer_attentions = outputs.attentions[layer_index].mean(dim=1)
        # select the attention between the two tokens
        return layer_attentions[:, token_a_index][:, token_b_index]


In [3]:
model_id = "/root/.cache/huggingface/hub/models--meta-llama--Meta-Llama-3-8B-Instruct/snapshots/e1945c40cd546c78e41f1151f4db032b271faeaa"
tokenizer = AutoTokenizer.from_pretrained(model_id)
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
model = LlamaForCausalLM_v1.from_pretrained(model_id, device_map="auto", quantization_config=quantization_config)
model.eval()

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

LlamaForCausalLM_v1(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), 

In [4]:
input_text = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\nWhat is the meaning of life?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
inputs = tokenizer(input_text, return_tensors="pt")

In [18]:
outputs = model(inputs.input_ids, output_hidden_states=True)
# select the attention for the layer and average over the heads
layer_index = 32
torch.nn.functional.cosine_similarity(outputs.hidden_states[layer_index], outputs.hidden_states[layer_index - 1], dim=-1)

tensor([[0.3069, 0.3069, 0.5806, 0.6885, 0.5894, 0.6470, 0.6304, 0.6709, 0.6392,
         0.6230, 0.6548, 0.6675, 0.5933, 0.4868, 0.5801, 0.5083, 0.5112, 0.5103,
         0.3796, 0.3926, 0.4805, 0.3530, 0.5088, 0.5596, 0.6055, 0.5317, 0.5376,
         0.6255, 0.6484]], dtype=torch.float16, grad_fn=<SumBackward1>)

In [19]:
torch.nn.functional.cosine_similarity(outputs.hidden_states[layer_index - 1], outputs.hidden_states[layer_index], dim=-1)

tensor([[0.3069, 0.3069, 0.5806, 0.6885, 0.5894, 0.6470, 0.6304, 0.6709, 0.6392,
         0.6230, 0.6548, 0.6675, 0.5933, 0.4868, 0.5801, 0.5083, 0.5112, 0.5103,
         0.3796, 0.3926, 0.4805, 0.3530, 0.5088, 0.5596, 0.6055, 0.5317, 0.5376,
         0.6255, 0.6484]], dtype=torch.float16, grad_fn=<SumBackward1>)

In [None]:
"""
`show_topk_token` method
"""
input_text = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\nWhat is the meaning of life?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
inputs = tokenizer(input_text, return_tensors="pt")
values, indices = model.show_topk_token(inputs.input_ids, -1, 10)
tokenizer.batch_decode(indices.t())

In [None]:
device = "cuda"

messages = [
    {"role": "user", "content": "What is the meaning of life?"},
]

input_ids = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    return_tensors="pt"
).to(model.device)

outputs = model.generate(
    input_ids,
    max_new_tokens=24,
    eos_token_id=tokenizer.eos_token_id,
    do_sample=False,
    temperature=None,
    top_p=None,
)
response = outputs[0][input_ids.shape[-1]:]
print(tokenizer.decode(response, skip_special_tokens=True))

# The meaning of life! This is a question that has puzzled philosophers, theologians, scientists, and everyday people for centuries