From 05ab01d555b021004df7e8b5f136974df6221d46 Mon Sep 17 00:00:00 2001 From: Emilien Macchi Date: Mon, 24 Nov 2025 20:16:47 -0500 Subject: [PATCH] (fix) timm: ROCm 7.0 compatibility for Attention2d modules ROCm 7.0 enforces GEMM paths for 1x1 convolutions, requiring strict memory contiguity. This change causes HIP error: invalid argument when non-contiguous tensors (from reshape/permute/slice operations) are passed to Attention2d and MultiQueryAttention2d modules. Changes: - Add contiguity checks in Attention2d.forward() - Add contiguity checks in MultiQueryAttention2d.forward() - Force .contiguous() only when tensor is non-contiguous Fixes #2613 Signed-off-by: Emilien Macchi --- timm/layers/attention2d.py | 10 +++++++++- timm/layers/helpers.py | 10 ++++++++++ timm/layers/norm.py | 9 ++------- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/timm/layers/attention2d.py b/timm/layers/attention2d.py index d454374a68..8a6d0ad306 100644 --- a/timm/layers/attention2d.py +++ b/timm/layers/attention2d.py @@ -6,7 +6,7 @@ from .config import use_fused_attn from .create_conv2d import create_conv2d -from .helpers import to_2tuple +from .helpers import to_2tuple, is_contiguous from .pool2d_same import create_pool2d @@ -271,6 +271,10 @@ def forward(self, x, attn_mask: Optional[torch.Tensor] = None): """Run layer computation.""" B, C, H, W = s = x.shape + # Force memory contiguity to satisfy GEMM constraints for 1x1 convolutions + if not is_contiguous(x): + x = x.contiguous() + q = self.query(x) # desired q shape: [b, h, k, n x n] - [b, l, h, k] q = self._reshape_projected_query(q, self.num_heads, self.key_dim) @@ -351,6 +355,10 @@ def __init__( def forward(self, x, attn_mask: Optional[torch.Tensor] = None): B, C, H, W = x.shape + # Force memory contiguity to satisfy GEMM constraints for 1x1 convolutions + if not is_contiguous(x): + x = x.contiguous() + if self.head_first: q, k, v = self.qkv(x).view(B, self.num_heads, self.dim_head * 3, -1).chunk(3, dim=2) else: diff --git a/timm/layers/helpers.py b/timm/layers/helpers.py index b003f48d84..3a0b8ea55f 100644 --- a/timm/layers/helpers.py +++ b/timm/layers/helpers.py @@ -4,6 +4,7 @@ """ from itertools import repeat import collections.abc +import torch # From PyTorch internals @@ -41,3 +42,12 @@ def extend_tuple(x, n): if pad_n <= 0: return x[:n] return x + (x[-1],) * pad_n + + +def is_contiguous(tensor: torch.Tensor) -> bool: + """Check tensor contiguity with proper handling for TorchScript compilation.""" + # jit is oh so lovely :/ + if torch.jit.is_scripting(): + return tensor.is_contiguous() + else: + return tensor.is_contiguous(memory_format=torch.contiguous_format) diff --git a/timm/layers/norm.py b/timm/layers/norm.py index cca8eecfe4..e5dea27701 100644 --- a/timm/layers/norm.py +++ b/timm/layers/norm.py @@ -21,6 +21,7 @@ fast_simple_norm, simple_norm, ) +from .helpers import is_contiguous try: from torch.nn.functional import rms_norm @@ -155,12 +156,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -def _is_contiguous(tensor: torch.Tensor) -> bool: - # jit is oh so lovely :/ - if torch.jit.is_scripting(): - return tensor.is_contiguous() - else: - return tensor.is_contiguous(memory_format=torch.contiguous_format) def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float): @@ -191,7 +186,7 @@ def __init__(self, num_channels: int, eps: float = 1e-6): super().__init__(num_channels, eps=eps) def forward(self, x) -> torch.Tensor: - if _is_contiguous(x): + if is_contiguous(x): x = F.layer_norm( x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) else: