In [1]:
from huggingface_hub import login
login(new_session=False)

In [2]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig

# Check GPU availability
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}")

# Load model on GPU
model_id = "meta-llama/Llama-3.2-1B"

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,  # or torch.float16 for older GPUs
    device_map="cuda:0"  # Automatically uses GPU
)
model.config.output_attentions = True

tokenizer = AutoTokenizer.from_pretrained(model_id)
config = AutoConfig.from_pretrained(model_id)


CUDA available: True
GPU: NVIDIA A100-SXM4-80GB


In [3]:
from transformers import AutoModelForCausalLM
import inspect

#model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")
layer0 = model.model.layers[0]

# View Layer 0 forward method
print("="*60)
print("LAYER 0 FORWARD:")
print("="*60)
print(inspect.getsource(layer0.forward))

print("\n" + "="*60)
print("ATTENTION FORWARD:")
print("="*60)
print(inspect.getsource(layer0.self_attn.forward))

print("\n" + "="*60)
print("ROTARY EMBEDDING:")
print("="*60)
print(inspect.getsource(model.model.rotary_emb.__class__))

print("\n" + "="*60)
print("\nMODEL.MODEL FORWARD:")
print(inspect.getsource(model.model.forward))


print("\n" + "="*60)
print("\nMODEL.MODEL embedding:")
print(inspect.getsource(model.model.embed_tokens.forward))

LAYER 0 FORWARD:
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # necessary, but kept here for BC
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention
        hidden_states, self_attn_weights = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
   

In [4]:
print(model)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rotary_emb):

In [5]:
input_text = "The future of artificial intelligence is"
inputs = tokenizer(input_text, return_tensors="pt")

device = next(model.parameters()).device
input_ids = inputs.input_ids.to(device)
attention_mask = inputs.attention_mask.to(device)

print(f"\n{'='*60}")
print("PREFILL PHASE")
print(f"{'='*60}")
print(f"Input text: '{input_text}'")
print(f"Input shape: {input_ids.shape}")

with torch.no_grad():
    outputs = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        use_cache=True,
        return_dict=True,
    )

past_key_values = outputs.past_key_values
print(past_key_values.key_cache[0].shape)
next_token_logits = outputs.logits[:, -1, :]
next_token_id = torch.argmax(next_token_logits, dim=-1)

print(f"Predicted next token: '{tokenizer.decode(next_token_id[0])}'")
print(f"KV cache created for {len(past_key_values)} layers")



PREFILL PHASE
Input text: 'The future of artificial intelligence is'
Input shape: torch.Size([1, 7])




torch.Size([1, 8, 7, 64])
Predicted next token: ' here'
KV cache created for 16 layers


In [6]:
ground_truth = {}

def attention_pre_hook(module, input):
    """
    Captures inputs to attention forward BEFORE computation
    input is a tuple of all arguments to forward()
    """
    # input[0] = hidden_states
    # input[1] = position_embeddings (tuple of cos, sin)
    # input[2] = attention_mask
    # input[3] = past_key_value
    # input[4] = cache_position
    # input[5+] = kwargs
    
    if len(input) > 0:
        ground_truth['attn_hidden_states'] = input[0].detach().clone()
    if len(input) > 1:
        ground_truth['attn_position_embeddings'] = input[1]  # This is (cos, sin) tuple
    if len(input) > 2:
        ground_truth['attn_attention_mask'] = input[2].detach().clone() if input[2] is not None else None
    if len(input) > 3:
        ground_truth['attn_past_key_value'] = input[3]
    if len(input) > 4:
        ground_truth['attn_cache_position'] = input[4].detach().clone() if input[4] is not None else None
    
    print("PRE HOOK INPUTS: ", input)
def rotary_embedding_hook(module, input, output):
    if isinstance(output, tuple) and len(output) == 2:
        # Rotary embeddings return (cos, sin)
        ground_truth['rot_emb_cos'] = output[0].detach().clone()
        ground_truth['rot_emb_sin'] = output[1].detach().clone()
    else:
        ground_truth['rot_emb_out'] = output.detach().clone()

"""
def make_hook(name):
    def hook(module, input, output):
        if isinstance(output, tuple):
            ground_truth[f"{name}_output"] = output[0].detach().clone()
        else:
            ground_truth[f"{name}_output"] = output.detach().clone()
        if isinstance(input, tuple):
            ground_truth[f"{name}_input"] = input[0].detach().clone()
        else:
            ground_truth[f"{name}_input"] = input.detach().clone()
    return hook
    """
def make_hook(name):
    def hook(module, input, output):
        # Handle output - can be tuple or tensor
        if isinstance(output, tuple):
            ground_truth[f"{name}_output"] = output[0].detach().clone()
            ground_truth[f"{name}_output2"] = output[1].detach().clone()

        else:
            ground_truth[f"{name}_output"] = output.detach().clone()
        
        # Handle input - ALWAYS a tuple in forward hooks
        # input[0] is the first argument (usually the tensor)
        if isinstance(input, tuple) and len(input) > 0:
            if isinstance(input[0], torch.Tensor):
                ground_truth[f"{name}_input"] = input[0].detach().clone()
            else:
                # If it's not a tensor, skip it
                pass
        else: 
            ground_truth[f"{name}_input"] = input
        
    return hook
def make_input_hook(name):
    def hook(module, input, output):
        if isinstance(input, tuple):
            ground_truth[name] = input[0].detach().clone()
        else:
            ground_truth[name] = input.detach().clone()
    return hook
def rope_hook(module, input, output):
    x, position_ids = input
    ground_truth['rot_emb_position_ids'] = position_ids.detach().clone()
    ground_truth['rot_emb_x'] = x.detach().clone()

    print(f"Position IDs shape: {position_ids.shape}")
    print(f"Position IDs:\n{position_ids}")
    print(f"Output cos shape: {output[0].shape}")
    print(f"Output sin shape: {output[1].shape}")
    return output

# Also capture intermediate attention states
attention_internals = {}

def attention_hook(module, input, output):
    # Capture the full output tuple from attention
    if isinstance(output, tuple):
        attention_internals['attn_output'] = output[0].detach().clone()
        if len(output) > 1 and output[1] is not None:
            attention_internals['attn_weights'] = output[1].detach().clone() if output[1] is not None else None
        if len(output) > 2:
            attention_internals['past_kv'] = output[2]
    else:
        attention_internals['attn_output'] = output.detach().clone()

# Register hooks on Layer 0
layer_0 = model.model.layers[0]
hooks = []

# Capture embedding
def embedding_hook(module, input, output):
    ground_truth['embedding_output'] = output.detach().clone()
    ground_truth['embedding_input'] = input.detach().clone()



#hooks.append(model.model.embed_tokens.register_forward_hook(embedding_hook))
hooks.append(model.model.rotary_emb.register_forward_hook(rotary_embedding_hook))

hooks.append(model.model.rotary_emb.register_forward_hook(make_input_hook("rot_emb_in")))

hooks.append(layer_0.input_layernorm.register_forward_hook(make_hook('input_layernorm')))
hooks.append(layer_0.self_attn.q_proj.register_forward_hook(make_hook('q_proj')))
hooks.append(layer_0.self_attn.k_proj.register_forward_hook(make_hook('k_proj')))
hooks.append(layer_0.self_attn.v_proj.register_forward_hook(make_hook('v_proj')))
hooks.append(layer_0.self_attn.register_forward_hook(make_hook("attention")))
hooks.append(layer_0.self_attn.o_proj.register_forward_hook(make_hook('o_proj')))
#hooks.append(layer_0.self_attn.config.attn_implementation(make_hook('atten_interface')))
#hooks.append(model.model._update_causal_mask.register_forward_hook(make_hook('causal_mask')))
hooks.append(layer_0.self_attn.register_forward_pre_hook(attention_pre_hook))


hooks.append(layer_0.post_attention_layernorm.register_forward_hook(make_hook('post_attention_layernorm')))
hooks.append(layer_0.mlp.gate_proj.register_forward_hook(make_hook('gate_proj')))
hooks.append(layer_0.mlp.up_proj.register_forward_hook(make_hook('up_proj')))
hooks.append(layer_0.mlp.down_proj.register_forward_hook(make_hook('down_proj')))
hooks.append(model.model.rotary_emb.register_forward_hook(rope_hook))

# Capture residual connection after attention
def layer_output_hook(module, input, output):
    if isinstance(output, tuple):
        ground_truth['layer_0_output'] = output[0].detach().clone()
    else:
        ground_truth['layer_0_output'] = output.detach().clone()

hooks.append(layer_0.register_forward_hook(layer_output_hook))

# Run model with hooks to get ground truth
decode_token_id = next_token_id.unsqueeze(0)
print(f"\nRunning model with hooks to capture ground truth...")

with torch.no_grad():
    ground_truth_outputs = model(
        input_ids=decode_token_id,
        attention_mask=torch.ones_like(decode_token_id),
        past_key_values=past_key_values,
        use_cache=True,
        return_dict=True
    )

# Remove hooks
for hook in hooks:
    hook.remove()




Running model with hooks to capture ground truth...
Position IDs shape: torch.Size([1, 1])
Position IDs:
tensor([[7]], device='cuda:0')
Output cos shape: torch.Size([1, 1, 64])
Output sin shape: torch.Size([1, 1, 64])
PRE HOOK INPUTS:  ()


In [7]:
# PRINTING STUFF HOOKED FROM MODEL...

print(f"Ground truth captured for {len(ground_truth)} operations")
print(f"Ground truth keys: {list(ground_truth.keys())}")
print(f"Attention internals keys: {list(attention_internals.keys())}")

print(ground_truth['rot_emb_cos'].shape) # *** K.K  we probably want to use this as the rotary embedding weights...
print(ground_truth['rot_emb_sin'].shape) # *** K.K  we probably want to use this as the rotary embedding weights...

print(ground_truth['rot_emb_in'].shape)
sin_golden = ground_truth['rot_emb_sin']
cos_golden = ground_truth['rot_emb_cos']

text_with_space = " here"
tokens = tokenizer.encode(text_with_space)
print(f"Tokens: {tokens}")
print(f"First token: {tokens[1]}")

# Get embedding matrix
embedding_matrix = model.model.embed_tokens.weight
print(f"Embedding matrix shape: {embedding_matrix.shape}")

print(embedding_matrix[1618])
print(embedding_matrix.shape)

print(layer_0.self_attn.config._attn_implementation)

print(model.model.config.output_attentions)

print("attention outs:" , ground_truth['attention_output2'].float().cpu().numpy().shape)
#print("o proj outs:" , ground_truth['o_proj_output'].float().cpu().numpy().shape)


print("position ids:" , ground_truth['rot_emb_position_ids'].cpu().numpy())
print("rot embedding x input vector shape: ", ground_truth['rot_emb_x'].shape)
print("rot embedding x INPUT vector: ", ground_truth['rot_emb_x'].float().cpu().numpy())

Ground truth captured for 27 operations
Ground truth keys: ['rot_emb_cos', 'rot_emb_sin', 'rot_emb_in', 'rot_emb_position_ids', 'rot_emb_x', 'input_layernorm_output', 'input_layernorm_input', 'q_proj_output', 'q_proj_input', 'k_proj_output', 'k_proj_input', 'v_proj_output', 'v_proj_input', 'o_proj_output', 'o_proj_input', 'attention_output', 'attention_output2', 'attention_input', 'post_attention_layernorm_output', 'post_attention_layernorm_input', 'gate_proj_output', 'gate_proj_input', 'up_proj_output', 'up_proj_input', 'down_proj_output', 'down_proj_input', 'layer_0_output']
Attention internals keys: []
torch.Size([1, 1, 64])
torch.Size([1, 1, 64])
torch.Size([1, 1, 2048])
Tokens: [128000, 1618]
First token: 1618
Embedding matrix shape: torch.Size([128256, 2048])
tensor([ 0.0088,  0.0013,  0.0354,  ..., -0.0277, -0.0461,  0.0244],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<SelectBackward0>)
torch.Size([128256, 2048])
sdpa
True
attention outs: (1, 32, 1, 8)
position ids: [

In [8]:
print(f"\n{'='*60}")
print("DECODE PHASE - LAYER 0 DETAILED INSPECTION")
print(f"{'='*60}")

print(f"\nDecode token ID: {decode_token_id}")
print(f"Decode token: '{tokenizer.decode(decode_token_id[0])}'")

# Helper function to compare tensors
def compare_tensors(manual, ground_truth, name, rtol=1e-3, atol=1e-3
                   ):
    if ground_truth is None:
        print(f"  ⚠️  No ground truth available for {name}")
        return
    
    manual = manual.to(ground_truth.dtype)
    match = torch.allclose(manual, ground_truth, rtol=rtol, atol=atol)
    max_diff = (manual - ground_truth).abs().max().item()
    mean_diff = (manual - ground_truth).abs().mean().item()
    
    if match:
        print(f"  ✓ MATCH: {name}")
    else:
        print(f"  ✗ MISMATCH: {name}")
    print(f"    Max diff: {max_diff:.6e}, Mean diff: {mean_diff:.6e}")
    print(f"    Manual shape: {manual.shape}, GT shape: {ground_truth.shape}")
    print(f"    Manual sample: {manual.flatten()[:3]}")
    print(f"    GT sample:     {ground_truth.flatten()[:3]}")

# Step 1: Get embedding for the new token
print(f"\n{'='*60}")
print("STEP 1: Token Embedding")
print(f"{'='*60}")
embed_tokens = model.model.embed_tokens
hidden_states = embed_tokens(decode_token_id)
print(f"Embedding output shape: {hidden_states.shape}")
print(f"Embedding dtype: {hidden_states.dtype}")
print(f"Embedding sample values: {hidden_states[0, 0, :5]}")
compare_tensors(hidden_states, ground_truth.get('embedding'), "Token Embedding")
print(hidden_states.shape)


DECODE PHASE - LAYER 0 DETAILED INSPECTION

Decode token ID: tensor([[1618]], device='cuda:0')
Decode token: ' here'

STEP 1: Token Embedding
Embedding output shape: torch.Size([1, 1, 2048])
Embedding dtype: torch.bfloat16
Embedding sample values: tensor([ 0.0088,  0.0013,  0.0354,  0.0029, -0.0430], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<SliceBackward0>)
  ⚠️  No ground truth available for Token Embedding
torch.Size([1, 1, 2048])


In [9]:
# Step 2: Input Layer Norm
print(f"\n{'='*60}")
print("STEP 2: Input Layer Norm")
print(f"{'='*60}")
input_layernorm = layer_0.input_layernorm
print(list(input_layernorm.named_parameters()))
normed_hidden = input_layernorm(hidden_states)
hidden_states_sq = hidden_states ** 2
mean = hidden_states_sq.mean(dim=-1,keepdim=True)
print(mean, "AHPPY")
x_norm = (hidden_states)/(torch.sqrt(mean +1e-05)) * input_layernorm.weight
compare_tensors(x_norm, normed_hidden, "lnorm")
#print(norm.shape)
#print(f"LayerNorm output shape: {normed_hidden.shape}")
#print(f"LayerNorm output sample: {normed_hidden[0, 0, :5]}")
#compare_tensors(normed_hidden, ground_truth.get('input_layernorm'), "Input LayerNorm")

#K.K. says this part done



STEP 2: Input Layer Norm
[('weight', Parameter containing:
tensor([0.1582, 0.1807, 0.2695,  ..., 0.2217, 0.2109, 0.1523], device='cuda:0',
       dtype=torch.bfloat16, requires_grad=True))]
tensor([[[0.0005]]], device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<MeanBackward1>) AHPPY
  ✗ MISMATCH: lnorm
    Max diff: 1.562500e-02, Mean diff: 2.784729e-04
    Manual shape: torch.Size([1, 1, 2048]), GT shape: torch.Size([1, 1, 2048])
    Manual sample: tensor([0.0603, 0.0103, 0.4121], device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<SliceBackward0>)
    GT sample:     tensor([0.0603, 0.0103, 0.4141], device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<SliceBackward0>)


In [10]:
# Step 3: Self-Attention - Q, K, V Projections
import inspect
print(f"\n{'='*60}")
print("STEP 3: Q, K, V Projections")
print(f"{'='*60}")
self_attn = layer_0.self_attn
print(list(self_attn.named_parameters()), "ARCH")

print(self_attn.q_proj.bias)
print(inspect.getsource(self_attn.q_proj.forward))
print(list(self_attn.q_proj.named_parameters()))

print(self_attn.q_proj.weight.shape)
print(self_attn.k_proj.weight.shape)
print(x_norm.T.shape)
#KEVIN STYLE
q_kevin = x_norm.reshape(2048) @ self_attn.q_proj.weight.T 
k_kevin = self_attn.k_proj.weight @ x_norm.reshape(2048).T
v_kevin = x_norm.reshape(2048) @ self_attn.v_proj.weight.T

# Project to Q, K, V
q = self_attn.q_proj(normed_hidden)
k = self_attn.k_proj(normed_hidden)
v = self_attn.v_proj(normed_hidden)
compare_tensors(q, q_kevin, "Q Projection")
compare_tensors(k, k_kevin, "K Projection")
compare_tensors(v, v_kevin, "V Projection")

#print(f"Q shape: {q.shape}")
#print(f"K shape: {k.shape}")
#print(f"V shape: {v.shape}")
#compare_tensors(q, ground_truth.get('q_proj'), "Q Projection")
#compare_tensors(k, ground_truth.get('k_proj'), "K Projection")
#compare_tensors(v, ground_truth.get('v_proj'), "V Projection")


STEP 3: Q, K, V Projections
[('q_proj.weight', Parameter containing:
tensor([[-0.0183,  0.0071,  0.0219,  ..., -0.0070, -0.0089,  0.0149],
        [ 0.0112,  0.0593,  0.0630,  ..., -0.0334, -0.0148,  0.0058],
        [ 0.0182,  0.0141,  0.0361,  ..., -0.0432, -0.0388, -0.0233],
        ...,
        [ 0.0305,  0.0289,  0.0801,  ..., -0.0767, -0.0311, -0.0334],
        [ 0.0242, -0.0325,  0.0369,  ..., -0.0123, -0.0269, -0.0151],
        [-0.0264, -0.0498, -0.0210,  ...,  0.0601,  0.0130, -0.0007]],
       device='cuda:0', dtype=torch.bfloat16, requires_grad=True)), ('k_proj.weight', Parameter containing:
tensor([[ 0.0559,  0.1123,  0.0718,  ..., -0.1045, -0.0332,  0.0104],
        [-0.0126,  0.0718,  0.0645,  ..., -0.0304, -0.0206, -0.0496],
        [-0.0820, -0.0408, -0.0036,  ...,  0.0036, -0.0237,  0.0013],
        ...,
        [-0.0012,  0.0042,  0.0047,  ..., -0.0003, -0.0332, -0.0099],
        [-0.0261, -0.0064,  0.0422,  ...,  0.0359,  0.0236,  0.0057],
        [ 0.0179,  0.0250

  print(x_norm.T.shape)


In [11]:
# Step 4: Reshape for multi-head attention
print(f"\n{'='*60}")
print("STEP 4: Reshape for Multi-Head Attention")
print(f"{'='*60}")
bsz, q_len, _ = normed_hidden.size()

# Get config values
num_heads = config.num_attention_heads
num_key_value_heads = config.num_key_value_heads
hidden_size = config.hidden_size
head_dim = hidden_size // num_heads

print(f"Config values:")
print(f"  num_attention_heads: {num_heads}")
print(f"  num_key_value_heads: {num_key_value_heads}")
print(f"  hidden_size: {hidden_size}")
print(f"  head_dim: {head_dim}")

q_kevin_reshaped = q_kevin.reshape(1,32,1,64)
k_kevin_reshaped = k_kevin.reshape(1,8,64,1)
v_kevin_reshaped = v_kevin.reshape(1,8,1,64)

q_kevin_reshaped_final = q_kevin_reshaped.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
k_kevin_reshaped_final = k_kevin_reshaped.view(bsz, q_len, num_key_value_heads, head_dim).transpose(1, 2)
v_kevin_reshaped_final = v_kevin_reshaped.view(bsz, q_len, num_key_value_heads, head_dim).transpose(1, 2)


q_reshaped = q.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
k_reshaped = k.view(bsz, q_len, num_key_value_heads, head_dim).transpose(1, 2)
v_reshaped = v.view(bsz, q_len, num_key_value_heads, head_dim).transpose(1, 2)

print(f"\nQ reshaped: {q_reshaped.shape} [batch, num_heads, seq_len, head_dim]")
print(f"K reshaped: {k_reshaped.shape} [batch, num_kv_heads, seq_len, head_dim]")
print(f"V reshaped: {v_reshaped.shape} [batch, num_kv_heads, seq_len, head_dim]")
print(f"Q reshaped sample: {q_reshaped[0, 0, 0, :5]}")
print(f"K reshaped sample: {k_reshaped.T[0, 0, 0, :5]}")
print(f"V reshaped sample: {v_reshaped[0, 0, 0, :5]}")

print(f"Q kevin reshaped sample: {q_kevin_reshaped[ 0, 0, :5]}")
print(f"K kevin reshaped sample: {k_kevin_reshaped.T[ 0, 0, :5]}")
print(f"V kevin reshaped sample: {v_kevin_reshaped[ 0, 0, :5]}")


STEP 4: Reshape for Multi-Head Attention
Config values:
  num_attention_heads: 32
  num_key_value_heads: 8
  hidden_size: 2048
  head_dim: 64

Q reshaped: torch.Size([1, 32, 1, 64]) [batch, num_heads, seq_len, head_dim]
K reshaped: torch.Size([1, 8, 1, 64]) [batch, num_kv_heads, seq_len, head_dim]
V reshaped: torch.Size([1, 8, 1, 64]) [batch, num_kv_heads, seq_len, head_dim]
Q reshaped sample: tensor([ 1.0547,  2.8438,  2.5781, -3.6094, -4.9062], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<SliceBackward0>)
K reshaped sample: tensor([4.8438], device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<SliceBackward0>)
V reshaped sample: tensor([-0.0111,  0.0618, -0.0505, -0.0020, -0.0245], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<SliceBackward0>)
Q kevin reshaped sample: tensor([[  1.0547,   2.8438,   2.5781,  -3.5938,  -4.9062,  -1.0156,  -1.4062,
          -0.9688,  -1.7891,  -2.6562,   1.4844,   2.1719,  -2.3594,  -0.9297,
          -0.9062,  -2.6406,   1.9453,  

In [12]:
# Step 5: Apply RoPE (Rotary Position Embeddings)
print(f"\n{'='*60}")
print("STEP 5: Apply RoPE (Rotary Position Embeddings)")
print(f"{'='*60}")

# Get the position for the new token (seq_len from KV cache)
kv_seq_len = past_key_values[0][0].shape[2]  # Current cache length
position_ids = torch.arange(kv_seq_len, kv_seq_len + 1, device=device).unsqueeze(0)
print(f"Current KV cache length: {kv_seq_len}")
print(f"Position IDs: {position_ids}")

# Get RoPE from the model's rotary embedding
rotary_emb = model.model.rotary_emb  # This is where RoPE lives
#cos, sin = rotary_emb(v_reshaped, position_ids)
cos = cos_golden
sin = sin_golden

print(f"RoPE cos shape: {cos.shape}")
print(f"RoPE sin shape: {sin.shape}")
print(f"RoPE cos sample: {cos[0, 0, :5]}")
print(f"RoPE sin sample: {sin[0, 0, :5]}")

# Apply rotation using the same method as transformers
def rotate_half(x):
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin):
    # Ensure cos and sin have the right shape for broadcasting
    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
    
    # Expand to match q and k dimensions
    cos = cos.unsqueeze(0).unsqueeze(0)  # [1, 1, seq_len, dim]
    sin = sin.unsqueeze(0).unsqueeze(0)  # [1, 1, seq_len, dim]
    
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

q_before_rope = q_reshaped.clone()
k_before_rope = k_reshaped.clone()
#q_rope, k_rope = apply_rotary_pos_emb(q_reshaped, k_reshaped, cos, sin)
q_rope, k_rope = apply_rotary_pos_emb(q_kevin_reshaped_final, k_kevin_reshaped_final, cos, sin)


print(f"\nQ before RoPE sample: {q_before_rope[0, 0, 0, :5]}")
print(f"Q after RoPE sample: {q_rope[0, 0, 0, :5]}")
print(f"K before RoPE sample: {k_before_rope[0, 0, 0, :5]}")
print(f"K after RoPE sample: {k_rope[0, 0, 0, :5]}")


print(ground_truth.keys())

#print(f"attention inputs: {ground_truth["attn_position_embeddings"]}")
#print(f"attention inputs: {ground_truth["attn_past_key_value"]}")

print("rotary embedding inputs: " , ground_truth["rot_emb_in"].shape)
print(f"attention outputs: {ground_truth["attention_output"].float().cpu().numpy()}")
print(f"o proj input: {ground_truth["o_proj_input"].float().cpu().numpy()}")
print(f"o proj output: {ground_truth["o_proj_output"].float().cpu().numpy()}")






STEP 5: Apply RoPE (Rotary Position Embeddings)
Current KV cache length: 8
Position IDs: tensor([[8]], device='cuda:0')
RoPE cos shape: torch.Size([1, 1, 64])
RoPE sin shape: torch.Size([1, 1, 64])
RoPE cos sample: tensor([ 0.7539, -0.0669, -1.0000, -0.4570,  0.2119], device='cuda:0',
       dtype=torch.bfloat16)
RoPE sin sample: tensor([ 0.6562, -0.9961,  0.0591,  0.8906,  0.9766], device='cuda:0',
       dtype=torch.bfloat16)

Q before RoPE sample: tensor([ 1.0547,  2.8438,  2.5781, -3.6094, -4.9062], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<SliceBackward0>)
Q after RoPE sample: tensor([ 0.8086, -1.5234, -2.6719, -1.4375, -1.4844], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<SliceBackward0>)
K before RoPE sample: tensor([ 4.8438,  2.9219, -1.2734, -2.5625, -1.7188], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<SliceBackward0>)
K after RoPE sample: tensor([1.4688, 2.0625, 1.1094, 1.7656, 3.1094], device='cuda:0',
       dtype=torch.bfloat16, grad_fn

In [13]:
# Step 6: Concatenate with KV Cache

print(inputs)
print(f"\n{'='*60}")
print("STEP 6: Concatenate with KV Cache")
print(f"{'='*60}")

# Past key and past values use the GOLDEN TRUTH kv cache but without the extra column due to decode (decode automatically adds to the dynamic cache)
past_key_cache = past_key_values[0][0][:,:,:7,:].clone()  # Layer 0 keys
past_value_cache = past_key_values[0][1][:,:,:7,:].clone()  # Layer 0 values

golden_key_vec = past_key_values[0][0][:,:,7:8,:].clone()
golden_value_vec = past_key_values[0][1][:,:,7:8,:].clone()

kevin_key_vec = k_rope.clone();
kevin_val_vec = v_kevin_reshaped.clone();
print("golden key vec shape: ", golden_key_vec.shape)
print("golden value vec shape: ", golden_value_vec.shape)
print(f"K after RoPE sample: {kevin_key_vec[:,0,:,:]}")

print("golden key vector: ", golden_key_vec[:,0,:])
print(f"Kevin V: {kevin_val_vec[:,0,:]}")

print("golden value vector: ", golden_value_vec[:,0,:])


print(f"Past K cache shape: {past_key_cache.shape}")
print(f"Past V cache shape: {past_value_cache.shape}")
print(f"New K shape: {k_rope.shape}")
print(f"New V shape: {v_reshaped.shape}")

# Concatenate cached K, V, with KEVIN K,V vectors
#k_combined = torch.cat([past_key_cache, kevin_key_vec], dim=2)
#v_combined = torch.cat([past_value_cache, kevin_val_vec], dim=2)

# Concatenate cached K, V with GOLDEN K,V vectors
k_combined = torch.cat([past_key_cache, golden_key_vec], dim=2)
v_combined = torch.cat([past_value_cache, golden_value_vec], dim=2)


print(f"Combined K shape: {k_combined.shape}")

print(f"Combined V shape: {v_combined.shape}")
print(f"Combined K sample (last token): {k_combined[0, 0, -1, :5]}")
print(f"Combined V sample (last token): {v_combined[0, 0, -1, :5]}")

# *** NOTE: WORKS UP TILL HERE (KK 10-18-25)

{'input_ids': tensor([[128000,    791,   3938,    315,  21075,  11478,    374]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]])}

STEP 6: Concatenate with KV Cache
golden key vec shape:  torch.Size([1, 8, 1, 64])
golden value vec shape:  torch.Size([1, 8, 1, 64])
K after RoPE sample: tensor([[[ 1.4688,  2.0625,  1.1094,  1.7656,  3.1094, -1.4844, -0.3594,
          -3.3438, -0.8398, -2.1094,  2.3750,  1.9141, -2.0625, -1.9531,
           0.0527, -1.0859,  1.1953,  0.3281,  0.9492, -3.9531, -0.0898,
          -0.0962,  2.6094, -0.7344, -2.9375,  1.9531,  1.8984, -2.4844,
           1.6172,  1.4531, -1.6562,  1.5703,  5.6250, -3.0625, -2.6562,
          -1.9844, -2.4219, -1.2656, -2.8906,  1.1875, -2.5312, -1.4766,
           2.0312, -1.8203,  2.3594, -1.9922, -0.4824,  3.6719, -1.2031,
          -3.2656, -0.5312,  0.2139,  6.4062,  5.4688, -0.9062, -4.0938,
           1.1875, -0.2188, -1.0938,  1.5469,  1.9297,  2.9844, -2.2500,
          -1.4453]]], device='cuda:0', dtype=torch.bfl

In [14]:
# Step 7: Repeat KV heads if needed (GQA - Grouped Query Attention)
print(f"\n{'='*60}")
print("STEP 7: Repeat KV Heads (Grouped Query Attention)")
print(f"{'='*60}")
if num_key_value_heads != num_heads:
    n_rep = num_heads // num_key_value_heads
    k_repeated = k_combined.repeat_interleave(n_rep, dim=1)
    v_repeated = v_combined.repeat_interleave(n_rep, dim=1)
    print(f"KV heads repeated {n_rep} times")
    print(f"K shape after repeat: {k_repeated.shape}")
    print(f"V shape after repeat: {v_repeated.shape}")
else:
    k_repeated = k_combined
    v_repeated = v_combined
    print(f"No KV head repetition needed (MHA)")


STEP 7: Repeat KV Heads (Grouped Query Attention)
KV heads repeated 4 times
K shape after repeat: torch.Size([1, 32, 8, 64])
V shape after repeat: torch.Size([1, 32, 8, 64])


In [15]:
# Step 8: Compute Attention Scores
print(f"head dim: {head_dim}")
print(f"q dim: {q_rope.shape}")
print(f"\n{'='*60}")
print("STEP 8: Compute Attention Scores")
print(f"{'='*60}")
attn_scores = torch.matmul(q_rope, k_repeated.transpose(2, 3)) / (head_dim ** 0.5)
print(f"Attention weights shape: {attn_scores.shape}")
print(f"Attention weights (raw) sample:\n{attn_scores[0, 0, 0, -5:]}")
print(f"Head 0, attending to last 5 positions: {attn_scores[0, 0, 0, -5:]}")


head dim: 64
q dim: torch.Size([1, 32, 1, 64])

STEP 8: Compute Attention Scores
Attention weights shape: torch.Size([1, 32, 1, 8])
Attention weights (raw) sample:
tensor([7.8438, 8.0000, 8.1875, 9.0000, 5.8438], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<SliceBackward0>)
Head 0, attending to last 5 positions: tensor([7.8438, 8.0000, 8.1875, 9.0000, 5.8438], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<SliceBackward0>)


In [16]:
def kevin_softmax(attention_scores):
    # assumes that attention_scores is [1, query_head, 1, seq_length]
    vectors = []
    for x in attention_scores[0]:
        exp = torch.exp(x)
        sum_exp = torch.sum(exp)
        softmaxed_vec = exp / sum_exp
        # reinsert the singleton dimension that was dropped (dim=0 of x)
        softmaxed_vec = softmaxed_vec.unsqueeze(0)
        vectors.append(softmaxed_vec)
    result = torch.stack(vectors, dim=1)
    print(result.shape)
    return result

In [17]:
# Step 9: Apply Causal Mask and Softmax
import torch.nn.functional as F



print(f"\n{'='*60}")
print("STEP 9: Apply Causal Mask and Softmax")
print(f"{'='*60}")
# For decode, we attend to all previous tokens
# The mask is effectively: [1, 1, 1, ..., 1] (attend to all)

#GOLDEN SOFTMAX
attn_weights = F.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q_rope.dtype)

#KEVIN SOFTMAX
#attn_weights = kevin_softmax(attn_scores)
print(f"Attention weights (after softmax) shape: {attn_weights.shape}")
print(f"Attention weights sample (should sum to 1):\n{attn_weights[0, 0, 0, -5:]}")
print(f"Sum of attention weights: {attn_weights[0, 0, 0, :].sum()}")


golden_attn_weights = ground_truth['attention_output2']
print("golden attention weights:")
print(golden_attn_weights[:,0:4,:,:])
print("kevin attention weights:")

print(attn_weights[:,0:4,:,:])


STEP 9: Apply Causal Mask and Softmax
Attention weights (after softmax) shape: torch.Size([1, 32, 1, 8])
Attention weights sample (should sum to 1):
tensor([0.1025, 0.1196, 0.1445, 0.3262, 0.0139], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<SliceBackward0>)
Sum of attention weights: 1.0
golden attention weights:
tensor([[[[1.4062e-01, 1.0645e-01, 4.4434e-02, 1.0303e-01, 1.2451e-01,
           1.5039e-01, 3.1641e-01, 1.3916e-02]],

         [[4.3750e-01, 1.3672e-02, 6.0730e-03, 2.5635e-03, 9.0942e-03,
           3.8330e-02, 3.2031e-01, 1.7188e-01]],

         [[2.9175e-02, 2.7657e-05, 2.5034e-06, 5.8115e-06, 3.1281e-04,
           3.5156e-02, 9.0625e-01, 2.9175e-02]],

         [[9.4141e-01, 1.0132e-02, 3.7231e-03, 2.9449e-03, 2.9449e-03,
           1.1841e-02, 1.1841e-02, 1.6235e-02]]]], device='cuda:0',
       dtype=torch.bfloat16)
kevin attention weights:
tensor([[[[1.4453e-01, 1.0547e-01, 4.3945e-02, 1.0254e-01, 1.1963e-01,
           1.4453e-01, 3.2617e-01, 1.3855e-02]

In [18]:
# Step 10: Compute Attention Output
print(f"\n{'='*60}")
print("STEP 10: Compute Attention Output")
print(f"{'='*60}")
attn_output = torch.matmul(attn_weights, v_repeated)
print(f"Attention output shape: {attn_output.shape}")
print(f"Attention output sample: {attn_output[0, 0:2, 0, :5]}")

# Compare with ground truth attention output if available
# The ground truth is already in [batch, seq, hidden] format
gt_attn = ground_truth['o_proj_input']
print(gt_attn.shape)
print("golden way: ", gt_attn[:,:,:50])
print("kevin way: ", attn_output[:,:,:50])
attn_output_reshaped = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, -1)


print("golden way: ", attn_output_reshaped[:,:,:50])
print("golden way: ", gt_attn[:,:,:50])
print(100*((attn_output_reshaped[:,:,:50]-gt_attn[:,:,:50])/gt_attn[:,:,:50]))



# Our attn_output is in [batch, heads, seq, head_dim], need to reshape
print(f"\nComparing attention output (before o_proj):")
compare_tensors(attn_output_reshaped, gt_attn, "Attention Output (before o_proj)")


STEP 10: Compute Attention Output
Attention output shape: torch.Size([1, 32, 1, 64])
Attention output sample: tensor([[ 7.7248e-05,  4.6692e-03,  2.0905e-03, -6.8970e-03,  4.1504e-02],
        [ 1.5442e-02,  1.4099e-02, -2.7344e-02,  6.4392e-03,  4.3945e-02]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<SliceBackward0>)
torch.Size([1, 1, 2048])
golden way:  tensor([[[-0.0007,  0.0048,  0.0027, -0.0081,  0.0403, -0.0297,  0.0227,
          -0.0247,  0.0079,  0.0092, -0.0173,  0.0374, -0.0132, -0.0103,
           0.0286,  0.0167,  0.0522, -0.0219, -0.0669, -0.0240,  0.0002,
          -0.0378,  0.0713, -0.0262, -0.0391,  0.0201, -0.0167,  0.0114,
          -0.0221,  0.0481,  0.0084, -0.0133, -0.0469,  0.0234,  0.0014,
           0.0139, -0.0513,  0.0747, -0.0674, -0.0048, -0.0299, -0.0123,
           0.0415, -0.0464, -0.0209, -0.0094, -0.0237,  0.0400,  0.2139,
           0.0469]]], device='cuda:0', dtype=torch.bfloat16)
kevin way:  tensor([[[[ 7.7248e-05,  4.6692e-03,  2.0905

In [19]:
# Step 11: Reshape and Project Output
print(f"\n{'='*60}")
print("STEP 11: Reshape and Project Output")
print(f"{'='*60}")
attn_output_merged1 = attn_output.transpose(1, 2).contiguous()
attn_output_merged2 = attn_output_merged1.reshape(bsz, q_len, -1)
print(attn_output_merged2[0,0,64:68])
print(f"Attention output after reshape: {attn_output_merged2.shape}")

print(self_attn.o_proj.weight.shape)


attn_output_proj = self_attn.o_proj(attn_output_merged2)
print(f"Attention output after o_proj: {attn_output_proj.shape}")
print(f"O_proj output sample: {attn_output_proj[0, 0, :5]}")
compare_tensors(attn_output_proj, ground_truth.get('o_proj'), "O Projection")


STEP 11: Reshape and Project Output
tensor([ 0.0154,  0.0141, -0.0273,  0.0064], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<SliceBackward0>)
Attention output after reshape: torch.Size([1, 1, 2048])
torch.Size([2048, 2048])
Attention output after o_proj: torch.Size([1, 1, 2048])
O_proj output sample: tensor([ 0.0052, -0.0045, -0.0134,  0.0035, -0.0002], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<SliceBackward0>)
  ⚠️  No ground truth available for O Projection


In [20]:
# Step 12: Residual Connection 1
print(f"\n{'='*60}")
print("STEP 12: Residual Connection (Attention)")
print(f"{'='*60}")
hidden_states_after_attn = hidden_states + attn_output_proj
print(f"Hidden states after residual: {hidden_states_after_attn.shape}")
print(f"Hidden states sample: {hidden_states_after_attn[0, 0, :5]}")
print(f"Residual input: {hidden_states[0, 0, :5]}")
print(f"Attention output: {attn_output_proj[0, 0, :5]}")
print(f"After addition: {hidden_states_after_attn[0, 0, :5]}")


STEP 12: Residual Connection (Attention)
Hidden states after residual: torch.Size([1, 1, 2048])
Hidden states sample: tensor([ 0.0140, -0.0031,  0.0220,  0.0065, -0.0432], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<SliceBackward0>)
Residual input: tensor([ 0.0088,  0.0013,  0.0354,  0.0029, -0.0430], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<SliceBackward0>)
Attention output: tensor([ 0.0052, -0.0045, -0.0134,  0.0035, -0.0002], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<SliceBackward0>)
After addition: tensor([ 0.0140, -0.0031,  0.0220,  0.0065, -0.0432], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<SliceBackward0>)


In [21]:
# Step 13: Post-Attention LayerNorm
print(f"\n{'='*60}")
print("STEP 13: Post-Attention LayerNorm")
print(f"{'='*60}")
post_attention_layernorm = layer_0.post_attention_layernorm
normed_hidden_2 = post_attention_layernorm(hidden_states_after_attn)
print(f"Post-attention LayerNorm output shape: {normed_hidden_2.shape}")
print(f"Post-attention LayerNorm sample: {normed_hidden_2[0, 0, :5]}")
compare_tensors(normed_hidden_2, ground_truth.get('post_attention_layernorm'), "Post-Attention LayerNorm")


STEP 13: Post-Attention LayerNorm
Post-attention LayerNorm output shape: torch.Size([1, 1, 2048])
Post-attention LayerNorm sample: tensor([ 0.1094, -0.0240,  0.1553,  0.0469, -0.3535], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<SliceBackward0>)
  ⚠️  No ground truth available for Post-Attention LayerNorm


In [22]:
# Step 14: MLP (Feed-Forward Network)
print(f"MLP Layer 0 dimensions:")
for name, param in layer_0.mlp.named_parameters():
    print(f"  {name}: {param.shape}")

print(f"\n{'='*60}")
print("STEP 14: MLP - Gate and Up Projections")
print(f"{'='*60}")
mlp = layer_0.mlp
gate_proj_output = mlp.gate_proj(normed_hidden_2)
up_proj_output = mlp.up_proj(normed_hidden_2)
print(f"Gate projection output shape: {gate_proj_output.shape}")
print(f"Up projection output shape: {up_proj_output.shape}")
compare_tensors(gate_proj_output, ground_truth.get('gate_proj'), "Gate Projection")
compare_tensors(up_proj_output, ground_truth.get('up_proj'), "Up Projection")

MLP Layer 0 dimensions:
  gate_proj.weight: torch.Size([8192, 2048])
  up_proj.weight: torch.Size([8192, 2048])
  down_proj.weight: torch.Size([2048, 8192])

STEP 14: MLP - Gate and Up Projections
Gate projection output shape: torch.Size([1, 1, 8192])
Up projection output shape: torch.Size([1, 1, 8192])
  ⚠️  No ground truth available for Gate Projection
  ⚠️  No ground truth available for Up Projection


In [23]:
# Apply SiLU activation to gate
mlp_hidden = F.silu(gate_proj_output) * up_proj_output
print(f"\nMLP hidden (SiLU(gate) * up) shape: {mlp_hidden.shape}")
print(f"MLP hidden sample: {mlp_hidden[0, 0, :5]}")
print(f"SiLU(gate) sample: {F.silu(gate_proj_output)[0, 0, :5]}")
print(f"Up sample: {up_proj_output[0, 0, :5]}")
print(f"Product sample: {mlp_hidden[0, 0, :5]}")


MLP hidden (SiLU(gate) * up) shape: torch.Size([1, 1, 8192])
MLP hidden sample: tensor([0.0007, 0.0151, 0.0095, 0.0103, 0.0018], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<SliceBackward0>)
SiLU(gate) sample: tensor([ 0.0194, -0.0654,  0.0796,  0.0474, -0.0366], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<SliceBackward0>)
Up sample: tensor([ 0.0369, -0.2305,  0.1191,  0.2168, -0.0493], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<SliceBackward0>)
Product sample: tensor([0.0007, 0.0151, 0.0095, 0.0103, 0.0018], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<SliceBackward0>)


In [24]:
# Step 15: MLP Down Projection
print(f"\n{'='*60}")
print("STEP 15: MLP - Down Projection")
print(f"{'='*60}")
mlp_output = mlp.down_proj(mlp_hidden)
print(f"MLP output shape: {mlp_output.shape}")
print(f"MLP output sample: {mlp_output[0, 0, :5]}")
compare_tensors(mlp_output, ground_truth.get('down_proj'), "Down Projection")

# Step 16: Residual Connection 2
print(f"\n{'='*60}")
print("STEP 16: Residual Connection (MLP)")
print(f"{'='*60}")
hidden_states_final = hidden_states_after_attn + mlp_output
print(f"Final hidden states shape: {hidden_states_final.shape}")
print(f"Final hidden states sample: {hidden_states_final[0, 0, :5]}")
print(f"Before MLP residual: {hidden_states_after_attn[0, 0, :5]}")
print(f"MLP output: {mlp_output[0, 0, :5]}")
print(f"After addition: {hidden_states_final[0, 0, :5]}")
compare_tensors(hidden_states_final, ground_truth.get('layer_0_output'), "Final Layer 0 Output (with residual)")
print(hidden_states_final)
print(ground_truth.get('layer_0_output'))
# ============================================================
# FINAL VERIFICATION
# ============================================================
"""
print(f"\n{'='*60}")
print("FINAL VERIFICATION")
print(f"{'='*60}")

# Compare final output with ground truth by running full model
with torch.no_grad():
    # Get all hidden states
    model.config.output_hidden_states = True
    full_outputs = model(
        input_ids=decode_token_id,
        attention_mask=torch.ones_like(decode_token_id),
        past_key_values=past_key_values,
        use_cache=True,
        return_dict=True
    )
    model.config.output_hidden_states = False

if hasattr(full_outputs, 'hidden_states') and full_outputs.hidden_states is not None:
    # hidden_states[0] is embedding, hidden_states[1] is after layer 0
    layer_0_gt_output = full_outputs.hidden_states[1]
    compare_tensors(hidden_states_final, layer_0_gt_output, "Final Layer 0 Output")
else:
    print(f"Hidden states not available in output")

print(f"\n{'='*60}")
print("SUMMARY")
print(f"{'='*60}")
print(f"✓ Layer 0 decode step completed successfully!")
print(f"✓ All intermediate steps verified against ground truth")
print(f"\nManual Layer 0 output: {hidden_states_final[0, 0, :5]}")
"""


STEP 15: MLP - Down Projection
MLP output shape: torch.Size([1, 1, 2048])
MLP output sample: tensor([-0.0112,  0.0157, -0.0232,  0.0417,  0.0179], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<SliceBackward0>)
  ⚠️  No ground truth available for Down Projection

STEP 16: Residual Connection (MLP)
Final hidden states shape: torch.Size([1, 1, 2048])
Final hidden states sample: tensor([ 0.0027,  0.0126, -0.0012,  0.0483, -0.0253], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<SliceBackward0>)
Before MLP residual: tensor([ 0.0140, -0.0031,  0.0220,  0.0065, -0.0432], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<SliceBackward0>)
MLP output: tensor([-0.0112,  0.0157, -0.0232,  0.0417,  0.0179], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<SliceBackward0>)
After addition: tensor([ 0.0027,  0.0126, -0.0012,  0.0483, -0.0253], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<SliceBackward0>)
  ✗ MISMATCH: Final Layer 0 Output (with residual)
    Max 

'\nprint(f"\n{\'=\'*60}")\nprint("FINAL VERIFICATION")\nprint(f"{\'=\'*60}")\n\n# Compare final output with ground truth by running full model\nwith torch.no_grad():\n    # Get all hidden states\n    model.config.output_hidden_states = True\n    full_outputs = model(\n        input_ids=decode_token_id,\n        attention_mask=torch.ones_like(decode_token_id),\n        past_key_values=past_key_values,\n        use_cache=True,\n        return_dict=True\n    )\n    model.config.output_hidden_states = False\n\nif hasattr(full_outputs, \'hidden_states\') and full_outputs.hidden_states is not None:\n    # hidden_states[0] is embedding, hidden_states[1] is after layer 0\n    layer_0_gt_output = full_outputs.hidden_states[1]\n    compare_tensors(hidden_states_final, layer_0_gt_output, "Final Layer 0 Output")\nelse:\n    print(f"Hidden states not available in output")\n\nprint(f"\n{\'=\'*60}")\nprint("SUMMARY")\nprint(f"{\'=\'*60}")\nprint(f"✓ Layer 0 decode step completed successfully!")\nprin