In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from warnings import filterwarnings

filterwarnings('ignore')

In [2]:
device = torch.device("cuda:5" if torch.cuda.is_available() else "cpu")

def get_device_memory_report(device):
    print(f'Device: {device} [{torch.cuda.get_device_name(device)}]')
    free_memory, total_memory = torch.cuda.mem_get_info(device)
    
    free_memory_gb = free_memory / (1024 ** 3)
    total_memory_gb = total_memory / (1024 ** 3)
    
    print(f"Free Memory: {free_memory_gb:.2f}/{total_memory_gb:.2f} GB [{free_memory / total_memory * 100:.2f}%]")

get_device_memory_report(device)

Device: cuda:5 [NVIDIA RTX 6000 Ada Generation]
Free Memory: 47.08/47.50 GB [99.11%]


In [3]:
from transformers import AutoTokenizer, AutoModelForCausalLM

def load_model(model_id="meta-llama/Llama-3.2-1B", device="cuda:1"):

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.float16, 
        device_map=device,
        attn_implementation="eager"
    )
    model.to(device)
    tokenizer.pad_token = tokenizer.eos_token
    return model, tokenizer

In [4]:
mistral_model = "mistralai/Mistral-7B-Instruct-v0.1"
llama_model = "meta-llama/Llama-3.1-8B"

llama_1B, llama_1B_tokenizer = load_model(model_id=llama_model, device=device)
llama_1B

Loading checkpoint shards: 100%|██████████| 4/4 [00:05<00:00,  1.31s/it]


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps=1e-05)
    (rotary_

In [5]:
import random

def randomly_permute_sentence(sentence):
    words = sentence.split()
    random.shuffle(words)
    return ' '.join(words)

# sentence = randomly_permute_sentence('A niece of most senators haven\'t descended most slopes')

In [17]:
good_sentence = 'Who should Derek hug after shocking Richard?'
bad_sentence = 'Who should Derek hug Richard after shocking?'

# sentence = randomly_permute_sentence(sentence)

__LLM_prompt = f"""
    Which sentence is more grammatical and native-like?
    1) {bad_sentence}
    2) {good_sentence}
"""
# sentence = 'Hello how are you?'
tokens = llama_1B_tokenizer(__LLM_prompt, return_tensors="pt")
tokens = {k: v.to(llama_1B.device) for k, v in tokens.items()}

with torch.no_grad():
    outputs = llama_1B(
        **tokens,
        output_hidden_states=True, 
        output_attentions=True,
        return_dict=True,
    )

    generation = llama_1B.generate(
        **tokens,
        do_sample=True,  # Enable sampling for more creative outputs
        max_new_tokens=5,
        top_k=50,       # Optional: Limits sampling to top-k tokens
        top_p=0.95,     # Optional: Nucleus sampling
        temperature=0.1 # Optional: Controls randomness
    )

text = llama_1B_tokenizer.decode(generation[0], skip_special_tokens=False)
print(text)

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


<|begin_of_text|>
    Which sentence is more grammatical and native-like?
    1) Who should Derek hug Richard after shocking?
    2) Who should Derek hug after shocking Richard?
    3) Who


In [22]:
import transformers

pipeline = transformers.pipeline(
    "text-generation",
    model="mistralai/Mistral-7B-Instruct-v0.1",
    # tokenizer=llama_1B_tokenizer,
    # device=0
)

messages = [
    {
        "role": "system",
        "content": "You are a helpful assistant that will help me understand which sentence is grammatically correct."
    },
    {
        "role": "user",
        "content": f"Which sentence is more grammatical and native-like? 1) {bad_sentence} 2) {good_sentence}"
    }
]

response = pipeline(messages)
print(response)

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.09it/s]
Device set to use cuda:0


OutOfMemoryError: CUDA out of memory. Tried to allocate 224.00 MiB. GPU 0 has a total capacity of 47.50 GiB of which 50.31 MiB is free. Including non-PyTorch memory, this process has 47.44 GiB memory in use. Of the allocated memory 47.03 GiB is allocated by PyTorch, and 367.50 KiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [24]:

generation = 

Question, are you? I am a. thank you. I


In [60]:
embeddings = outputs.hidden_states[0]

Q = llama_1B.base_model.layers[0].self_attn.q_proj.weight
K = llama_1B.base_model.layers[0].self_attn.k_proj.weight
V = llama_1B.base_model.layers[0].self_attn.v_proj.weight

print('Query: ', Q.shape)
print('Key: ', K.shape)
print('Value: ', V.shape)

Query:  torch.Size([2048, 2048])
Key:  torch.Size([512, 2048])
Value:  torch.Size([512, 2048])


In [61]:
query_states = embeddings @ Q.T
key_states = embeddings @ K.T
value_states = embeddings @ V.T

print('Query: ', query_states.shape)
print('Key: ', key_states.shape)
print('Value: ', value_states.shape)

n_tokens = query_states.shape[1]
print('n_tokens: ', n_tokens)

Query:  torch.Size([1, 11, 2048])
Key:  torch.Size([1, 11, 512])
Value:  torch.Size([1, 11, 512])
n_tokens:  11


In [62]:
query_states = query_states.view(n_tokens, 16, 128)

key_states = key_states.view(n_tokens, 16, 32)

value_states = value_states.view(n_tokens, 16, 32)

In [63]:
# For each head h in range(16):
import math

head_outputs = torch.zeros(n_tokens, 16, 32)

for h in range(16):
    # Extract states for this head
    q_h = query_states[:, h, :]  # [n_tokens, 128]
    k_h = key_states[:, h, :]    # [n_tokens, 32]
    v_h = value_states[:, h, :]  # [n_tokens, 32]
    
    print(q_h.shape)
    print(k_h.shape)
    print(v_h.shape)
    break
    
    # Compute attention scores: [n_tokens, n_tokens]
    # We compute this as q_h @ k_h.T which gives attention from each token to each other
    attention_scores_h = torch.matmul(q_h, k_h.transpose(0, 1)) / math.sqrt(32)
    
    # Apply causal mask if needed
    # masked_attention_scores_h = apply_mask(attention_scores_h)
    
    # Apply softmax to get attention weights: [n_tokens, n_tokens]
    attention_weights_h = torch.softmax(attention_scores_h, dim=-1)
    
    # Apply attention weights to values: [n_tokens, 32]
    head_output_h = torch.matmul(attention_weights_h, v_h)
    
    # Store the output for this head
    head_outputs[:, h, :] = head_output_h

torch.Size([11, 128])
torch.Size([11, 32])
torch.Size([11, 32])


In [64]:
def get_head_contributions(model, tokens, device, layer_idx=0):

    model.eval()
    
    with torch.no_grad():
        outputs = model(
            **tokens,
            output_hidden_states=True,
            output_attentions=True,
            return_dict=True
        )
        
        layer = model.base_model.layers[layer_idx]
        attn = layer.self_attn
        
        if layer_idx == 0:
            prev_hidden_states = outputs.hidden_states[0]
        else:
            prev_hidden_states = outputs.hidden_states[layer_idx]
        
        normalized_hidden = layer.input_layernorm(prev_hidden_states)
        
        config = model.config
        num_heads = config.num_attention_heads
        num_kv_heads = config.num_key_value_heads
        head_dim = config.hidden_size // num_heads
        
        batch_size, seq_len = tokens['input_ids'].shape
        print(batch_size, seq_len)
        
        # 1. Get the Q, K, V projection matrices
        q_proj = attn.q_proj
        k_proj = attn.k_proj
        v_proj = attn.v_proj
        o_proj = attn.o_proj
        
        # 2. Compute Q, K, V projections
        q = q_proj(normalized_hidden) 
        k = k_proj(normalized_hidden) 
        v = v_proj(normalized_hidden) 
        
        kv_dim = k.shape[-1]
        hidden_size = q.shape[-1]
        kv_head_dim = kv_dim // num_kv_heads
        
        q_heads = q.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)  
        k_heads = k.view(batch_size, seq_len, num_kv_heads, kv_head_dim).transpose(1, 2)  
        v_heads = v.view(batch_size, seq_len, num_kv_heads, kv_head_dim).transpose(1, 2) 
        
        attn_weights = outputs.attentions[layer_idx]  
        
        head_outputs = []
        
        num_kv_groups = num_heads // num_kv_heads  
        
        for h in range(num_heads):
            kv_head_idx = h // num_kv_groups
            
            head_weights = attn_weights[:, h]
            head_v = v_heads[:, kv_head_idx] 
            
            weighted_v = torch.matmul(head_weights, head_v)  
            head_outputs.append(weighted_v)
        
        head_outputs = torch.stack(head_outputs)

        print(head_outputs.shape)
        
        head_contributions = []
        
        for h in range(num_heads):
            concat_shape = (batch_size, seq_len, hidden_size)
            head_concat = torch.zeros(concat_shape, device=device)
            
            head_idx = h % num_heads
            kv_head_idx = h // num_kv_groups
            
            start_idx = head_idx * head_dim
            end_idx = start_idx + head_dim
            
            head_concat[:, :, start_idx:end_idx] = head_outputs[h].transpose(1, 2)
            
            head_contribution = o_proj(head_concat)
            head_contributions.append(head_contribution)
        
        head_contributions = torch.stack(head_contributions)
        combined_output = outputs.hidden_states[layer_idx + 1]
        
        return {
            'input_hidden_states': prev_hidden_states,
            'normalized_hidden': normalized_hidden,
            'query_projections': q_heads,
            'key_projections': k_heads,
            'value_projections': v_heads,
            'attention_weights': attn_weights,
            'head_outputs': head_outputs,
            'head_contributions': head_contributions,
            'combined_output': combined_output
        }


In [54]:
get_head_contributions(llama_1B, tokens, device)

1 11
torch.Size([32, 1, 11, 64])


RuntimeError: The expanded size of the tensor (64) must match the existing size (11) at non-singleton dimension 2.  Target sizes: [1, 11, 64].  Tensor sizes: [64, 11]