In [1]:
import torch
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM


In [2]:
model_id = "Qwen/Qwen3-235B-A22B"


In [3]:
cfg = AutoConfig.from_pretrained(model_id, trust_remote_code=True)  # 94 layers originally  [oai_citation:0‡Hugging Face](https://huggingface.co/Qwen/Qwen3-235B-A22B/blob/main/config.json?utm_source=chatgpt.com)


In [4]:
cfg.num_hidden_layers = 1          # ← the only change strictly required


In [5]:
cfg

Qwen3MoeConfig {
  "architectures": [
    "Qwen3MoeForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 151643,
  "decoder_sparse_step": 1,
  "eos_token_id": 151645,
  "head_dim": 128,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 12288,
  "max_position_embeddings": 40960,
  "max_window_layers": 94,
  "mlp_only_layers": [],
  "model_type": "qwen3_moe",
  "moe_intermediate_size": 1536,
  "norm_topk_prob": true,
  "num_attention_heads": 64,
  "num_experts": 128,
  "num_experts_per_tok": 8,
  "num_hidden_layers": 1,
  "num_key_value_heads": 4,
  "output_router_logits": false,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 1000000.0,
  "router_aux_loss_coef": 0.001,
  "sliding_window": null,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.51.3",
  "use_cache": true,
  "use_sliding_window": false,
  "vocab_size": 151936
}

In [6]:
torch.set_default_device("cuda")                # PyTorch 2.1+


In [7]:
torch.manual_seed(42)              # reproducible randomness
with torch.no_grad():
    model = AutoModelForCausalLM.from_config(cfg, trust_remote_code=True)


In [8]:
model

Qwen3MoeForCausalLM(
  (model): Qwen3MoeModel(
    (embed_tokens): Embedding(151936, 4096)
    (layers): ModuleList(
      (0): Qwen3MoeDecoderLayer(
        (self_attn): Qwen3MoeAttention(
          (q_proj): Linear(in_features=4096, out_features=8192, bias=False)
          (k_proj): Linear(in_features=4096, out_features=512, bias=False)
          (v_proj): Linear(in_features=4096, out_features=512, bias=False)
          (o_proj): Linear(in_features=8192, out_features=4096, bias=False)
          (q_norm): Qwen3MoeRMSNorm((128,), eps=1e-06)
          (k_norm): Qwen3MoeRMSNorm((128,), eps=1e-06)
        )
        (mlp): Qwen3MoeSparseMoeBlock(
          (gate): Linear(in_features=4096, out_features=128, bias=False)
          (experts): ModuleList(
            (0-127): 128 x Qwen3MoeMLP(
              (gate_proj): Linear(in_features=4096, out_features=1536, bias=False)
              (up_proj): Linear(in_features=4096, out_features=1536, bias=False)
              (down_proj): Linear(in_fe

In [10]:
import transformers.modeling_utils

In [12]:
tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

loading file vocab.json from cache at /global/homes/j/jundac/.cache/huggingface/hub/models--Qwen--Qwen3-235B-A22B/snapshots/c30ce1aa8a0ff9cebf95e95b4b8fd90826043fd0/vocab.json
loading file merges.txt from cache at /global/homes/j/jundac/.cache/huggingface/hub/models--Qwen--Qwen3-235B-A22B/snapshots/c30ce1aa8a0ff9cebf95e95b4b8fd90826043fd0/merges.txt
loading file tokenizer.json from cache at /global/homes/j/jundac/.cache/huggingface/hub/models--Qwen--Qwen3-235B-A22B/snapshots/c30ce1aa8a0ff9cebf95e95b4b8fd90826043fd0/tokenizer.json
loading file added_tokens.json from cache at None
loading file special_tokens_map.json from cache at None
loading file tokenizer_config.json from cache at /global/homes/j/jundac/.cache/huggingface/hub/models--Qwen--Qwen3-235B-A22B/snapshots/c30ce1aa8a0ff9cebf95e95b4b8fd90826043fd0/tokenizer_config.json
loading file chat_template.jinja from cache at None
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or

In [14]:
profile = {}  # name → dict(start, end, params)

def want_hook(name, module):
    # leaf = no children; skip attention classes (FlashAttention2, QwenAttention, …)
    is_leaf = len(list(module.children())) == 0
    is_attn = "attn" in module.__class__.__name__.lower() or \
              "attention" in module.__class__.__name__.lower()
    return is_leaf 
    # return is_leaf and not is_attn

In [15]:
import time

In [16]:
for name, m in model.named_modules():
    if not want_hook(name, m):
        continue

    # forward-pre & forward hooks share the closure variable 'name'
    def pre_hook(mod, inp, name=name):
        profile[name] = {"start": time.perf_counter(),
                         "params": sum(p.numel() for p in mod.parameters())}

    def post_hook(mod, inp, out, name=name):
        torch.cuda.synchronize()
        profile[name]["end"] = time.perf_counter()

    m.register_forward_pre_hook(pre_hook, prepend=True)
    m.register_forward_hook(post_hook)


In [17]:
old_attn_funcs = transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS._global_mapping

In [18]:
# Hook the attention functions too
new_attn_funcs = {}
for k, v in old_attn_funcs.items():
    def new_func(*args, **kwargs):
        start = time.perf_counter()
        ret = v(*args, **kwargs)
        torch.cuda.synchronize()
        end = time.perf_counter()
        profile["model.layers.0.self_attn.attn"] = {
            "start": start,
            "params": 0,
            "end": end,
        }
        return ret
    new_attn_funcs[k] = new_func


transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS._global_mapping = new_attn_funcs

In [None]:
K = 1024
ctx_len = K * 4
inp = tok(" a" * ctx_len, return_tensors="pt").to(model.device)
ctx_len = len(inp.input_ids[0])

with torch.no_grad():
    model(**inp)                        # warm-up (CUDA kernels/JIT)
    torch.cuda.synchronize()

    start = time.perf_counter()
    model(**inp)
    torch.cuda.synchronize()
    total = time.perf_counter() - start
    total *= 1e3


rows = []
for name, rec in profile.items():
    dur_ms = (rec["end"] - rec["start"]) * 1e3
    rows.append((dur_ms, name, rec["params"]))
# rows.sort(reverse=True)                # slowest first

print(f"{'module':45}  latency  params")
print("-"*70)

total_layer0_time = 0
expert_time = 0
expert_params = 0

layer_attn_time = 0
for dur, name, nparam in rows:
    # if "layers.0" not in name:
    #     continue
    if "experts" in name:
        expert_time += dur
        expert_params += nparam
    else:
        print(f"{name:45}  {dur:7.3f} ms  {nparam/1e6:7.2f} M")
    
    if name == "model.layers.0.self_attn.attn":
        layer_attn_time += dur

    if "layers.0" not in name:
        continue

    total_layer0_time += dur

print(f"{'model.layers.0.*expert*':45}  {expert_time:7.3f} ms  {expert_params/1e6:7.2f} M")
print("-"*70)
print(f"ctx_len: {ctx_len}")
print(f"total layer 0 time: {total_layer0_time:.3f} ms")
print(f"layer 0 attn time: {layer_attn_time:.3f} ms")
print(f"full forward time: {total:.3f} ms")

module                                         latency  params
----------------------------------------------------------------------
model.embed_tokens                               0.248 ms   622.33 M
model.rotary_emb                                 0.352 ms     0.00 M
model.layers.0.input_layernorm                   0.496 ms     0.00 M
model.layers.0.self_attn.q_proj                  1.172 ms    33.55 M
model.layers.0.self_attn.q_norm                  0.935 ms     0.00 M
model.layers.0.self_attn.k_proj                  0.147 ms     2.10 M
model.layers.0.self_attn.k_norm                  0.164 ms     0.00 M
model.layers.0.self_attn.v_proj                  0.148 ms     2.10 M
model.layers.0.self_attn.attn                    2.956 ms     0.00 M
model.layers.0.self_attn.o_proj                  1.143 ms    33.55 M
model.layers.0.post_attention_layernorm          0.513 ms     0.00 M
model.layers.0.mlp.gate                          0.085 ms     0.52 M
model.norm                            

In [None]:
"""

module                                         latency  params
----------------------------------------------------------------------
model.layers.0.input_layernorm                   0.495 ms     0.00 M
model.layers.0.self_attn.q_proj                  1.176 ms    33.55 M
model.layers.0.self_attn.q_norm                  0.936 ms     0.00 M
model.layers.0.self_attn.k_proj                  0.146 ms     2.10 M
model.layers.0.self_attn.k_norm                  0.158 ms     0.00 M
model.layers.0.self_attn.v_proj                  0.145 ms     2.10 M
model.layers.0.self_attn.o_proj                  3.853 ms    33.55 M
model.layers.0.post_attention_layernorm          0.510 ms     0.00 M
model.layers.0.mlp.gate                          0.085 ms     0.52 M
model.layers.0.*expert*                         23.759 ms  2415.92 M
----------------------------------------------------------------------
ctx_len: 4096
total layer 0 time: 31.263 ms
full forward time: 83.191 ms



module                                         latency  params
----------------------------------------------------------------------
model.layers.0.input_layernorm                   0.933 ms     0.00 M
model.layers.0.self_attn.q_proj                  2.200 ms    33.55 M
model.layers.0.self_attn.q_norm                  1.796 ms     0.00 M
model.layers.0.self_attn.k_proj                  0.279 ms     2.10 M
model.layers.0.self_attn.k_norm                  0.180 ms     0.00 M
model.layers.0.self_attn.v_proj                  0.272 ms     2.10 M
model.layers.0.self_attn.o_proj                 10.788 ms    33.55 M
model.layers.0.post_attention_layernorm          0.993 ms     0.00 M
model.layers.0.mlp.gate                          0.113 ms     0.52 M
model.layers.0.*expert*                         31.005 ms  2415.92 M
----------------------------------------------------------------------
ctx_len: 8192 (8K)
total layer 0 time: 48.558 ms
full forward time: 123.544 ms


module                                         latency  params
----------------------------------------------------------------------
model.layers.0.input_layernorm                   1.755 ms     0.00 M
model.layers.0.self_attn.q_proj                  4.331 ms    33.55 M
model.layers.0.self_attn.q_norm                  3.517 ms     0.00 M
model.layers.0.self_attn.k_proj                  0.392 ms     2.10 M
model.layers.0.self_attn.k_norm                  0.289 ms     0.00 M
model.layers.0.self_attn.v_proj                  0.385 ms     2.10 M
model.layers.0.self_attn.o_proj                 32.830 ms    33.55 M
model.layers.0.post_attention_layernorm          1.951 ms     0.00 M
model.layers.0.mlp.gate                          0.163 ms     0.52 M
model.layers.0.*expert*                         41.588 ms  2415.92 M
----------------------------------------------------------------------
ctx_len: 16384 (16K)
total layer 0 time: 87.201 ms
full forward time: 205.412 ms



module                                         latency  params
----------------------------------------------------------------------
model.layers.0.input_layernorm                   3.340 ms     0.00 M
model.layers.0.self_attn.q_proj                  7.718 ms    33.55 M
model.layers.0.self_attn.q_norm                  6.709 ms     0.00 M
model.layers.0.self_attn.k_proj                  0.563 ms     2.10 M
model.layers.0.self_attn.k_norm                  0.483 ms     0.00 M
model.layers.0.self_attn.v_proj                  0.559 ms     2.10 M
model.layers.0.self_attn.o_proj                 99.912 ms    33.55 M
model.layers.0.post_attention_layernorm          3.724 ms     0.00 M
model.layers.0.mlp.gate                          0.323 ms     0.52 M
model.layers.0.*expert*                         64.315 ms  2415.92 M
----------------------------------------------------------------------
ctx_len: 32768 (32K)
total layer 0 time: 187.646 ms
full forward time: 397.876 ms



module                                         latency  params
----------------------------------------------------------------------
model.layers.0.input_layernorm                   4.983 ms     0.00 M
model.layers.0.self_attn.q_proj                 11.568 ms    33.55 M
model.layers.0.self_attn.q_norm                 10.093 ms     0.00 M
model.layers.0.self_attn.k_proj                  0.886 ms     2.10 M
model.layers.0.self_attn.k_norm                  0.692 ms     0.00 M
model.layers.0.self_attn.v_proj                  0.855 ms     2.10 M
model.layers.0.self_attn.o_proj                209.589 ms    33.55 M
model.layers.0.post_attention_layernorm          5.603 ms     0.00 M
model.layers.0.mlp.gate                          0.317 ms     0.52 M
model.layers.0.*expert*                         86.480 ms  2415.92 M
----------------------------------------------------------------------
ctx_len: 49152 (48K)
total layer 0 time: 331.066 ms
full forward time: 632.409 ms



module                                         latency  params
----------------------------------------------------------------------
model.layers.0.input_layernorm                   6.600 ms     0.00 M
model.layers.0.self_attn.q_proj                 15.371 ms    33.55 M
model.layers.0.self_attn.q_norm                 13.356 ms     0.00 M
model.layers.0.self_attn.k_proj                  1.065 ms     2.10 M
model.layers.0.self_attn.k_norm                  0.903 ms     0.00 M
model.layers.0.self_attn.v_proj                  1.059 ms     2.10 M
model.layers.0.self_attn.o_proj                359.716 ms    33.55 M
model.layers.0.post_attention_layernorm          7.454 ms     0.00 M
model.layers.0.mlp.gate                          0.425 ms     0.52 M
model.layers.0.*expert*                        110.195 ms  2415.92 M
----------------------------------------------------------------------
ctx_len: 65536 (64K)
total layer 0 time: 516.144 ms
full forward time: 908.428 ms




module                                         latency  params
----------------------------------------------------------------------
model.layers.0.input_layernorm                   9.858 ms     0.00 M
model.layers.0.self_attn.q_proj                 23.034 ms    33.55 M
model.layers.0.self_attn.q_norm                 20.070 ms     0.00 M
model.layers.0.self_attn.k_proj                  1.570 ms     2.10 M
model.layers.0.self_attn.k_norm                  1.328 ms     0.00 M
model.layers.0.self_attn.v_proj                  1.565 ms     2.10 M
model.layers.0.self_attn.o_proj                786.287 ms    33.55 M
model.layers.0.post_attention_layernorm         11.157 ms     0.00 M
model.layers.0.mlp.gate                          0.617 ms     0.52 M
model.layers.0.*expert*                        155.536 ms  2415.92 M
----------------------------------------------------------------------
ctx_len: 98304 (96K)
total layer 0 time: 1011.022 ms
full forward time: 1587.503 ms
"""

In [52]:
mlp_forward_time = {
    4: 31.263,
    8: 48.558,
    16: 87.201,
    32: 187.646,
    48: 331.066,
    64: 516.144,
    96: 1011.022,
}
mlp_forward_time

{4: 31.263,
 8: 48.558,
 16: 87.201,
 32: 187.646,
 48: 331.066,
 64: 516.144,
 96: 1011.022}

In [53]:
torch.cuda.empty_cache()

In [23]:
rows = []
for name, rec in profile.items():
    dur_ms = (rec["end"] - rec["start"]) * 1e3
    rows.append((dur_ms, name, rec["params"]))
# rows.sort(reverse=True)                # slowest first

print(f"{'module':45}  latency  params")
print("-"*70)

total_doc_time = 0
expert_time = 0
expert_params = 0
for dur, name, nparam in rows:
    if "experts" in name:
        expert_time += dur
        expert_params += nparam
    else:
        print(f"{name:45}  {dur:7.3f} ms  {nparam/1e6:7.2f} M")
        pass
    total_doc_time += dur
print(f"{'model.layers.0.*expert*':45}  {expert_time:7.3f} ms  {expert_params/1e6:7.2f} M")
print("-"*70)
print(f"full forward time: {total:.3f} ms")
print(f"full document forward time: {total_doc_time:.3f} ms")
print(f"undocumented forward time: {total - total_doc_time:.3f} ms")

module                                         latency  params
----------------------------------------------------------------------
model.embed_tokens                               0.248 ms   622.33 M
model.rotary_emb                                 0.352 ms     0.00 M
model.layers.0.input_layernorm                   0.496 ms     0.00 M
model.layers.0.self_attn.q_proj                  1.172 ms    33.55 M
model.layers.0.self_attn.q_norm                  0.935 ms     0.00 M
model.layers.0.self_attn.k_proj                  0.147 ms     2.10 M
model.layers.0.self_attn.k_norm                  0.164 ms     0.00 M
model.layers.0.self_attn.v_proj                  0.148 ms     2.10 M
model.layers.0.self_attn.attn                    2.956 ms     0.00 M
model.layers.0.self_attn.o_proj                  1.143 ms    33.55 M
model.layers.0.post_attention_layernorm          0.513 ms     0.00 M
model.layers.0.mlp.gate                          0.085 ms     0.52 M
model.norm                            

In [27]:
import flash_attn
from flash_attn.flash_attn_interface import flash_attn_varlen_func

In [29]:
num_qo_heads = cfg.num_attention_heads
num_kv_heads = cfg.num_key_value_heads
head_dim = cfg.hidden_size // num_qo_heads
tp = 1
cp = 1
device = model.device

In [30]:
batch = [16 * 1024]

In [31]:
# Qwen3 253B activate attention data
num_qo_heads = num_qo_heads // tp
num_kv_heads = max(num_kv_heads // tp, 1)

kv_lens = [(1/2 + 1/(2 * cp)) * i for i in batch]

batch = [int(i // cp) for i in batch]
kv_lens = [int(i) for i in kv_lens]

total_tokens = sum(batch)
total_kv_tokens = sum(kv_lens)

q = torch.randn(total_tokens, num_qo_heads, head_dim, device=device, dtype=torch.bfloat16)
k = torch.randn(total_kv_tokens, num_kv_heads, head_dim, device=device, dtype=torch.bfloat16)
v = torch.randn(total_kv_tokens, num_kv_heads, head_dim, device=device, dtype=torch.bfloat16)
max_seqlen_q = max(batch)
max_seqlen_k = max(kv_lens)

cu_seqlens_q = [0,]
cu_seqlens_k = [0,]
for idx, _ in enumerate(batch):
    cu_seqlens_q.append(sum(batch[:idx+1]))
    cu_seqlens_k.append(sum(kv_lens[:idx+1]))
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, device=device)
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, device=device)
max_seqlen_q = torch.tensor(max_seqlen_q, dtype=torch.int32, device=device)
max_seqlen_k = torch.tensor(max_seqlen_k, dtype=torch.int32, device=device)


def test_flash_attn():
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    start_event.record()
    output = flash_attn_varlen_func(
        q, k, v, 
        cu_seqlens_q, cu_seqlens_k, 
        max_seqlen_q, max_seqlen_k,
        dropout_p=0.0, causal=True,
    )
    end_event.record()
    torch.cuda.synchronize()
    duration = start_event.elapsed_time(end_event)
    return duration

# warmup
for _ in range(5):
    test_flash_attn()

# benchmark
num_iters = 10
durations = []
for _ in range(num_iters):
    duration = test_flash_attn()
    durations.append(duration)

avg_duration = sum(durations) / len(durations)
print(f"average latency: {avg_duration:.3f} ms")

average latency: 12.630 ms
