In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from torch.nn import MultiheadAttention
import numpy as np

# Sanity constants
name = "attn_cmp"
batch = 2
seq_len = 2048
n_dims = 768
n_heads = 12
n_layers = 6
ff_mult = 4
head_dim = n_dims // n_heads
dtype = torch.float16
device = "cuda"
seed = 42

torch.manual_seed(seed)

In [None]:
# Create dummy input
x = torch.randn(seq_len, batch, n_dims, dtype=dtype, device=device)  # shape: [T, B, D]

# Extract components
head_dim = n_dims // n_heads

# 1. Setup PyTorch Attention
attn = torch.nn.MultiheadAttention(
  embed_dim=n_dims, 
  num_heads=n_heads, 
  bias=True, 
  dropout=0.0, 
  batch_first=False, 
  dtype=dtype, 
  device=device
)

In [None]:
from lkeravnos import Transformer

# Construct CUDA transformer
Transformer.construct(
  name, 
  batch_size=batch, 
  sequence_length=seq_len, 
  num_dims=n_dims, 
  num_heads=n_heads, 
  num_layers=n_layers, 
  ff_multiplier=ff_mult, 
  verbose=True
)

In [None]:
Transformer.get_info("attn_cmp")

In [None]:
def half_tensor_to_uint16_numpy(tensor: torch.Tensor) -> np.ndarray:
    """Convert torch.float16 Tensor to np.uint16 using raw memory."""
    return tensor.cpu().numpy().view(np.uint16)

with torch.no_grad():
  qkv_weight = attn.in_proj_weight  # shape: [3D, D]
  qkv_bias = attn.in_proj_bias      # shape: [3D]
  out_weight = attn.out_proj.weight  # [D, D]
  out_bias = attn.out_proj.bias      # [D]

  # .reshape(...).contiguous() just to be safe
  qkv_weight_np = half_tensor_to_uint16_numpy(qkv_weight.reshape(3, n_dims, n_dims).contiguous())
  qkv_bias_np = half_tensor_to_uint16_numpy(qkv_bias.reshape(3, n_dims).contiguous())
  out_weight_np = half_tensor_to_uint16_numpy(out_weight.contiguous())
  out_bias_np = half_tensor_to_uint16_numpy(out_bias.contiguous())

  Transformer.edit_tensor(name, "qkv_proj", qkv_weight_np)
  Transformer.edit_tensor(name, "qkv_proj_bias", qkv_bias_np)
  Transformer.edit_tensor(name, "out_proj", out_weight_np)
  Transformer.edit_tensor(name, "out_proj_bias", out_bias_np)

In [None]:
token_embed = x.transpose(0, 1).contiguous().view(batch * seq_len, n_dims).cpu().numpy().view(np.uint16)
Transformer.edit_tensor(name, "input_embed", token_embed, verbose=True)

In [None]:
Transformer.causal_self_attention(name, use_bias=True, dropout=0.0, seed=42, verbose=True)

In [None]:
qkv_cuda = Transformer.get_tensor(name, "qkv_matrix", True)  # shape: [B, T, 3, D]
qkv_cuda = torch.from_numpy(qkv_cuda.view(np.float16)).to(dtype).to(device)
print(qkv_cuda)

In [None]:
qkv_ref = torch.nn.functional.linear(x, qkv_weight, qkv_bias)
qkv_ref = qkv_ref.view(batch, seq_len, 3, n_dims).contiguous()  # [B, T, 3, D]
print(qkv_ref)

In [None]:
diff = torch.abs(qkv_ref - qkv_cuda)
max_diff = diff.max()
print("Max diff:", max_diff.item())

if max_diff < 1e-2:
    print("✅ QKV projection matches!")
else:
    print("❌ QKV projection mismatch!")


In [None]:
attn.eval()
with torch.no_grad():
  qkv_ref_pytorch, _ = attn(x, x, x, need_weights=True, average_attn_weights=False)

In [None]:
# Grab the raw Q, K for inspection
qkv = torch.nn.functional.linear(x, qkv_weight, qkv_bias)  # [T, B, 3D]
qkv = qkv.view(seq_len, batch, 3, n_heads, head_dim).permute(2, 1, 3, 0, 4)  # [3, B, H, T, D]
q, k, v = qkv[0], qkv[1], qkv[2]  # each: [B, H, T, D]

# Compute attention scores manually (before softmax)
q_scaled = q / head_dim**0.5
attn_scores_ref = torch.einsum("bhid,bhjd->bhij", q_scaled, k)  # [B, H, T, T]

# Apply causal mask if needed
mask = torch.triu(torch.ones(seq_len, seq_len, device=device, dtype=torch.bool), diagonal=1)
attn_scores_ref = attn_scores_ref.masked_fill(mask, float("-inf"))

# Apply softmax
attn_probs_ref = torch.softmax(attn_scores_ref, dim=-1)  # [B, H, T, T]

# Apply dropout
dropout_p = 0.0  # or 0.1 if you want to test that too
if dropout_p > 0.0:
    attn_probs_ref = torch.nn.functional.dropout(attn_probs_ref, p=dropout_p, training=True)

In [None]:
attn_probs_ref

In [None]:
# Fetch CUDA side output
attn_probs_cuda = Transformer.get_tensor(name, "attention_scores")
attn_probs_cuda = torch.from_numpy(attn_probs_cuda.view(np.float16)).to(dtype).to(device)
attn_probs_cuda

In [None]:
# Sum over keys for each query: should be ≈1
for t in range(head_dim): print(attn_probs_cuda[0, 3, t, :].sum())

In [None]:
torch.allclose(attn_probs_ref, attn_probs_cuda, rtol=1e-2, atol=1e-3)

In [None]:
(torch.abs(attn_probs_cuda - attn_probs_ref) > 1e-2).sum()