diff --git a/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py b/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py index 967cd3b1..a9f90111 100644 --- a/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py +++ b/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py @@ -16,6 +16,8 @@ from twinkle.dataset import Dataset, DatasetMeta from twinkle.model import TransformersModel from twinkle.preprocessor import SelfCognitionProcessor +from twinkle.utils.framework import Torch +from twinkle.kernel import kernelize_model logger = get_logger() @@ -102,6 +104,9 @@ def train(): } }, ) + # npu patch + if Torch.is_npu_available(): + model = kernelize_model(model, mode='train', device='npu') lora_cfg = _build_lora_config(ENABLE_EP) model.add_adapter_to_model(ADAPTER_NAME, lora_cfg, gradient_accumulation_steps=GRAD_ACCUM_STEPS) model.set_optimizer('AdamW', lr=LR, foreach=False) diff --git a/cookbook/transformers/fsdp2.py b/cookbook/transformers/fsdp2.py index 450906c5..d7d001d3 100644 --- a/cookbook/transformers/fsdp2.py +++ b/cookbook/transformers/fsdp2.py @@ -9,6 +9,8 @@ from twinkle.dataset import Dataset, DatasetMeta from twinkle.model import TransformersModel from twinkle.preprocessor import SelfCognitionProcessor +from twinkle.utils.framework import Torch +from twinkle.kernel import kernelize_model logger = get_logger() @@ -72,7 +74,9 @@ def train(): # Use a TransformersModel model = TransformersModel(model_id=MODEL_ID) model.model._no_split_modules = {'Qwen3_5DecoderLayer'} - + # npu patch + if Torch.is_npu_available(): + model = kernelize_model(model, mode='train', device='npu') lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') # Add a lora to model, with name `default` diff --git a/src/twinkle/kernel/__init__.py b/src/twinkle/kernel/__init__.py index 564d0c4f..c7262eb0 100644 --- a/src/twinkle/kernel/__init__.py +++ b/src/twinkle/kernel/__init__.py @@ -49,7 +49,7 @@ def kernelize_model( # so that patched module classes are used when new instances are created. if device == 'npu' or (device is None and _is_npu_device(model)): try: - apply_npu_patch() + apply_npu_patch(model) except Exception: logger.warning('NPU patch failed. Continuing without fused ops.', exc_info=True) diff --git a/src/twinkle/kernel/chunk_gated_delta_rule.py b/src/twinkle/kernel/chunk_gated_delta_rule.py new file mode 100644 index 00000000..553fb122 --- /dev/null +++ b/src/twinkle/kernel/chunk_gated_delta_rule.py @@ -0,0 +1,362 @@ +'''Ascend NPU implementation of chunk_gated_delta_rule for Flash Linear Attention (FLA). +This module provides a drop-in replacement for fla.ops.gated_delta_rule.chunk_gated_delta_rule, +redirecting the underlying Triton kernels to MindSpeed's NPU-compatible counterparts. +It is consumed by twinkle.kernel.monkey_patch_npu to enable the fast linear-attention +path of Qwen3.5 on Ascend hardware.''' + +import torch +import warnings +from mindspeed.lite.ops.triton.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu, chunk_gated_delta_rule_fwd_h +from mindspeed.lite.ops.triton.chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv_local, chunk_fwd_o +from mindspeed.lite.ops.triton.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd +from mindspeed.lite.ops.triton.cumsum import chunk_local_cumsum +from mindspeed.lite.ops.triton.solve_tril import solve_tril +from mindspeed.lite.ops.triton.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard +from mindspeed.lite.ops.triton.wy_fast import prepare_wy_repr_bwd, recompute_w_u_fwd +from typing import Optional + + +def _torch_l2norm_fwd( + x: torch.Tensor, + eps: float = 1e-6, + output_dtype: Optional[torch.dtype] = None, +): + x_shape_og = x.shape + x = x.view(-1, x.shape[-1]) + x_float = x.float() + rstd = torch.rsqrt(torch.sum(x_float * x_float, dim=-1) + eps) + y = x_float * rstd.unsqueeze(-1) + y = y.to(output_dtype if output_dtype is not None else x.dtype) + return y.view(x_shape_og), rstd.view(x_shape_og[:-1]) + + +def _torch_l2norm_bwd( + y: torch.Tensor, + rstd: torch.Tensor, + dy: torch.Tensor, + eps: float = 1e-6, +): + y_shape_og = y.shape + y = y.view(-1, y.shape[-1]) + dy = dy.view(-1, dy.shape[-1]) + y_float = y.float() + dy_float = dy.float() + rstd = rstd.view(-1).float() + dx = dy_float * rstd.unsqueeze(-1) + dx = dx - torch.sum(dy_float * y_float, dim=-1, keepdim=True) * y_float * rstd.unsqueeze(-1) + return dx.to(y.dtype).view(y_shape_og) + + +def chunk_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, +): + g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens, head_first=False) + # obtain WY representation. u is actually the new v. + A = chunk_scaled_dot_kkt_fwd( + k=k, g=g, beta=beta, cu_seqlens=cu_seqlens, chunk_size=chunk_size, output_dtype=torch.float32) + A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) + w, u = recompute_w_u_fwd( + k=k, + v=v, + beta=beta, + A=A, + g=g, + cu_seqlens=cu_seqlens, + ) + h, v_new, final_state = chunk_gated_delta_rule_fwd_h( + k=k, + w=w, + u=u, + g=g, + initial_state=initial_state, + output_final_state=output_final_state, + chunk_size=chunk_size, + cu_seqlens=cu_seqlens, + ) + o = chunk_fwd_o( + q=q, + k=k, + v=v_new, + h=h, + g=g, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + ) + return g, o, A, final_state + + +def chunk_gated_delta_rule_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + A: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + do: torch.Tensor, + dht: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, +): + w, u = recompute_w_u_fwd( + k=k, + v=v, + beta=beta, + A=A, + g=g, + cu_seqlens=cu_seqlens, + ) + h, v_new, _ = chunk_gated_delta_rule_fwd_h( + k=k, + w=w, + u=u, + g=g, + initial_state=initial_state, + output_final_state=False, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + ) + dv = chunk_bwd_dv_local( + q=q, + k=k, + g=g, + do=do, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + ) + dh, dh0, dv = chunk_gated_delta_rule_bwd_dhu( + q=q, + k=k, + w=w, + g=g, + h0=initial_state, + dht=dht, + do=do, + dv=dv, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + ) + dq, dk, dw, dg = chunk_bwd_dqkwg( + q=q, + k=k, + v=v_new, + w=w, + g=g, + h=h, + dv=dv, + do=do, + dh=dh, + chunk_size=chunk_size, + scale=scale, + cu_seqlens=cu_seqlens, + ) + dk2, dv, db, dg2 = prepare_wy_repr_bwd( + k=k, v=v, beta=beta, g=g, A=A, dw=dw, du=dv, cu_seqlens=cu_seqlens, chunk_size=chunk_size) + dk.add_(dk2) + dg.add_(dg2) + if dg.dtype != torch.float32: + raise ValueError(f'dg current type is {dg.dtype} , should be float32') + dg = chunk_local_cumsum(dg, chunk_size=chunk_size, reverse=True, cu_seqlens=cu_seqlens, head_first=False) + return dq, dk, dv, db, dg, dh0 + + +class ChunkGatedDeltaRuleFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, + use_qk_l2norm_in_kernel: bool = False, + chunk_size: int = 64, + ): + if use_qk_l2norm_in_kernel: + q, q_rstd = _torch_l2norm_fwd(q) + k, k_rstd = _torch_l2norm_fwd(k) + else: + q_rstd, k_rstd = None, None + + g, o, A, final_state = chunk_gated_delta_rule_fwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size) + ctx.save_for_backward(q, q_rstd, k, k_rstd, v, g, beta, A, initial_state, cu_seqlens) + ctx.scale = scale + ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel + ctx.chunk_size = chunk_size + return o.to(q.dtype), final_state + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do: torch.Tensor, dht: torch.Tensor): + q, q_rstd, k, k_rstd, v, g, beta, A, initial_state, cu_seqlens = ctx.saved_tensors + dq, dk, dv, db, dg, dh0 = chunk_gated_delta_rule_bwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + A=A, + scale=ctx.scale, + initial_state=initial_state, + do=do, + dht=dht, + cu_seqlens=cu_seqlens, + chunk_size=ctx.chunk_size, + ) + if ctx.use_qk_l2norm_in_kernel: + dq = _torch_l2norm_bwd(q, q_rstd, dq) + dk = _torch_l2norm_bwd(k, k_rstd, dk) + return dq.to(q), dk.to(k), dv.to(v), dg.to(g), db.to(beta), None, dh0, None, None, None, None + + +@torch.compiler.disable +def chunk_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + use_qk_l2norm_in_kernel: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, + head_first: bool = False, +): + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]`. + g (torch.Tensor): + (forget) gating tensor (in log space!) of shape `[B, T, H]`. + beta (torch.Tensor): + betas of shape `[B, T, H]`. + scale (Optional[float]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + use_qk_l2norm_in_kernel (bool): + Whether to apply L2norm to the q/k tensor internally. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `False`. + This argument has been deprecated. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda') + >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid() + >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda')) + >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda') + >>> o, ht = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o, ht = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens + ) + """ + if q.dtype != k.dtype or k.dtype != v.dtype: + raise ValueError( + f'q current type is {q.dtype}, k current type is {k.dtype}, v current type is {v.dtype}, should be equal') + if q.dtype == torch.float32: + raise ValueError('ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16.') + if len(beta.shape) != 3: + raise ValueError(f'beta current shape len is {len(beta.shape)}, beta must be of shape [B, T, H] ' + f'if head_first=False, or [B, H, T] otherwise.') + if head_first: + warnings.warn('head_first is deprecated and will be removed in a future version. ' + 'Please use head_first=False for now instead.') + if not head_first and q.shape[1] < q.shape[2]: + warnings.warn( + f'Input tensor shape suggests format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). ' + 'This may indicate the inputs were passed in head-first format [B, H, T, ...] ' + 'when head_first=False was specified. ' + 'Please verify your input tensor format matches the expected shape [B, T, H, ...].') + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError(f'The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`.' + f'Please flatten variable-length inputs before processing.') + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError(f'The number of initial states is expected to be equal to the number of input sequences, ' + f'i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.') + if scale is None: + scale = k.shape[-1]**-0.5 + o, final_state = ChunkGatedDeltaRuleFunction.apply( + q, + k, + v, + g, + beta, + scale, + initial_state, + output_final_state, + cu_seqlens, + use_qk_l2norm_in_kernel, + chunk_size, + ) + return o, final_state diff --git a/src/twinkle/kernel/monkey_patch_npu.py b/src/twinkle/kernel/monkey_patch_npu.py index e04fe44c..acb2a6aa 100644 --- a/src/twinkle/kernel/monkey_patch_npu.py +++ b/src/twinkle/kernel/monkey_patch_npu.py @@ -4,11 +4,13 @@ >>> from twinkle.kernel.monkey_patch_npu import apply_npu_patch >>> if Torch.is_npu_available(): - ... apply_npu_patch() # Enables all patches unconditionally + ... apply_npu_patch(model) """ import importlib +import os import torch +import torch.nn.functional as F from torch import nn from transformers.utils import is_torch_npu_available @@ -22,6 +24,26 @@ if _is_torch_npu_available: import torch_npu +# --------------------------------------------------------------------------- +# Utils +# --------------------------------------------------------------------------- + + +def import_optional_module(module_name: str): + """Import a module, returning None if unavailable.""" + try: + return importlib.import_module(module_name) + except ImportError as exc: + logger.debug('Failed to import optional module %s: %s', module_name, exc) + return None + + +def _resolve_unsqueeze_dim(position_ids=None, unsqueeze_dim=1): + if isinstance(position_ids, int) and unsqueeze_dim == 1: + return position_ids + return unsqueeze_dim + + # ============================================================================= # Section 1: MoE Grouped MatMul (GMM) # ============================================================================= @@ -85,7 +107,168 @@ def apply_hf_moe_grouped_mm_patch() -> None: # ============================================================================= -# Section 2: Fused Operators +# Section 1b: MoE Packed Experts +# ============================================================================= + + +def _normalize_packed_expert_weights(module, input_dtype: torch.dtype, hidden_dim: int): + """Normalize packed expert weight shapes for NPU grouped matmul.""" + gate_up_proj = module.gate_up_proj.to(input_dtype) + down_proj = module.down_proj.to(input_dtype) + + if gate_up_proj.shape[1] == hidden_dim: + gate_up_weight = gate_up_proj + elif gate_up_proj.shape[2] == hidden_dim: + gate_up_weight = gate_up_proj.transpose(1, 2) + else: + raise RuntimeError(f'Unsupported gate_up_proj shape for NPU MoE patch: {tuple(gate_up_proj.shape)}.') + + if down_proj.shape[2] == hidden_dim: + down_weight = down_proj + elif down_proj.shape[1] == hidden_dim: + down_weight = down_proj.transpose(1, 2) + else: + raise RuntimeError(f'Unsupported down_proj shape for NPU MoE patch: {tuple(down_proj.shape)}.') + + return gate_up_weight, down_weight + + +def _get_cached_expert_weights(self, target_dtype: torch.dtype, hidden_dim: int): + """Return normalized expert weights with automatic cache invalidation. + + Cache key combines (dtype, gate_version, down_version). This correctly + handles: + - Full-parameter training: optimizer in-place updates bump _version + - LoRA training: frozen weights keep _version stable, cache persists + - Inference: cache is permanent + - AMP autocast: separate cache per dtype + + Safety: when weights require gradients, the cache is bypassed to avoid + breaking the PyTorch autograd graph (non-leaf tensors from .to() cannot + be reused across forward passes). + """ + requires_grad = ( + getattr(self.gate_up_proj, 'requires_grad', False) or getattr(self.down_proj, 'requires_grad', False)) + cache_attr = '_npu_expert_cache' + if not requires_grad and hasattr(self, cache_attr): + cached_dtype, cached_gate_ver, cached_down_ver, cached = getattr(self, cache_attr) + if (cached_dtype == target_dtype and cached_gate_ver == self.gate_up_proj._version + and cached_down_ver == self.down_proj._version): + return cached + + weights = _normalize_packed_expert_weights(self, target_dtype, hidden_dim) + if not requires_grad: + setattr( + self, + cache_attr, + (target_dtype, self.gate_up_proj._version, self.down_proj._version, weights), + ) + return weights + + +def npu_packed_moe_experts_forward( + self, + hidden_states: torch.Tensor, + router_indices_or_routing_weights: torch.Tensor, + routing_weights_or_router_indices: torch.Tensor, +) -> torch.Tensor: + """Packed MoE experts forward using NPU grouped matmul. + + Compatible with Qwen3-MoE, Qwen3.5-MoE, and any model using packed experts + with the standard ``(hidden_states, router_indices, routing_weights)`` call convention. + """ + if router_indices_or_routing_weights.dtype in {torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8}: + router_indices = router_indices_or_routing_weights + routing_weights = routing_weights_or_router_indices + else: + routing_weights = router_indices_or_routing_weights + router_indices = routing_weights_or_router_indices + + output_shape = hidden_states.shape + hidden_dim = output_shape[-1] + hidden_states = hidden_states.reshape(-1, hidden_dim) + + if routing_weights.shape != router_indices.shape: + routing_weights = torch.gather(routing_weights, dim=-1, index=router_indices.to(torch.long)) + routing_weights = routing_weights.to(hidden_states.dtype) + router_indices = router_indices.to(torch.int32) + + permuted_hidden_states, row_ids_map = torch_npu.npu_moe_token_permute(hidden_states, router_indices) + tokens_per_expert = torch.bincount(router_indices.view(-1), minlength=self.num_experts).to(torch.int64) + + # Cached normalized weights: auto-invalidates on weight updates (full-param) + # and persists when frozen (LoRA / inference). + gate_up_weight, down_weight = _get_cached_expert_weights(self, hidden_states.dtype, hidden_dim) + + intermediate_hidden_states = GmmFunction.apply(permuted_hidden_states, tokens_per_expert, gate_up_weight) + intermediate_activations = torch_npu.npu_swiglu(intermediate_hidden_states, dim=-1) + output = GmmFunction.apply(intermediate_activations, tokens_per_expert, down_weight) + next_states = torch_npu.npu_moe_token_unpermute(output, row_ids_map, probs=routing_weights) + return next_states.view(*output_shape) + + +# ============================================================================= +# Section 1c: MoE Sparse Block +# ============================================================================= + + +def _topk_from_router_logits(module, hidden_states: torch.Tensor, router_logits: torch.Tensor): + """Compute top-k routing from router logits (Transformers 4.x style).""" + routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float) + routing_weights, router_indices = torch.topk(routing_weights, module.top_k, dim=-1) + if getattr(module, 'norm_topk_prob', True): + routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states.dtype) + return routing_weights, router_indices + + +def _add_shared_expert(self, hidden_states: torch.Tensor, expert_output: torch.Tensor) -> torch.Tensor: + """Add shared expert output with sigmoid gating. + + Automatically skips if the module lacks shared_expert / shared_expert_gate. + """ + if not (hasattr(self, 'shared_expert') and hasattr(self, 'shared_expert_gate')): + return expert_output + + shared_expert_output = self.shared_expert(hidden_states) + shared_expert_output = (F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output) + return expert_output + shared_expert_output + + +def _qwen3_5_moe_forward_transformers_5(self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, + selected_experts: torch.Tensor) -> torch.Tensor: + """Transformers 5.x path: gate returns (router_logits, routing_weights, selected_experts).""" + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + expert_output = self.experts(hidden_states, selected_experts, routing_weights) + expert_output = _add_shared_expert(self, hidden_states, expert_output) + return expert_output.reshape(batch_size, sequence_length, hidden_dim) + + +def _qwen3_5_moe_forward_linear_gate(self, hidden_states: torch.Tensor, router_logits: torch.Tensor) -> torch.Tensor: + """Transformers 4.x path: gate is nn.Linear and returns router logits.""" + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + routing_weights, selected_experts = _topk_from_router_logits(self, hidden_states, router_logits) + expert_output = self.experts(hidden_states, selected_experts, routing_weights) + expert_output = _add_shared_expert(self, hidden_states, expert_output) + return expert_output.reshape(batch_size, sequence_length, hidden_dim) + + +def npu_qwen3_5_moe_sparse_block_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """NPU-accelerated SparseMoeBlock forward with dual Transformers version support.""" + hidden_dim = hidden_states.shape[-1] + gate_output = self.gate(hidden_states.view(-1, hidden_dim)) + + if isinstance(gate_output, tuple): + _, routing_weights, selected_experts = gate_output + return _qwen3_5_moe_forward_transformers_5(self, hidden_states, routing_weights, selected_experts) + + return _qwen3_5_moe_forward_linear_gate(self, hidden_states, gate_output) + + +# ============================================================================= +# Section 2: Fused Operators (RMSNorm / RoPE / SwiGLU / SDPA) # ============================================================================= @@ -96,33 +279,106 @@ def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps + # Detect residual parameterization (e.g. Qwen3.5: scale = 1.0 + weight) + # once at initialization to avoid CPU-synchronizing Tensor.item() calls. + self._residual_param = abs(self.weight.data.mean().item()) < 0.3 + if self._residual_param: + logger.debug('[NPU] NpuRMSNorm using residual parameterization (1.0 + weight)') + + def _get_effective_weight(self, target_dtype: torch.dtype): + if self._residual_param: + return (1.0 + self.weight).to(dtype=target_dtype) + return self.weight.to(dtype=target_dtype) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0] + scale = self._get_effective_weight(hidden_states.dtype) + return torch_npu.npu_rms_norm(hidden_states, scale, epsilon=self.variance_epsilon)[0] def extra_repr(self) -> str: return f'{tuple(self.weight.shape)}, eps={self.variance_epsilon}' +def npu_gated_rms_norm_forward(self, hidden_states, gate=None): + """NPU forward for Gated RMSNorm. + + The FP32 mode is controlled by ``TWINKLE_NPU_GATED_RMSNorm_FP32``, + resolved once during patching and stored in ``self._twinkle_force_fp32``. + """ + input_dtype = hidden_states.dtype + _eps = getattr(self, 'variance_epsilon', getattr(self, 'eps', 1e-6)) + + # Read the cached flag; no env lookup in the hot path. + force_fp32 = getattr(self, '_twinkle_force_fp32', False) + if force_fp32: + hidden_states = hidden_states.to(torch.float32) + weight = self.weight.float() + gate = gate.to(torch.float32) if gate is not None else None + else: + weight = self.weight + + hidden_states = torch_npu.npu_rms_norm(hidden_states, weight, epsilon=_eps)[0] + + if gate is not None: + hidden_states = hidden_states * F.silu(gate) + + return hidden_states.to(input_dtype) + + +def _make_apply_npu_rotary_emb(): + _cached_partial = {} + + def _apply_npu_rotary_emb(q, k, cos, sin): + rotary_dim = cos.shape[-1] + query_dim = q.shape[-1] + shape_key = (rotary_dim, query_dim) + + use_partial = _cached_partial.get(shape_key) + if use_partial is None: + use_partial = rotary_dim < query_dim + _cached_partial[shape_key] = use_partial + + if use_partial: + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + q_embed = torch_npu.npu_rotary_mul(q_rot, cos, sin).to(q.dtype) + k_embed = torch_npu.npu_rotary_mul(k_rot, cos, sin).to(k.dtype) + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + else: + q_embed = torch_npu.npu_rotary_mul(q, cos, sin).to(q.dtype) + k_embed = torch_npu.npu_rotary_mul(k, cos, sin).to(k.dtype) + + return q_embed, k_embed + + return _apply_npu_rotary_emb + + +_apply_npu_rotary_emb = _make_apply_npu_rotary_emb() + + def npu_apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - r"""Fused RoPE via ``torch_npu.npu_rotary_mul``.""" + """Fused RoPE via ``torch_npu.npu_rotary_mul`` with automatic Partial RoPE support.""" + unsqueeze_dim = _resolve_unsqueeze_dim(position_ids, unsqueeze_dim) cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) - return torch_npu.npu_rotary_mul(q, cos, sin), torch_npu.npu_rotary_mul(k, cos, sin) + return _apply_npu_rotary_emb(q, k, cos, sin) def npu_apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): - r"""Multimodal RoPE for Qwen2.5-VL.""" + """Multimodal RoPE for Qwen2.5-VL with automatic Partial RoPE support.""" mrope_section = mrope_section * 2 cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim) sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim) - return torch_npu.npu_rotary_mul(q, cos, sin), torch_npu.npu_rotary_mul(k, cos, sin) + return _apply_npu_rotary_emb(q, k, cos, sin) def npu_swiglu_forward(self, hidden_state): - r"""Fused SwiGLU (Qwen-style).""" + """Fused SwiGLU (Qwen-style).""" return self.down_proj( - torch_npu.npu_swiglu(torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), dim=-1)) + torch_npu.npu_swiglu( + torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), + dim=-1, + )) def npu_sdpa_attention_forward(module, @@ -164,6 +420,112 @@ def npu_sdpa_attention_forward(module, return attn_output.transpose(1, 2).contiguous(), None +# ============================================================================= +# Section 2c: Flash Linear Attention (FLA) for Qwen3.5 +# ============================================================================= + + +def _patch_qwen3_5_fla(model=None) -> None: + """Enable Flash Linear Attention (FLA) fast path for Qwen3.5 on NPU. + + Controlled by environment variable ``TWINKLE_NPU_FLA`` (default: True). + """ + if not _is_env_enabled('TWINKLE_NPU_FLA', default=True): + logger.info('[NPU] [FLA] Disabled by TWINKLE_NPU_FLA environment variable') + return + + if not _is_torch_npu_available: + logger.info('[NPU] [FLA] Skip: NPU not available') + return + + # 1. Force FLA availability flag + def _is_fla_available() -> bool: + return True + + for utils_mod_name in ('transformers.utils', 'transformers.utils.import_utils'): + try: + utils_mod = importlib.import_module(utils_mod_name) + setattr(utils_mod, 'is_flash_linear_attention_available', _is_fla_available) + logger.info( + '[NPU] [FLA] Patched %s.is_flash_linear_attention_available', + utils_mod_name, + ) + except Exception as exc: + logger.debug('[NPU] [FLA] Failed to patch %s: %s', utils_mod_name, exc) + + # 2. Try MindSpeed Triton FLA backend + mindspeed_fla = None + try: + from .chunk_gated_delta_rule import chunk_gated_delta_rule as _ms_fla + mindspeed_fla = _ms_fla + logger.info('[NPU] [FLA] MindSpeed Triton chunk_gated_delta_rule loaded') + except ImportError as exc: + logger.warning('[NPU] [FLA] MindSpeed not available: %s', exc) + + # 3. Patch Qwen3.5 modeling modules + fla_target_modules = [ + 'transformers.models.qwen3_5.modeling_qwen3_5', + 'transformers.models.qwen3_5_moe.modeling_qwen3_5_moe', + ] + + for module_name in fla_target_modules: + module = import_optional_module(module_name) + if module is None: + logger.info('[NPU] [FLA] %s: module not found, skip', module_name) + continue + + # Only enable FLA flags if we actually have a backend to serve it + if mindspeed_fla is not None: + setattr(module, 'is_flash_linear_attention_available', _is_fla_available) + setattr(module, 'is_fast_path_available', True) + + # Disable CUDA-only fused op + if hasattr(module, 'FusedRMSNormGated'): + setattr(module, 'FusedRMSNormGated', None) + logger.info('[NPU] [FLA] %s: disabled FusedRMSNormGated', module_name) + + # Replace chunk_gated_delta_rule with MindSpeed implementation + setattr(module, 'chunk_gated_delta_rule', mindspeed_fla) + logger.info( + '[NPU] [FLA] Patched %s.chunk_gated_delta_rule -> MindSpeed', + module_name, + ) + else: + logger.warning( + '[NPU] [FLA] %s: MindSpeed unavailable, FLA flags NOT set', + module_name, + ) + + # 4. Traverse instantiated model and replace per-layer chunk_gated_delta_rule + if model is not None and mindspeed_fla is not None: + # Resolve the underlying PyTorch model from TransformersModel wrapper + model = getattr(model, 'model', getattr(model, 'module', model)) + if not hasattr(model, 'named_modules'): + logger.warning('[NPU] [FLA] Model does not support named_modules, skipping instance patch') + return + patched_instances = 0 + for _name, _module in model.named_modules(): + if hasattr(_module, 'chunk_gated_delta_rule') and callable(getattr(_module, 'chunk_gated_delta_rule')): + if _module.chunk_gated_delta_rule is mindspeed_fla: + continue + + _module.chunk_gated_delta_rule = mindspeed_fla + patched_instances += 1 + logger.debug( + '[NPU] [FLA] Replaced %s(%s).chunk_gated_delta_rule -> MindSpeed', + _name, + type(_module).__name__, + ) + + if patched_instances > 0: + logger.info( + '[NPU] [FLA] Patched %d linear attention instance(s)', + patched_instances, + ) + else: + logger.info('[NPU] [FLA] No linear attention instances found in model') + + # ============================================================================= # Section 3: Patching Helpers # ============================================================================= @@ -173,48 +535,163 @@ def _patch_sdpa_forward() -> None: from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, AttentionInterface AttentionInterface._global_mapping['sdpa'] = npu_sdpa_attention_forward ALL_ATTENTION_FUNCTIONS['sdpa'] = npu_sdpa_attention_forward - logger.debug('[NPU] Patched SDPA attention forward') + logger.debug('[NPU] [SDPA] Patched global SDPA attention forward') def _patch_rmsnorm(module, class_name: str) -> None: - setattr(module, class_name, NpuRMSNorm) - logger.debug(f'[NPU] Patched {module.__name__}.{class_name} -> NpuRMSNorm') + """Patch RMSNorm class with NPU-optimized implementation.""" + if 'Gated' in class_name: + orig_cls = getattr(module, class_name) + setattr(orig_cls, 'forward', npu_gated_rms_norm_forward) + + # Cache the FP32 env flag once at patch time to avoid per-forward overhead. + orig_cls._twinkle_force_fp32 = os.environ.get('TWINKLE_NPU_GATED_RMSNorm_FP32', + '0').lower() in ('1', 'true', 'on', 'yes') + if orig_cls._twinkle_force_fp32: + logger.info( + '[NPU] [RMSNorm] %s.%s forced to FP32 mode', + module.__name__, + class_name, + ) + + logger.info( + '[NPU] [RMSNorm] Patched %s.%s.forward -> npu_gated_rms_norm_forward', + module.__name__, + class_name, + ) + else: + setattr(module, class_name, NpuRMSNorm) + logger.info( + '[NPU] [RMSNorm] Patched %s.%s -> NpuRMSNorm', + module.__name__, + class_name, + ) def _patch_rope(module, func_name: str) -> None: setattr(module, func_name, npu_apply_rotary_pos_emb) - logger.debug(f'[NPU] Patched {module.__name__}.{func_name} -> npu_apply_rotary_pos_emb') + logger.debug( + '[NPU] [RoPE] Patched %s.%s -> npu_apply_rotary_pos_emb', + module.__name__, + func_name, + ) def _patch_swiglu(module, class_name: str) -> None: setattr(getattr(module, class_name), 'forward', npu_swiglu_forward) - logger.debug(f'[NPU] Patched {module.__name__}.{class_name}.forward -> npu_swiglu_forward') + logger.debug( + '[NPU] [MLP] Patched %s.%s.forward -> npu_swiglu_forward', + module.__name__, + class_name, + ) + + +def _patch_moe_sparse_block(module, class_name: str) -> None: + """Patch SparseMoeBlock forward with NPU-optimized implementation.""" + setattr(getattr(module, class_name), 'forward', npu_qwen3_5_moe_sparse_block_forward) + logger.info( + '[NPU] [MoE] Patched %s.%s.forward -> npu_qwen3_5_moe_sparse_block_forward', + module.__name__, + class_name, + ) + + +def _patch_moe_experts(module, class_name: str) -> None: + """Patch packed Experts forward with NPU grouped matmul.""" + setattr(getattr(module, class_name), 'forward', npu_packed_moe_experts_forward) + logger.debug( + '[NPU] [MoE] Patched %s.%s.forward -> npu_packed_moe_experts_forward', + module.__name__, + class_name, + ) # ============================================================================= -# Section 4: Unified Patching Logic +# Section 4: Environment Control # ============================================================================= -def _apply_all_fused_ops() -> None: - r"""Apply fused ops to supported model families unconditionally.""" +def _is_env_enabled(var_name: str, default: bool = True) -> bool: + """Check whether an environment variable is enabled. + + Supports: ``1``/``true``/``on``/``yes`` (force on), + ``0``/``false``/``off``/``no`` (force off), + unset (use ``default``). + """ + env = os.environ.get(var_name, '').lower().strip() + if not env: + return default + if env in ('1', 'true', 'on', 'yes'): + return True + if env in ('0', 'false', 'off', 'no'): + logger.info('[NPU] %s=%s: disabled.', var_name, env) + return False + return default + + +# ============================================================================= +# Section 5: Unified Patching Logic (Fused Ops) +# ============================================================================= + + +def _apply_all_fused_ops(model=None) -> None: + """Apply fused ops to supported model families.""" + logger.info('[NPU] === _apply_all_fused_ops ENTERED ===') if not _is_torch_npu_available: return + if not _is_env_enabled('TWINKLE_NPU_FUSED_OPS', default=True): + return + + target_archs = set() + if model is not None: + config = getattr(model, 'hf_config', getattr(model, 'config', None)) + archs = getattr(config, 'architectures', None) if config else None + if archs: + target_archs = set(archs) + logger.debug('[NPU] Detected architectures for fused ops: %s', archs) + logger.info('[NPU] Auto-applying fused ops to supported model families') - # Patch global SDPA first _patch_sdpa_forward() - # Supported model families: (module_path, class_prefix, mlp_class_name) model_families = [ - ('transformers.models.qwen3.modeling_qwen3', 'Qwen3', 'Qwen3MLP'), - ('transformers.models.qwen3_moe.modeling_qwen3_moe', 'Qwen3Moe', 'Qwen3MoeMLP'), - ('transformers.models.qwen2_5_vl.modeling_qwen2_5_vl', 'Qwen2_5_VL', 'Qwen2MLP'), + ('transformers.models.qwen3.modeling_qwen3', 'Qwen3', 'Qwen3MLP', 'Qwen3ForCausalLM'), + ('transformers.models.qwen3_moe.modeling_qwen3_moe', 'Qwen3Moe', 'Qwen3MoeMLP', 'Qwen3MoeForCausalLM'), + ( + 'transformers.models.qwen2_5_vl.modeling_qwen2_5_vl', + 'Qwen2_5_VL', + 'Qwen2MLP', + 'Qwen2_5_VLForConditionalGeneration', + ), + ( + 'transformers.models.qwen3_5_moe.modeling_qwen3_5_moe', + 'Qwen3_5Moe', + 'Qwen3_5MoeMLP', + 'Qwen3MoeForCausalLM', + ), ] + modeling_qwen3_5 = import_optional_module('transformers.models.qwen3_5.modeling_qwen3_5') + if modeling_qwen3_5 is not None: + model_families.append(( + 'transformers.models.qwen3_5.modeling_qwen3_5', + 'Qwen3_5', + 'Qwen3_5MLP', + 'Qwen3_5ForCausalLM', + )) + + modeling_qwen3_5_moe = import_optional_module('transformers.models.qwen3_5_moe.modeling_qwen3_5_moe') + if modeling_qwen3_5_moe is not None: + model_families.append(( + 'transformers.models.qwen3_5_moe.modeling_qwen3_5_moe', + 'Qwen3_5Moe', + 'Qwen3_5MoeMLP', + 'Qwen3_5MoeForCausalLM', + )) + patched_count = 0 - for module_name, prefix, mlp_name in model_families: + for module_name, prefix, mlp_name, trigger_arch in model_families: try: module = importlib.import_module(module_name) @@ -234,7 +711,16 @@ def _apply_all_fused_ops() -> None: _patch_swiglu(module, mlp_name) patched_count += 1 - # Qwen2.5-VL special cases + experts_cls = f'{prefix}Experts' + if hasattr(module, experts_cls): + _patch_moe_experts(module, experts_cls) + patched_count += 1 + + sparse_cls = f'{prefix}SparseMoeBlock' + if hasattr(module, sparse_cls): + _patch_moe_sparse_block(module, sparse_cls) + patched_count += 1 + if prefix == 'Qwen2_5_VL': if hasattr(module, 'Qwen2_5_VLMLP'): _patch_swiglu(module, 'Qwen2_5_VLMLP') @@ -242,65 +728,182 @@ def _apply_all_fused_ops() -> None: setattr(module, 'apply_multimodal_rotary_pos_emb', npu_apply_multimodal_rotary_pos_emb) logger.debug('[NPU] Patched Qwen2_5_VL multimodal RoPE') - logger.debug(f'[NPU] Patched {prefix} fused ops') + if prefix == 'Qwen3_5': + gated_rmsnorm_cls = f'{prefix}GatedRMSNorm' + if hasattr(module, gated_rmsnorm_cls): + _patch_rmsnorm(module, gated_rmsnorm_cls) + patched_count += 1 + if hasattr(module, 'Qwen3_5VisionMLP'): + _patch_swiglu(module, 'Qwen3_5VisionMLP') + patched_count += 1 + if hasattr(module, 'Qwen3_5VisionRMSNorm'): + _patch_rmsnorm(module, 'Qwen3_5VisionRMSNorm') + patched_count += 1 + + if prefix == 'Qwen3_5Moe': + if hasattr(module, 'Qwen3_5MoeGatedRMSNorm'): + _patch_rmsnorm(module, 'Qwen3_5MoeGatedRMSNorm') + patched_count += 1 + + logger.debug('[NPU] Patched %s fused ops', prefix) except ImportError: - pass # Model family not installed, skip silently + pass + + if not target_archs: + patched_count += _discover_and_patch_unknown_models() - logger.info(f'[NPU] Auto-patched {patched_count} components') + _patch_qwen3_5_fla(model) + + logger.info('[NPU] Auto-patched %d components', patched_count) # ============================================================================= -# Section 5: Public API +# Section 5b: Dynamic model discovery (no hard-coding) # ============================================================================= -def apply_npu_patch() -> None: - r"""Apply all NPU patches unconditionally. +def _discover_and_patch_unknown_models() -> int: + """Dynamically discover and patch additional transformers model families.""" + patched = 0 + already_patched_modules = { + 'transformers.models.qwen3.modeling_qwen3', + 'transformers.models.qwen3_moe.modeling_qwen3_moe', + 'transformers.models.qwen2_5_vl.modeling_qwen2_5_vl', + 'transformers.models.qwen3_5.modeling_qwen3_5', + 'transformers.models.qwen3_5_moe.modeling_qwen3_5_moe', + } - All Ascend NPU optimizations are applied when running on NPU: + try: + import transformers.models as models_pkg + except ImportError: + return 0 + + candidate_modules = [] + for model_name in dir(models_pkg): + if model_name.startswith('_'): + continue + modeling_path = f'transformers.models.{model_name}.modeling_{model_name}' + if modeling_path not in already_patched_modules: + candidate_modules.append(modeling_path) + + for module_name in candidate_modules: + module = import_optional_module(module_name) + if module is None: + continue + + has_rmsnorm = any('RMSNorm' in attr_name and isinstance(getattr(module, attr_name, None), type) + for attr_name in dir(module)) + has_rope = hasattr(module, 'apply_rotary_pos_emb') + has_mlp = any( + attr_name.endswith('MLP') and isinstance(getattr(module, attr_name, None), type) + for attr_name in dir(module)) + + if not (has_rmsnorm or has_rope or has_mlp): + continue + + for attr_name in dir(module): + if attr_name.startswith('_'): + continue + obj = getattr(module, attr_name, None) + if not isinstance(obj, type): + continue + + if 'RMSNorm' in attr_name and issubclass(obj, nn.Module): + try: + _patch_rmsnorm(module, attr_name) + patched += 1 + except Exception as exc: + logger.debug('[NPU] Failed to patch %s.%s: %s', module_name, attr_name, exc) + + if attr_name.endswith('MLP') and hasattr(obj, 'forward'): + try: + _patch_swiglu(module, attr_name) + patched += 1 + except Exception as exc: + logger.debug('[NPU] Failed to patch %s.%s: %s', module_name, attr_name, exc) + + if attr_name.endswith('Experts') and hasattr(obj, 'forward'): + try: + _patch_moe_experts(module, attr_name) + patched += 1 + except Exception as exc: + logger.debug('[NPU] Failed to patch %s.%s: %s', module_name, attr_name, exc) + + if attr_name.endswith('SparseMoeBlock') and hasattr(obj, 'forward'): + try: + _patch_moe_sparse_block(module, attr_name) + patched += 1 + except Exception as exc: + logger.debug('[NPU] Failed to patch %s.%s: %s', module_name, attr_name, exc) + + if has_rope: + try: + _patch_rope(module, 'apply_rotary_pos_emb') + patched += 1 + except Exception as exc: + logger.debug('[NPU] Failed to patch %s.apply_rotary_pos_emb: %s', module_name, exc) + + if patched > 0: + logger.debug('[NPU] Dynamically patched %s', module_name) + + return patched + + +# ============================================================================= +# Section 6: Public API +# ============================================================================= + + +def apply_npu_patch(model=None) -> None: + """Apply all NPU patches. + + Ascend NPU optimizations applied: - MoE grouped_matmul (GMM) - RMSNorm fused kernel - RoPE fused kernel - SwiGLU fused kernel - SDPA Attention compatibility fixes - - Safe to call multiple times — patches are only applied once. - - Example:: - - >>> # Unified entry — all patches applied - >>> if Torch.is_npu_available(): - ... apply_npu_patch() + - Flash Linear Attention (FLA) for Qwen3.5 + + Environment variables: + - ``TWINKLE_NPU_PATCH``: overall switch (``1``/``0``) + - ``TWINKLE_NPU_FUSED_OPS``: fused ops switch (``1``/``0``) + - ``TWINKLE_NPU_MOE_PATCH``: MoE GMM switch (``1``/``0``) + - ``TWINKLE_NPU_FLA``: FLA switch (``1``/``0``) + - ``TWINKLE_NPU_GATED_RMSNorm_FP32``: force FP32 in Gated RMSNorm (``1``/``0``) + + Args: + model: Optional model instance. Required for FLA to traverse and + replace per-instance ``chunk_gated_delta_rule`` bindings. """ global _NPU_PATCH_APPLIED + if not _is_env_enabled('TWINKLE_NPU_PATCH', default=True): + return + + moe_enabled = _is_env_enabled('TWINKLE_NPU_MOE_PATCH', default=True) + if _NPU_PATCH_APPLIED: logger.debug('[NPU] Patches already applied, skipping.') return try: import torch_npu - from torch_npu.contrib import transfer_to_npu except ImportError: logger.warning('torch_npu not available. Skipping NPU patches.') return - # 1. MoE GMM (always) - apply_hf_moe_grouped_mm_patch() + if moe_enabled: + apply_hf_moe_grouped_mm_patch() - # 2. Fused operators (always, unconditional) - _apply_all_fused_ops() + _apply_all_fused_ops(model) _NPU_PATCH_APPLIED = True logger.info('[NPU] All patches applied successfully') def register_npu_fused_function_kernels() -> None: - r"""Register NPU fused ops as Twinkle function kernels (optional). - - Integrates with Twinkle's ``kernelize_model()`` so that RoPE and SDPA - are automatically replaced when running on Ascend devices. - """ + """Register NPU fused ops as Twinkle function kernels (optional).""" if not _is_torch_npu_available: return