In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from warnings import filterwarnings

filterwarnings('ignore')

In [2]:
device = torch.device("cuda:5" if torch.cuda.is_available() else "cpu")

In [3]:
def get_device_memory_report(device):
    print(f'Device: {device} [{torch.cuda.get_device_name(device)}]')
    free_memory, total_memory = torch.cuda.mem_get_info(device)
    
    free_memory_gb = free_memory / (1024 ** 3)
    total_memory_gb = total_memory / (1024 ** 3)
    
    print(f"Free Memory: {free_memory_gb:.2f}/{total_memory_gb:.2f} GB [{free_memory / total_memory * 100:.2f}%]")

In [4]:
get_device_memory_report(device)

Device: cuda:5 [NVIDIA RTX 6000 Ada Generation]
Free Memory: 37.98/47.50 GB [79.95%]


In [5]:
def pprint_matrix(matrix):
    if len(matrix.shape) == 1:
        print(" ".join(f"{x:.2f}" for x in matrix))
    else:
        for row in matrix:
            print(" ".join(f"{x:.2f}" for x in row))

In [6]:
from transformers import AutoTokenizer, AutoModelForCausalLM

def load_model(model_id="meta-llama/Llama-3.2-1B", device="cuda:1"):

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.float16, 
        device_map=device,
        attn_implementation="eager"
    )
    model.to(device)
    tokenizer.pad_token = tokenizer.eos_token
    return model, tokenizer

In [7]:
llama_1B, llama_1B_tokenizer = load_model(model_id="meta-llama/Llama-3.2-1B", device=device)
llama_1B

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rotary_emb):

In [8]:
sentence = "The cat chased the mouse. The mouse ran away. Another word"

tokens = llama_1B_tokenizer(sentence, return_tensors="pt")
tokens = {k: v.to(llama_1B.device) for k, v in tokens.items()}

print(tokens['attention_mask'])

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:5')


In [9]:
with torch.no_grad():
    outputs = llama_1B(
        **tokens,
        output_hidden_states=True,
        return_dict=True,
        output_attentions=True,
    )  

In [14]:
print(outputs.hidden_states[0].shape)

first_attention_layer = 

torch.Size([1, 14, 2048])


In [25]:
print(llama_1B.base_model.layers[0].self_attn)

LlamaAttention(
  (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
  (k_proj): Linear(in_features=2048, out_features=512, bias=False)
  (v_proj): Linear(in_features=2048, out_features=512, bias=False)
  (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
)


In [31]:
import inspect

print(inspect.getsource(llama_1B.base_model.layers[0].self_attn.forward))

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: Tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_value: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE mod

In [22]:
V_matrix = llama_1B.base_model.layers[0].self_attn.v_proj.weight 
O_matrix = llama_1B.base_model.layers[0].self_attn.o_proj.weight
print(O_matrix.shape)
print(V_matrix.shape)

num_heads = 32
embedding_dim = 2048
value_matrix_per_head = V_matrix.view(num_heads, -1, embedding_dim)

print(value_matrix_per_head.shape)

torch.Size([2048, 2048])
torch.Size([512, 2048])
torch.Size([32, 16, 2048])


In [20]:
print(value_layer.shape)

torch.Size([512, 2048])


In [10]:
print(outputs.attentions[0].shape)
first_attention_head = outputs.attentions[0][0, 0, :, :]

# first_attention_head = first_attenion_head.cpu().numpy()
row_sums = first_attention_head.sum(axis=1)

# attention_value = torch.matmul(first_attention_head, value_layer)

pprint_matrix(first_attention_head)

torch.Size([1, 32, 14, 14])
1.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00
0.86 0.14 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00
0.44 0.51 0.05 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00
0.46 0.37 0.06 0.11 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00
0.13 0.52 0.12 0.20 0.03 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00
0.23 0.11 0.08 0.30 0.22 0.06 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00
0.14 0.09 0.07 0.29 0.16 0.07 0.17 0.00 0.00 0.00 0.00 0.00 0.00 0.00
0.10 0.04 0.03 0.09 0.09 0.09 0.52 0.04 0.00 0.00 0.00 0.00 0.00 0.00
0.13 0.03 0.02 0.05 0.06 0.07 0.38 0.22 0.04 0.00 0.00 0.00 0.00 0.00
0.12 0.01 0.01 0.04 0.02 0.03 0.36 0.28 0.07 0.05 0.00 0.00 0.00 0.00
0.11 0.01 0.01 0.03 0.03 0.03 0.20 0.20 0.14 0.21 0.03 0.00 0.00 0.00
0.09 0.01 0.01 0.03 0.03 0.04 0.15 0.11 0.10 0.20 0.06 0.15 0.00 0.00
0.09 0.00 0.00 0.01 0.01 0.01 0.06 0.03 0.02 0.09 0.08 0.53 0.06 0.00
0.07 0.00 0.00 0.00 0.00 0.00 0.02 0.03 0.01 0.04 0.05 0.40 0.

In [14]:
hidden_states = outputs.hidden_states
for i in range(len(hidden_states)):
    print(hidden_states[i].shape)


torch.Size([1, 14, 2048])
torch.Size([1, 14, 2048])
torch.Size([1, 14, 2048])
torch.Size([1, 14, 2048])
torch.Size([1, 14, 2048])
torch.Size([1, 14, 2048])
torch.Size([1, 14, 2048])
torch.Size([1, 14, 2048])
torch.Size([1, 14, 2048])
torch.Size([1, 14, 2048])
torch.Size([1, 14, 2048])
torch.Size([1, 14, 2048])
torch.Size([1, 14, 2048])
torch.Size([1, 14, 2048])
torch.Size([1, 14, 2048])
torch.Size([1, 14, 2048])
torch.Size([1, 14, 2048])
